Compare commits
No commits in common. "1b080eade2180bf3958a6fade326d270df35ea7f" and "310cb9421ed0e6fa841ce41567208fb438838000" have entirely different histories.
1b080eade2
...
310cb9421e
@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
|||||||
is not found, it uses the first AI in the list.
|
is not found, it uses the first AI in the list.
|
||||||
"""
|
"""
|
||||||
ai_conf: AIConfig
|
ai_conf: AIConfig
|
||||||
if hasattr(args, 'AI') and args.AI:
|
if 'AI' in args and args.AI:
|
||||||
try:
|
try:
|
||||||
ai_conf = config.ais[args.AI]
|
ai_conf = config.ais[args.AI]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
|||||||
|
|
||||||
if ai_conf.name == 'openai':
|
if ai_conf.name == 'openai':
|
||||||
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
||||||
if hasattr(args, 'model') and args.model:
|
if 'model' in args and args.model:
|
||||||
ai.config.model = args.model
|
ai.config.model = args.model
|
||||||
if hasattr(args, 'max_tokens') and args.max_tokens:
|
if 'max_tokens' in args and args.max_tokens:
|
||||||
ai.config.max_tokens = args.max_tokens
|
ai.config.max_tokens = args.max_tokens
|
||||||
if hasattr(args, 'temperature') and args.temperature:
|
if 'temperature' in args and args.temperature:
|
||||||
ai.config.temperature = args.temperature
|
ai.config.temperature = args.temperature
|
||||||
return ai
|
return ai
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,82 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest import mock
|
|
||||||
from chatmastermind.ais.openai import OpenAI
|
|
||||||
from chatmastermind.message import Message, Question, Answer
|
|
||||||
from chatmastermind.chat import Chat
|
|
||||||
from chatmastermind.ai import AIResponse, Tokens
|
|
||||||
from chatmastermind.configuration import OpenAIConfig
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAITest(unittest.TestCase):
|
|
||||||
|
|
||||||
@mock.patch('openai.ChatCompletion.create')
|
|
||||||
def test_request(self, mock_create: mock.MagicMock) -> None:
|
|
||||||
# Create a test instance of OpenAI
|
|
||||||
config = OpenAIConfig()
|
|
||||||
openai = OpenAI(config)
|
|
||||||
|
|
||||||
# Set up the mock response from openai.ChatCompletion.create
|
|
||||||
mock_response = {
|
|
||||||
'choices': [
|
|
||||||
{
|
|
||||||
'message': {
|
|
||||||
'content': 'Answer 1'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'message': {
|
|
||||||
'content': 'Answer 2'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
'usage': {
|
|
||||||
'prompt_tokens': 10,
|
|
||||||
'completion_tokens': 20,
|
|
||||||
'total_tokens': 30
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
question = Message(Question('Question'))
|
|
||||||
chat = Chat([
|
|
||||||
Message(Question('Question 1'), answer=Answer('Answer 1')),
|
|
||||||
Message(Question('Question 2'), answer=Answer('Answer 2')),
|
|
||||||
Message(Question('Question 3'), answer=Answer('Answer 3'))
|
|
||||||
])
|
|
||||||
|
|
||||||
# Make the request
|
|
||||||
response = openai.request(question, chat, num_answers=2)
|
|
||||||
|
|
||||||
# Assert the AIResponse
|
|
||||||
self.assertIsInstance(response, AIResponse)
|
|
||||||
self.assertEqual(len(response.messages), 2)
|
|
||||||
self.assertEqual(response.messages[0].answer, 'Answer 1')
|
|
||||||
self.assertEqual(response.messages[1].answer, 'Answer 2')
|
|
||||||
self.assertIsNotNone(response.tokens)
|
|
||||||
self.assertIsInstance(response.tokens, Tokens)
|
|
||||||
assert response.tokens
|
|
||||||
self.assertEqual(response.tokens.prompt, 10)
|
|
||||||
self.assertEqual(response.tokens.completion, 20)
|
|
||||||
self.assertEqual(response.tokens.total, 30)
|
|
||||||
|
|
||||||
# Assert the mock call to openai.ChatCompletion.create
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
model=f'{config.model}',
|
|
||||||
messages=[
|
|
||||||
{'role': 'system', 'content': f'{config.system}'},
|
|
||||||
{'role': 'user', 'content': 'Question 1'},
|
|
||||||
{'role': 'assistant', 'content': 'Answer 1'},
|
|
||||||
{'role': 'user', 'content': 'Question 2'},
|
|
||||||
{'role': 'assistant', 'content': 'Answer 2'},
|
|
||||||
{'role': 'user', 'content': 'Question 3'},
|
|
||||||
{'role': 'assistant', 'content': 'Answer 3'},
|
|
||||||
{'role': 'user', 'content': 'Question'}
|
|
||||||
],
|
|
||||||
temperature=config.temperature,
|
|
||||||
max_tokens=config.max_tokens,
|
|
||||||
top_p=config.top_p,
|
|
||||||
n=2,
|
|
||||||
frequency_penalty=config.frequency_penalty,
|
|
||||||
presence_penalty=config.presence_penalty
|
|
||||||
)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user