Compare commits

..

2 Commits

2 changed files with 93 additions and 38 deletions

View File

@ -62,15 +62,31 @@ class TestWithFakeAI(unittest.TestCase):
""" """
Base class for all tests that need to use the FakeAI. Base class for all tests that need to use the FakeAI.
""" """
def assert_messages_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None: def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None:
""" """
Compare messages using more than just Question and Answer. Compare messages using Question, Answer and all metadata excecot for the file_path.
""" """
self.assertEqual(len(msg1), len(msg2)) self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2): for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata # exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using Question, Answer and ALL metadata.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertTrue(m1.equals(m2, verbose=True))
def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using only Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
self.assertEqual(m1, m2)
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI: def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
""" """
Mocked 'create_ai' that returns a 'FakeAI' instance. Mocked 'create_ai' that returns a 'FakeAI' instance.

View File

@ -7,6 +7,7 @@ from unittest import mock
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AIError from chatmastermind.ai import AIError
@ -260,7 +261,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal_except_file_path(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
@ -318,7 +319,7 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal_except_file_path(cached_msg, [expected_question]) self.assert_msgs_equal_except_file_path(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd): class TestQuestionCmdRepeat(TestQuestionCmd):
@ -343,22 +344,21 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
# since the message's answer is modified, we use a copy here Answer('Answer 0'),
# -> the original is used for comparison below ai=message.ai,
expected_response = fake_ai.request(copy(message), model=message.model,
Chat([]), tags=message.tags,
self.args.num_answers, file_path=Path('<NOT COMPARED>'))
set(self.args.output_tags)).messages
# we expect the original message + the one with the new response # we expect the original message + the one with the new response
expected_responses = [message] + expected_response expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir)) print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal_except_file_path(cached_msg, expected_responses) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
@ -381,19 +381,20 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (WITH overwriting) # repeat the last question (WITH overwriting)
# -> expect a single message afterwards # -> expect a single message afterwards (with a new answer)
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = True self.args.overwrite = True
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(message, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
set(self.args.output_tags)).messages tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal_except_file_path(cached_msg, expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -424,15 +425,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(message, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
self.args.output_tags).messages tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal_except_file_path(cached_msg, expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -457,26 +459,60 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata # -> expect two messages with identical question but different metadata and new answer
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
self.args.output_tags = ['newtag'] self.args.output_tags = ['newtag']
self.args.AI = 'newai' self.args.AI = 'newai'
self.args.model = 'newmodel' self.args.model = 'newmodel'
new_expected_question = Message(question=Question(message.question), new_expected_response = Message(Question(message.question),
tags=set(self.args.output_tags), Answer('Answer 0'),
ai=self.args.AI, ai='newai',
model=self.args.model) model='newmodel',
fake_ai = self.mock_create_ai(self.args, self.config) tags={Tag('newtag')},
new_expected_response = fake_ai.request(new_expected_question, file_path=Path('<NOT COMPARED>'))
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal_except_file_path(cached_msg, [message] + new_expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
@ -533,4 +569,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assert_messages_equal_except_file_path(cached_msg, expected_cache_messages) self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all
db_msg = chat.msg_gather(loc='db')
self.assert_msgs_all_equal(db_msg, [message3])