diff --git a/tests/test_ais_openai.py b/tests/test_ais_openai.py new file mode 100644 index 0000000..7c903d5 --- /dev/null +++ b/tests/test_ais_openai.py @@ -0,0 +1,82 @@ +import unittest +from unittest import mock +from chatmastermind.ais.openai import OpenAI +from chatmastermind.message import Message, Question, Answer +from chatmastermind.chat import Chat +from chatmastermind.ai import AIResponse, Tokens +from chatmastermind.configuration import OpenAIConfig + + +class OpenAITest(unittest.TestCase): + + @mock.patch('openai.ChatCompletion.create') + def test_request(self, mock_create: mock.MagicMock) -> None: + # Create a test instance of OpenAI + config = OpenAIConfig() + openai = OpenAI(config) + + # Set up the mock response from openai.ChatCompletion.create + mock_response = { + 'choices': [ + { + 'message': { + 'content': 'Answer 1' + } + }, + { + 'message': { + 'content': 'Answer 2' + } + } + ], + 'usage': { + 'prompt_tokens': 10, + 'completion_tokens': 20, + 'total_tokens': 30 + } + } + mock_create.return_value = mock_response + + # Create test data + question = Message(Question('Question')) + chat = Chat([ + Message(Question('Question 1'), answer=Answer('Answer 1')), + Message(Question('Question 2'), answer=Answer('Answer 2')), + Message(Question('Question 3'), answer=Answer('Answer 3')) + ]) + + # Make the request + response = openai.request(question, chat, num_answers=2) + + # Assert the AIResponse + self.assertIsInstance(response, AIResponse) + self.assertEqual(len(response.messages), 2) + self.assertEqual(response.messages[0].answer, 'Answer 1') + self.assertEqual(response.messages[1].answer, 'Answer 2') + self.assertIsNotNone(response.tokens) + self.assertIsInstance(response.tokens, Tokens) + assert response.tokens + self.assertEqual(response.tokens.prompt, 10) + self.assertEqual(response.tokens.completion, 20) + self.assertEqual(response.tokens.total, 30) + + # Assert the mock call to openai.ChatCompletion.create + mock_create.assert_called_once_with( + model=f'{config.model}', + messages=[ + {'role': 'system', 'content': f'{config.system}'}, + {'role': 'user', 'content': 'Question 1'}, + {'role': 'assistant', 'content': 'Answer 1'}, + {'role': 'user', 'content': 'Question 2'}, + {'role': 'assistant', 'content': 'Answer 2'}, + {'role': 'user', 'content': 'Question 3'}, + {'role': 'assistant', 'content': 'Answer 3'}, + {'role': 'user', 'content': 'Question'} + ], + temperature=config.temperature, + max_tokens=config.max_tokens, + top_p=config.top_p, + n=2, + frequency_penalty=config.frequency_penalty, + presence_penalty=config.presence_penalty + )