import os import unittest import argparse import tempfile from pathlib import Path from unittest import mock from unittest.mock import MagicMock, call, ANY from typing import Optional from chatmastermind.configuration import Config from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.tags import Tag from chatmastermind.message import Message, Question, Answer from chatmastermind.chat import Chat, ChatDB from chatmastermind.ai import AI, AIResponse, Tokens, AIError class TestQuestionCmdBase(unittest.TestCase): def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using more than just Question and Answer. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): # exclude the file_path, compare only Q, A and metadata self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) class TestMessageCreate(TestQuestionCmdBase): """ Test if messages created by the 'question' command have the correct format. """ def setUp(self) -> None: # create ChatDB structure self.db_dir = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory() self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name), db_path=Path(self.db_dir.name)) # create some messages self.message_text = Message(Question("What is this?"), Answer("It is pure text")) self.message_code = Message(Question("What is this?"), Answer("Text\n```\nIt is embedded code\n```\ntext")) self.chat.db_add([self.message_text, self.message_code]) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` This is embedded source code. ``` And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) # File 4 : two source code blocks self.source_file4 = tempfile.NamedTemporaryFile(delete=False) self.source_file4_content = """This is just text. ``` This is embedded source code. ``` And some text again. ``` This is embedded source code. ``` Aaaand again some text.""" with open(self.source_file4.name, 'w') as f: f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) os.remove(self.source_file2.name) os.remove(self.source_file3.name) os.remove(self.source_file4.name) def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' return list(Path(tmp_dir.name).glob('*.[ty]*')) def test_message_file_created(self) -> None: self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 1) message = Message.from_file(cache_dir_files[0]) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this 'bard' thing? Is it good?""")) def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question("""What is this? ``` This is embedded source code. ``` """)) def test_single_question_with_code_only_file(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file3.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file is complete source code self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` {self.source_file3_content} ```""")) def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file4.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains 2 source code blocks # -> expect them in the question self.assertEqual(len(message.question.source_code()), 2) self.assertEqual(message.question, Question("""What is this? ``` This is embedded source code. ``` ``` This is embedded source code. ``` """)) def test_single_question_with_text_only_message(self) -> None: self.args.ask = ["What is this?"] self.args.source_text = [f"{self.chat.messages[0].file_path}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.message_text.answer}""")) def test_single_question_with_message_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.chat.messages[1].file_path}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # answer contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question("""What is this? ``` It is embedded code ``` """)) class TestQuestionCmd(TestQuestionCmdBase): def setUp(self) -> None: # create DB and cache self.db_dir = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory() # create configuration self.config = Config() self.config.cache = self.cache_dir.name self.config.db = self.db_dir.name # create a mock argparse.Namespace self.args = argparse.Namespace( ask=['What is the meaning of life?'], num_answers=1, output_tags=['science'], AI='openai', model='gpt-3.5-turbo', or_tags=None, and_tags=None, exclude_tags=None, source_text=None, source_code=None, create=None, repeat=None, process=None, overwrite=None ) # create a mock AI instance self.ai = MagicMock(spec=AI) self.ai.request.side_effect = self.mock_request def input_message(self, args: argparse.Namespace) -> Message: """ Create the expected input message for a question using the given arguments. """ # NOTE: we only use the first question from the "ask" list # -> message creation using "question.create_message()" is # tested above # the answer is always empty for the input message return Message(Question(args.ask[0]), tags=args.output_tags, ai=args.AI, model=args.model) def mock_request(self, question: Message, chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ Mock the 'ai.request()' function """ question.answer = Answer("Answer 0") question.tags = set(otags) if otags else None question.ai = 'FakeAI' question.model = 'FakeModel' answers: list[Message] = [question] for n in range(1, num_answers): answers.append(Message(question=question.question, answer=Answer(f"Answer {n}"), tags=otags, ai='FakeAI', model='FakeModel')) return AIResponse(answers, Tokens(10, 10, 20)) def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: """ Test single answer with no errors. """ mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) expected_responses = self.mock_request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages # execute the command question_cmd(self.args, self.config) # check for correct request call self.ai.request.assert_called_once_with(expected_question, ANY, self.args.num_answers, self.args.output_tags) # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: """ Test single answer with no errors (mocked ChatDB version). """ chat = MagicMock(spec=ChatDB) mock_from_dir.return_value = chat mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) expected_responses = self.mock_request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages # execute the command question_cmd(self.args, self.config) # check for correct request call self.ai.request.assert_called_once_with(expected_question, chat, self.args.num_answers, self.args.output_tags) # check for the correct ChatDB calls: # - initial question has been written (prior to the actual request) # - responses have been written (after the request) chat.cache_write.assert_has_calls([call([expected_question]), call(expected_responses)], any_order=False) # check that the messages have not been added to the internal message list chat.cache_add.assert_not_called() @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_with_error(self, mock_create_ai: MagicMock) -> None: """ Provoke an error during the AI request and verify that the question has been correctly stored in the cache. """ mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) self.ai.request.side_effect = AIError # execute the command with self.assertRaises(AIError): question_cmd(self.args, self.config) # check for correct request call self.ai.request.assert_called_once_with(expected_question, ANY, self.args.num_answers, self.args.output_tags) # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, [expected_question]) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question. """ # 1. ask a question mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) expected_responses = self.mock_request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages question_cmd(self.args, self.config) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_responses) # 2. repeat the last question (without overwriting) # -> expect two identical messages (except for the file_path) self.args.ask = None self.args.repeat = [] self.args.overwrite = False expected_responses += expected_responses question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_messages_equal(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question and overwrite the old one. """ # 1. ask a question mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) expected_responses = self.mock_request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages question_cmd(self.args, self.config) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_responses) # 2. repeat the last question (WITH overwriting) # -> expect a single message afterwards self.args.ask = None self.args.repeat = [] self.args.overwrite = True question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_responses) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question after an error. """ # 1. ask a question mock_create_ai.return_value = self.ai expected_question = self.input_message(self.args) self.ai.request.side_effect = AIError # execute the command with self.assertRaises(AIError): question_cmd(self.args, self.config) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, [expected_question]) # 2. repeat the last question (without overwriting) # -> expect a single message because if the original has # no answer, it should be overwritten by default self.args.ask = None self.args.repeat = [] self.args.overwrite = False self.ai.request.side_effect = self.mock_request expected_responses = self.mock_request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_responses) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)