openai: added test module
This commit is contained in:
parent
a7345cbc41
commit
1b080eade2
82
tests/test_ais_openai.py
Normal file
82
tests/test_ais_openai.py
Normal file
@ -0,0 +1,82 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from chatmastermind.ais.openai import OpenAI
|
||||
from chatmastermind.message import Message, Question, Answer
|
||||
from chatmastermind.chat import Chat
|
||||
from chatmastermind.ai import AIResponse, Tokens
|
||||
from chatmastermind.configuration import OpenAIConfig
|
||||
|
||||
|
||||
class OpenAITest(unittest.TestCase):
|
||||
|
||||
@mock.patch('openai.ChatCompletion.create')
|
||||
def test_request(self, mock_create: mock.MagicMock) -> None:
|
||||
# Create a test instance of OpenAI
|
||||
config = OpenAIConfig()
|
||||
openai = OpenAI(config)
|
||||
|
||||
# Set up the mock response from openai.ChatCompletion.create
|
||||
mock_response = {
|
||||
'choices': [
|
||||
{
|
||||
'message': {
|
||||
'content': 'Answer 1'
|
||||
}
|
||||
},
|
||||
{
|
||||
'message': {
|
||||
'content': 'Answer 2'
|
||||
}
|
||||
}
|
||||
],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 20,
|
||||
'total_tokens': 30
|
||||
}
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
# Create test data
|
||||
question = Message(Question('Question'))
|
||||
chat = Chat([
|
||||
Message(Question('Question 1'), answer=Answer('Answer 1')),
|
||||
Message(Question('Question 2'), answer=Answer('Answer 2')),
|
||||
Message(Question('Question 3'), answer=Answer('Answer 3'))
|
||||
])
|
||||
|
||||
# Make the request
|
||||
response = openai.request(question, chat, num_answers=2)
|
||||
|
||||
# Assert the AIResponse
|
||||
self.assertIsInstance(response, AIResponse)
|
||||
self.assertEqual(len(response.messages), 2)
|
||||
self.assertEqual(response.messages[0].answer, 'Answer 1')
|
||||
self.assertEqual(response.messages[1].answer, 'Answer 2')
|
||||
self.assertIsNotNone(response.tokens)
|
||||
self.assertIsInstance(response.tokens, Tokens)
|
||||
assert response.tokens
|
||||
self.assertEqual(response.tokens.prompt, 10)
|
||||
self.assertEqual(response.tokens.completion, 20)
|
||||
self.assertEqual(response.tokens.total, 30)
|
||||
|
||||
# Assert the mock call to openai.ChatCompletion.create
|
||||
mock_create.assert_called_once_with(
|
||||
model=f'{config.model}',
|
||||
messages=[
|
||||
{'role': 'system', 'content': f'{config.system}'},
|
||||
{'role': 'user', 'content': 'Question 1'},
|
||||
{'role': 'assistant', 'content': 'Answer 1'},
|
||||
{'role': 'user', 'content': 'Question 2'},
|
||||
{'role': 'assistant', 'content': 'Answer 2'},
|
||||
{'role': 'user', 'content': 'Question 3'},
|
||||
{'role': 'assistant', 'content': 'Answer 3'},
|
||||
{'role': 'user', 'content': 'Question'}
|
||||
],
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
n=2,
|
||||
frequency_penalty=config.frequency_penalty,
|
||||
presence_penalty=config.presence_penalty
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user