From 820d4aed218d97125e4f73bb741ac6d58a482c83 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 19 Sep 2023 15:19:41 +0200 Subject: [PATCH] test_question_cmd: added more testcases for '--repeat' --- tests/test_question_cmd.py | 81 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index b809567..06d4929 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -285,7 +285,7 @@ class TestQuestionCmd(TestQuestionCmdBase): @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: """ - Test single answer with no errors + Test single answer with no errors. """ mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) @@ -313,7 +313,7 @@ class TestQuestionCmd(TestQuestionCmdBase): @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: """ - Test single answer with no errors (mocked ChatDB version) + Test single answer with no errors (mocked ChatDB version). """ chat = MagicMock(spec=ChatDB) mock_from_dir.return_value = chat @@ -373,7 +373,7 @@ class TestQuestionCmd(TestQuestionCmdBase): @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: """ - Repeat a single question + Repeat a single question. """ # 1. ask a question mock_create_ai.return_value = self.ai @@ -399,3 +399,78 @@ class TestQuestionCmd(TestQuestionCmdBase): cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_messages_equal(cached_msg, expected_responses) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: + """ + Repeat a single question and overwrite the old one. + """ + # 1. ask a question + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + question_cmd(self.args, self.config) + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + assert cached_msg[0].file_path + cached_msg_file_id = cached_msg[0].file_path.stem + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + + # 2. repeat the last question (WITH overwriting) + # -> expect a single message afterwards + self.args.ask = None + self.args.repeat = [] + self.args.overwrite = True + question_cmd(self.args, self.config) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + # also check that the file ID has not been changed + assert cached_msg[0].file_path + self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None: + """ + Repeat a single question after an error. + """ + # 1. ask a question + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + self.ai.request.side_effect = AIError + + # execute the command + with self.assertRaises(AIError): + question_cmd(self.args, self.config) + + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + assert cached_msg[0].file_path + cached_msg_file_id = cached_msg[0].file_path.stem + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, [expected_question]) + + # 2. repeat the last question (without overwriting) + # -> expect a single message because if the original has + # no answer, it should be overwritten by default + self.args.ask = None + self.args.repeat = [] + self.args.overwrite = False + self.ai.request.side_effect = self.mock_request + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + question_cmd(self.args, self.config) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + # also check that the file ID has not been changed + assert cached_msg[0].file_path + self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)