diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 89c72c7..c51d5fd 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -282,6 +282,9 @@ class TestQuestionCmd(TestQuestionCmdBase): # exclude '.next' return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) + +class TestQuestionCmdAsk(TestQuestionCmd): + @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: """ @@ -370,6 +373,9 @@ class TestQuestionCmd(TestQuestionCmdBase): self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, [expected_question]) + +class TestQuestionCmdRepeat(TestQuestionCmd): + @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: """ @@ -511,3 +517,44 @@ class TestQuestionCmd(TestQuestionCmdBase): cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_messages_equal(cached_msg, expected_responses) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: + """ + Repeat multiple questions. + """ + # 1. create some questions / messages + # cached message without an answer + message1 = Message(Question('Question 1'), + ai='foo', + model='bla', + file_path=Path(self.cache_dir.name) / '0001.txt') + # cached message with an answer + message2 = Message(Question('Question 2'), + Answer('Answer 2'), + ai='openai', + model='gpt-3.5-turbo', + file_path=Path(self.cache_dir.name) / '0002.txt') + # DB message without an answer + message3 = Message(Question('Question 3'), + ai='openai', + model='gpt-3.5-turbo', + file_path=Path(self.db_dir.name) / '0003.txt') + message1.to_file() + message2.to_file() + message3.to_file() + # chat = ChatDB.from_dir(Path(self.cache_dir.name), + # Path(self.db_dir.name)) + + # 2. repeat all three questions (without overwriting) + self.args.ask = None + self.args.repeat = ['0001', '0002', '0003'] + self.args.overwrite = False + question_cmd(self.args, self.config) + # two new files should be in the cache directory + # * the repeated cached message with answer + # * the repeated DB message + # -> the cached message wihtout answer should be overwritten + self.assertEqual(len(self.message_list(self.cache_dir)), 4) + self.assertEqual(len(self.message_list(self.db_dir)), 1) + # FIXME: also compare actual content!