test_question_cmd: added a new testcase and made the old cases more explicit (easier to read)

This commit is contained in:
juk0de 2023-09-24 08:53:37 +02:00
parent 87b25993be
commit 601ebe731a

View File

@ -7,6 +7,7 @@ from unittest import mock
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AIError from chatmastermind.ai import AIError
@ -343,15 +344,14 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
# since the message's answer is modified, we use a copy here Answer('Answer 0'),
# -> the original is used for comparison below ai=message.ai,
expected_response = fake_ai.request(copy(message), model=message.model,
Chat([]), tags=message.tags,
self.args.num_answers, file_path=Path('<NOT COMPARED>'))
set(self.args.output_tags)).messages
# we expect the original message + the one with the new response # we expect the original message + the one with the new response
expected_responses = [message] + expected_response expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
@ -381,19 +381,20 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
# repeat the last question (WITH overwriting) # repeat the last question (WITH overwriting)
# -> expect a single message afterwards # -> expect a single message afterwards (with a new answer)
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = True self.args.overwrite = True
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(message, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
set(self.args.output_tags)).messages tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -424,15 +425,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config) expected_response = Message(Question(message.question),
expected_response = fake_ai.request(message, Answer('Answer 0'),
Chat([]), ai=message.ai,
self.args.num_answers, model=message.model,
self.args.output_tags).messages tags=message.tags,
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response])
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -457,26 +459,60 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
assert cached_msg[0].file_path assert cached_msg[0].file_path
# repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata # -> expect two messages with identical question but different metadata and new answer
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
self.args.output_tags = ['newtag'] self.args.output_tags = ['newtag']
self.args.AI = 'newai' self.args.AI = 'newai'
self.args.model = 'newmodel' self.args.model = 'newmodel'
new_expected_question = Message(question=Question(message.question), new_expected_response = Message(Question(message.question),
tags=set(self.args.output_tags), Answer('Answer 0'),
ai=self.args.AI, ai='newai',
model=self.args.model) model='newmodel',
fake_ai = self.mock_create_ai(self.args, self.config) tags={Tag('newtag')},
new_expected_response = fake_ai.request(new_expected_question, file_path=Path('<NOT COMPARED>'))
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_msgs_equal_except_file_path(cached_msg, [message] + new_expected_response) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
# repeat the last question with new arguments
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_response = Message(Question(message.question),
Answer('Answer 0'),
ai='newai',
model='newmodel',
tags={Tag('newtag')},
file_path=Path('<NOT COMPARED>'))
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response])
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: