Compare commits

..

2 Commits

Author SHA1 Message Date
1b080eade2 openai: added test module 2023-09-13 08:49:06 +02:00
a7345cbc41 ai_factory: fixed argument parsing bug 2023-09-13 07:52:05 +02:00
2 changed files with 86 additions and 4 deletions

View File

@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if 'AI' in args and args.AI:
if hasattr(args, 'AI') and args.AI:
try:
ai_conf = config.ais[args.AI]
except KeyError:
@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if 'model' in args and args.model:
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
if 'max_tokens' in args and args.max_tokens:
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if 'temperature' in args and args.temperature:
if hasattr(args, 'temperature') and args.temperature:
ai.config.temperature = args.temperature
return ai
else:

82
tests/test_ais_openai.py Normal file
View File

@ -0,0 +1,82 @@
import unittest
from unittest import mock
from chatmastermind.ais.openai import OpenAI
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat
from chatmastermind.ai import AIResponse, Tokens
from chatmastermind.configuration import OpenAIConfig
class OpenAITest(unittest.TestCase):
@mock.patch('openai.ChatCompletion.create')
def test_request(self, mock_create: mock.MagicMock) -> None:
# Create a test instance of OpenAI
config = OpenAIConfig()
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_response = {
'choices': [
{
'message': {
'content': 'Answer 1'
}
},
{
'message': {
'content': 'Answer 2'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
}
}
mock_create.return_value = mock_response
# Create test data
question = Message(Question('Question'))
chat = Chat([
Message(Question('Question 1'), answer=Answer('Answer 1')),
Message(Question('Question 2'), answer=Answer('Answer 2')),
Message(Question('Question 3'), answer=Answer('Answer 3'))
])
# Make the request
response = openai.request(question, chat, num_answers=2)
# Assert the AIResponse
self.assertIsInstance(response, AIResponse)
self.assertEqual(len(response.messages), 2)
self.assertEqual(response.messages[0].answer, 'Answer 1')
self.assertEqual(response.messages[1].answer, 'Answer 2')
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 10)
self.assertEqual(response.tokens.completion, 20)
self.assertEqual(response.tokens.total, 30)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
model=f'{config.model}',
messages=[
{'role': 'system', 'content': f'{config.system}'},
{'role': 'user', 'content': 'Question 1'},
{'role': 'assistant', 'content': 'Answer 1'},
{'role': 'user', 'content': 'Question 2'},
{'role': 'assistant', 'content': 'Answer 2'},
{'role': 'user', 'content': 'Question 3'},
{'role': 'assistant', 'content': 'Answer 3'},
{'role': 'user', 'content': 'Question'}
],
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)