import os import argparse import tempfile from copy import copy from pathlib import Path from unittest import mock from unittest.mock import MagicMock, call from chatmastermind.configuration import Config from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.tags import Tag from chatmastermind.message import Message, Question, Answer from chatmastermind.chat import Chat, ChatDB from chatmastermind.ai import AIError from .test_common import TestWithFakeAI msg_suffix = Message.file_suffix_write class TestMessageCreate(TestWithFakeAI): """ Test if messages created by the 'question' command have the correct format. """ def setUp(self) -> None: # create ChatDB structure self.db_dir = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory() self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name), db_path=Path(self.db_dir.name)) # create some messages self.message_text = Message(Question("What is this?"), Answer("It is pure text")) self.message_code = Message(Question("What is this?"), Answer("Text\n```\nIt is embedded code\n```\ntext")) self.chat.db_add([self.message_text, self.message_code]) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` This is embedded source code. ``` And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) # File 4 : two source code blocks self.source_file4 = tempfile.NamedTemporaryFile(delete=False) self.source_file4_content = """This is just text. ``` This is embedded source code. ``` And some text again. ``` This is embedded source code. ``` Aaaand again some text.""" with open(self.source_file4.name, 'w') as f: f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) os.remove(self.source_file2.name) os.remove(self.source_file3.name) os.remove(self.source_file4.name) def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' return list(Path(tmp_dir.name).glob(f'*{msg_suffix}')) def test_message_file_created(self) -> None: self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 1) message = Message.from_file(cache_dir_files[0]) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this 'bard' thing? Is it good?""")) def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question("""What is this? ``` This is embedded source code. ``` """)) def test_single_question_with_code_only_file(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file3.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file is complete source code self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` {self.source_file3_content} ```""")) def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file4.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains 2 source code blocks # -> expect them in the question self.assertEqual(len(message.question.source_code()), 2) self.assertEqual(message.question, Question("""What is this? ``` This is embedded source code. ``` ``` This is embedded source code. ``` """)) def test_single_question_with_text_only_message(self) -> None: self.args.ask = ["What is this?"] self.args.source_text = [f"{self.chat.messages[0].file_path}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.message_text.answer}""")) def test_single_question_with_message_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.chat.messages[1].file_path}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # answer contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question("""What is this? ``` It is embedded code ``` """)) class TestQuestionCmd(TestWithFakeAI): def setUp(self) -> None: # create DB and cache self.db_dir = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory() # create configuration self.config = Config() self.config.cache = self.cache_dir.name self.config.db = self.db_dir.name # create a mock argparse.Namespace self.args = argparse.Namespace( ask=['What is the meaning of life?'], num_answers=1, output_tags=['science'], AI='FakeAI', model='FakeModel', or_tags=None, and_tags=None, exclude_tags=None, source_text=None, source_code=None, create=None, repeat=None, process=None, overwrite=None ) def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')]) class TestQuestionCmdAsk(TestQuestionCmd): @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: """ Test single answer with no errors. """ mock_create_ai.side_effect = self.mock_create_ai expected_question = Message(Question(self.args.ask[0]), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path('')) fake_ai = self.mock_create_ai(self.args, self.config) expected_responses = fake_ai.request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages # execute the command question_cmd(self.args, self.config) # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: """ Test single answer with no errors (mocked ChatDB version). """ chat = MagicMock(spec=ChatDB) mock_from_dir.return_value = chat mock_create_ai.side_effect = self.mock_create_ai expected_question = Message(Question(self.args.ask[0]), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path('')) fake_ai = self.mock_create_ai(self.args, self.config) expected_responses = fake_ai.request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages # execute the command question_cmd(self.args, self.config) # check for the correct ChatDB calls: # - initial question has been written (prior to the actual request) # - responses have been written (after the request) chat.cache_write.assert_has_calls([call([expected_question]), call(expected_responses)], any_order=False) # check that the messages have not been added to the internal message list chat.cache_add.assert_not_called() @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_with_error(self, mock_create_ai: MagicMock) -> None: """ Provoke an error during the AI request and verify that the question has been correctly stored in the cache. """ mock_create_ai.side_effect = self.mock_create_ai_with_error expected_question = Message(Question(self.args.ask[0]), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path('')) # execute the command with self.assertRaises(AIError): question_cmd(self.args, self.config) # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [expected_question]) class TestQuestionCmdRepeat(TestQuestionCmd): @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question. """ mock_create_ai.side_effect = self.mock_create_ai chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) # repeat the last question (without overwriting) # -> expect two identical messages (except for the file_path) self.args.ask = None self.args.repeat = [] self.args.overwrite = False expected_response = Message(Question(message.question), Answer('Answer 0'), ai=message.ai, model=message.model, tags=message.tags, file_path=Path('')) # we expect the original message + the one with the new response expected_responses = [message] + [expected_response] question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') print(self.message_list(self.cache_dir)) self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_msgs_equal_except_file_path(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question and overwrite the old one. """ mock_create_ai.side_effect = self.mock_create_ai chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem # repeat the last question (WITH overwriting) # -> expect a single message afterwards (with a new answer) self.args.ask = None self.args.repeat = [] self.args.overwrite = True expected_response = Message(Question(message.question), Answer('Answer 0'), ai=message.ai, model=message.model, tags=message.tags, file_path=Path('')) question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question after an error. """ mock_create_ai.side_effect = self.mock_create_ai chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) # create a question WITHOUT an answer # -> just like after an error, which is tested above message = Message(Question(self.args.ask[0]), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem # repeat the last question (without overwriting) # -> expect a single message because if the original has # no answer, it should be overwritten by default self.args.ask = None self.args.repeat = [] self.args.overwrite = False expected_response = Message(Question(message.question), Answer('Answer 0'), ai=message.ai, model=message.model, tags=message.tags, file_path=Path('')) question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [expected_response]) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question with new arguments. """ mock_create_ai.side_effect = self.mock_create_ai chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path # repeat the last question with new arguments (without overwriting) # -> expect two messages with identical question but different metadata and new answer self.args.ask = None self.args.repeat = [] self.args.overwrite = False self.args.output_tags = ['newtag'] self.args.AI = 'newai' self.args.model = 'newmodel' new_expected_response = Message(Question(message.question), Answer('Answer 0'), ai='newai', model='newmodel', tags={Tag('newtag')}, file_path=Path('')) question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_msgs_equal_except_file_path(cached_msg, [message] + [new_expected_response]) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_new_args_overwrite(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question with new arguments, overwriting the old one. """ mock_create_ai.side_effect = self.mock_create_ai chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) # create a message message = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') chat.msg_write([message]) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path # repeat the last question with new arguments self.args.ask = None self.args.repeat = [] self.args.overwrite = True self.args.output_tags = ['newtag'] self.args.AI = 'newai' self.args.model = 'newmodel' new_expected_response = Message(Question(message.question), Answer('Answer 0'), ai='newai', model='newmodel', tags={Tag('newtag')}, file_path=Path('')) question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_msgs_equal_except_file_path(cached_msg, [new_expected_response]) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: """ Repeat multiple questions. """ mock_create_ai.side_effect = self.mock_create_ai chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) # 1. === create three questions === # cached message without an answer message1 = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}') # cached message with an answer message2 = Message(Question(self.args.ask[0]), Answer('Old Answer'), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model, file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}') # DB message without an answer message3 = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model, file_path=Path(self.db_dir.name) / f'0003{msg_suffix}') chat.msg_write([message1, message2, message3]) questions = [message1, message2, message3] expected_responses: list[Message] = [] fake_ai = self.mock_create_ai(self.args, self.config) for question in questions: # since the message's answer is modified, we use a copy # -> the original is used for comparison below expected_responses += fake_ai.request(copy(question), Chat([]), self.args.num_answers, set(self.args.output_tags)).messages # 2. === repeat all three questions (without overwriting) === self.args.ask = None self.args.repeat = ['0001', '0002', '0003'] self.args.overwrite = False question_cmd(self.args, self.config) # two new files should be in the cache directory # * the repeated cached message with answer # * the repeated DB message # -> the cached message without answer should be overwritten self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.db_dir)), 1) expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] cached_msg = chat.msg_gather(loc='cache') self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages) # check that the DB message has not been modified at all db_msg = chat.msg_gather(loc='db') self.assert_msgs_all_equal(db_msg, [message3])