diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index b809567..89c72c7 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -285,7 +285,7 @@ class TestQuestionCmd(TestQuestionCmdBase): @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: """ - Test single answer with no errors + Test single answer with no errors. """ mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) @@ -313,7 +313,7 @@ class TestQuestionCmd(TestQuestionCmdBase): @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: """ - Test single answer with no errors (mocked ChatDB version) + Test single answer with no errors (mocked ChatDB version). """ chat = MagicMock(spec=ChatDB) mock_from_dir.return_value = chat @@ -373,7 +373,7 @@ class TestQuestionCmd(TestQuestionCmdBase): @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: """ - Repeat a single question + Repeat a single question. """ # 1. ask a question mock_create_ai.return_value = self.ai @@ -399,3 +399,115 @@ class TestQuestionCmd(TestQuestionCmdBase): cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_messages_equal(cached_msg, expected_responses) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: + """ + Repeat a single question and overwrite the old one. + """ + # 1. ask a question + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + question_cmd(self.args, self.config) + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + assert cached_msg[0].file_path + cached_msg_file_id = cached_msg[0].file_path.stem + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + + # 2. repeat the last question (WITH overwriting) + # -> expect a single message afterwards + self.args.ask = None + self.args.repeat = [] + self.args.overwrite = True + question_cmd(self.args, self.config) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + # also check that the file ID has not been changed + assert cached_msg[0].file_path + self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None: + """ + Repeat a single question after an error. + """ + # 1. ask a question and provoke an error + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + self.ai.request.side_effect = AIError + with self.assertRaises(AIError): + question_cmd(self.args, self.config) + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + assert cached_msg[0].file_path + cached_msg_file_id = cached_msg[0].file_path.stem + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, [expected_question]) + + # 2. repeat the last question (without overwriting) + # -> expect a single message because if the original has + # no answer, it should be overwritten by default + self.args.ask = None + self.args.repeat = [] + self.args.overwrite = False + self.ai.request.side_effect = self.mock_request + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + question_cmd(self.args, self.config) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + # also check that the file ID has not been changed + assert cached_msg[0].file_path + self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None: + """ + Repeat a single question with new arguments. + """ + # 1. ask a question + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + question_cmd(self.args, self.config) + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + + # 2. repeat the last question with new arguments (without overwriting) + # -> expect two messages with identical question and answer, but different metadata + self.args.ask = None + self.args.repeat = [] + self.args.overwrite = False + self.args.output_tags = ['newtag'] + self.args.AI = 'newai' + self.args.model = 'newmodel' + new_expected_question = Message(question=Question(expected_question.question), + tags=set(self.args.output_tags), + ai=self.args.AI, + model=self.args.model) + expected_responses += self.mock_request(new_expected_question, + Chat([]), + self.args.num_answers, + set(self.args.output_tags)).messages + question_cmd(self.args, self.config) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 2) + self.assert_messages_equal(cached_msg, expected_responses)