import os import unittest import argparse import tempfile from pathlib import Path from unittest import mock from unittest.mock import MagicMock, call from typing import Optional, Union from chatmastermind.configuration import Config, AIConfig from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.tags import Tag from chatmastermind.message import Message, Question, Answer from chatmastermind.chat import Chat, ChatDB from chatmastermind.ai import AI, AIResponse, Tokens, AIError class FakeAI(AI): """ A mocked version of the 'AI' class. """ ID: str name: str config: AIConfig def models(self) -> list[str]: raise NotImplementedError def tokens(self, data: Union[Message, Chat]) -> int: return 123 def print(self) -> None: pass def print_models(self) -> None: pass def __init__(self, ID: str, model: str, error: bool = False): self.ID = ID self.model = model self.error = error def request(self, question: Message, chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ Mock the 'ai.request()' function by either returning fake answers or raising an exception. """ if self.error: raise AIError question.answer = Answer("Answer 0") question.tags = set(otags) if otags is not None else None question.ai = self.ID question.model = self.model answers: list[Message] = [question] for n in range(1, num_answers): answers.append(Message(question=question.question, answer=Answer(f"Answer {n}"), tags=otags, ai=self.ID, model=self.model)) return AIResponse(answers, Tokens(10, 10, 20)) class TestQuestionCmdBase(unittest.TestCase): def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using more than just Question and Answer. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): # exclude the file_path, compare only Q, A and metadata self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI: """ Mocked 'create_ai' that returns a 'FakeAI' instance. """ return FakeAI(args.AI, args.model) def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI: """ Mocked 'create_ai' that returns a 'FakeAI' instance. """ return FakeAI(args.AI, args.model, error=True) class TestMessageCreate(TestQuestionCmdBase): """ Test if messages created by the 'question' command have the correct format. """ def setUp(self) -> None: # create ChatDB structure self.db_dir = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory() self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name), db_path=Path(self.db_dir.name)) # create some messages self.message_text = Message(Question("What is this?"), Answer("It is pure text")) self.message_code = Message(Question("What is this?"), Answer("Text\n```\nIt is embedded code\n```\ntext")) self.chat.db_add([self.message_text, self.message_code]) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` This is embedded source code. ``` And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) # File 4 : two source code blocks self.source_file4 = tempfile.NamedTemporaryFile(delete=False) self.source_file4_content = """This is just text. ``` This is embedded source code. ``` And some text again. ``` This is embedded source code. ``` Aaaand again some text.""" with open(self.source_file4.name, 'w') as f: f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) os.remove(self.source_file2.name) os.remove(self.source_file3.name) os.remove(self.source_file4.name) def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' return list(Path(tmp_dir.name).glob('*.[ty]*')) def test_message_file_created(self) -> None: self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 1) message = Message.from_file(cache_dir_files[0]) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this 'bard' thing? Is it good?""")) def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question("""What is this? ``` This is embedded source code. ``` """)) def test_single_question_with_code_only_file(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file3.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file is complete source code self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` {self.source_file3_content} ```""")) def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file4.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains 2 source code blocks # -> expect them in the question self.assertEqual(len(message.question.source_code()), 2) self.assertEqual(message.question, Question("""What is this? ``` This is embedded source code. ``` ``` This is embedded source code. ``` """)) def test_single_question_with_text_only_message(self) -> None: self.args.ask = ["What is this?"] self.args.source_text = [f"{self.chat.messages[0].file_path}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.message_text.answer}""")) def test_single_question_with_message_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.chat.messages[1].file_path}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) # answer contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question("""What is this? ``` It is embedded code ``` """)) class TestQuestionCmd(TestQuestionCmdBase): def setUp(self) -> None: # create DB and cache self.db_dir = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory() # create configuration self.config = Config() self.config.cache = self.cache_dir.name self.config.db = self.db_dir.name # create a mock argparse.Namespace self.args = argparse.Namespace( ask=['What is the meaning of life?'], num_answers=1, output_tags=['science'], AI='FakeAI', model='FakeModel', or_tags=None, and_tags=None, exclude_tags=None, source_text=None, source_code=None, create=None, repeat=None, process=None, overwrite=None ) def create_single_message(self, args: argparse.Namespace, with_answer: bool = True) -> Message: message = Message(Question(args.ask[0]), tags=set(args.output_tags) if args.output_tags is not None else None, ai=args.AI, model=args.model, file_path=Path(self.cache_dir.name) / '0001.txt') if with_answer: message.answer = Answer('Answer 0') message.to_file() return message def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) class TestQuestionCmdAsk(TestQuestionCmd): @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: """ Test single answer with no errors. """ mock_create_ai.side_effect = self.mock_create_ai expected_question = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model) fake_ai = self.mock_create_ai(self.args, self.config) expected_responses = fake_ai.request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages # execute the command question_cmd(self.args, self.config) # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: """ Test single answer with no errors (mocked ChatDB version). """ chat = MagicMock(spec=ChatDB) mock_from_dir.return_value = chat mock_create_ai.side_effect = self.mock_create_ai expected_question = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model) fake_ai = self.mock_create_ai(self.args, self.config) expected_responses = fake_ai.request(expected_question, Chat([]), self.args.num_answers, self.args.output_tags).messages # execute the command question_cmd(self.args, self.config) # check for the correct ChatDB calls: # - initial question has been written (prior to the actual request) # - responses have been written (after the request) chat.cache_write.assert_has_calls([call([expected_question]), call(expected_responses)], any_order=False) # check that the messages have not been added to the internal message list chat.cache_add.assert_not_called() @mock.patch('chatmastermind.commands.question.create_ai') def test_ask_with_error(self, mock_create_ai: MagicMock) -> None: """ Provoke an error during the AI request and verify that the question has been correctly stored in the cache. """ mock_create_ai.side_effect = self.mock_create_ai_with_error expected_question = Message(Question(self.args.ask[0]), tags=self.args.output_tags, ai=self.args.AI, model=self.args.model) # execute the command with self.assertRaises(AIError): question_cmd(self.args, self.config) # check for the expected message files chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, [expected_question]) class TestQuestionCmdRepeat(TestQuestionCmd): @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question. """ mock_create_ai.side_effect = self.mock_create_ai # create a message message = self.create_single_message(self.args) # repeat the last question (without overwriting) # -> expect two identical messages (except for the file_path) self.args.ask = None self.args.repeat = [] self.args.overwrite = False fake_ai = self.mock_create_ai(self.args, self.config) expected_response = fake_ai.request(message, Chat([]), self.args.num_answers, set(self.args.output_tags)).messages expected_responses = expected_response + expected_response question_cmd(self.args, self.config) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') print(self.message_list(self.cache_dir)) self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_messages_equal(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question and overwrite the old one. """ mock_create_ai.side_effect = self.mock_create_ai # create a message message = self.create_single_message(self.args) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem # repeat the last question (WITH overwriting) # -> expect a single message afterwards self.args.ask = None self.args.repeat = [] self.args.overwrite = True fake_ai = self.mock_create_ai(self.args, self.config) expected_response = fake_ai.request(message, Chat([]), self.args.num_answers, set(self.args.output_tags)).messages question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_response) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question after an error. """ mock_create_ai.side_effect = self.mock_create_ai # create a question WITHOUT an answer # -> just like after an error, which is tested above question = self.create_single_message(self.args, with_answer=False) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path cached_msg_file_id = cached_msg[0].file_path.stem # repeat the last question (without overwriting) # -> expect a single message because if the original has # no answer, it should be overwritten by default self.args.ask = None self.args.repeat = [] self.args.overwrite = False fake_ai = self.mock_create_ai(self.args, self.config) expected_response = fake_ai.request(question, Chat([]), self.args.num_answers, self.args.output_tags).messages question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assert_messages_equal(cached_msg, expected_response) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None: """ Repeat a single question with new arguments. """ mock_create_ai.side_effect = self.mock_create_ai # create a message message = self.create_single_message(self.args) chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') assert cached_msg[0].file_path # repeat the last question with new arguments (without overwriting) # -> expect two messages with identical question and answer, but different metadata self.args.ask = None self.args.repeat = [] self.args.overwrite = False self.args.output_tags = ['newtag'] self.args.AI = 'newai' self.args.model = 'newmodel' new_expected_question = Message(question=Question(message.question), tags=set(self.args.output_tags), ai=self.args.AI, model=self.args.model) fake_ai = self.mock_create_ai(self.args, self.config) new_expected_response = fake_ai.request(new_expected_question, Chat([]), self.args.num_answers, set(self.args.output_tags)).messages question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assert_messages_equal(cached_msg, [message] + new_expected_response) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: """ Repeat multiple questions. """ # 1. === create three questions === # cached message without an answer message1 = Message(Question('Question 1'), ai='foo', model='bla', file_path=Path(self.cache_dir.name) / '0001.txt') # cached message with an answer message2 = Message(Question('Question 2'), Answer('Answer 0'), ai='openai', model='gpt-3.5-turbo', file_path=Path(self.cache_dir.name) / '0002.txt') # DB message without an answer message3 = Message(Question('Question 3'), ai='openai', model='gpt-3.5-turbo', file_path=Path(self.db_dir.name) / '0003.txt') message1.to_file() message2.to_file() message3.to_file() questions = [message1, message2, message3] expected_responses: list[Message] = [] fake_ai = self.mock_create_ai(self.args, self.config) for question in questions: expected_responses += fake_ai.request(question, Chat([]), self.args.num_answers, set(self.args.output_tags)).messages # 2. === repeat all three questions (without overwriting) === self.args.ask = None self.args.repeat = ['0001', '0002', '0003'] self.args.overwrite = False question_cmd(self.args, self.config) # two new files should be in the cache directory # * the repeated cached message with answer # * the repeated DB message # -> the cached message without answer should be overwritten self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.db_dir)), 1) expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') print(f"Cached: {cached_msg}") print(f"Expected: {expected_cache_messages}") self.assert_messages_equal(cached_msg, expected_cache_messages)