diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py new file mode 100644 index 0000000..96b2fdf --- /dev/null +++ b/tests/test_question_cmd.py @@ -0,0 +1,111 @@ +import os +import unittest +import argparse +import tempfile +from pathlib import Path +from unittest.mock import MagicMock +from chatmastermind.commands.question import create_message +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB + + +class TestMessageCreate(unittest.TestCase): + """ + Test if messages created by the 'question' command have + the correct format. + """ + def setUp(self) -> None: + # create ChatDB structure + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), + db_path=Path(self.db_path.name)) + # create arguments mock + self.args = MagicMock(spec=argparse.Namespace) + self.args.source = None + self.args.source_code_only = False + self.args.ai = None + self.args.model = None + self.args.output_tags = None + # create some files for sourcing + self.source_file1 = tempfile.NamedTemporaryFile(delete=False) + self.source_file1_content = """This is just text. +No source code. +Nope. Go look elsewhere!""" + with open(self.source_file1.name, 'w') as f: + f.write(self.source_file1_content) + self.source_file2 = tempfile.NamedTemporaryFile(delete=False) + self.source_file2_content = """This is just text. +``` +This is embedded source code. +``` +And some text again.""" + with open(self.source_file2.name, 'w') as f: + f.write(self.source_file2_content) + self.source_file3 = tempfile.NamedTemporaryFile(delete=False) + self.source_file3_content = """This is all source code. +Yes, really. +Language is called 'brainfart'.""" + with open(self.source_file3.name, 'w') as f: + f.write(self.source_file3_content) + + def tearDown(self) -> None: + os.remove(self.source_file1.name) + os.remove(self.source_file2.name) + os.remove(self.source_file3.name) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return list(Path(tmp_dir.name).glob('*.[ty]*')) + + def test_message_file_created(self) -> None: + self.args.question = ["What is this?"] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + create_message(self.chat, self.args) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + message = Message.from_file(cache_dir_files[0]) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] + + def test_single_question(self) -> None: + self.args.question = ["What is this?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) + self.assertEqual(len(message.question.source_code()), 0) + + def test_multipart_question(self) -> None: + self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("""What is this + +'bard' thing? + +Is it good?""")) + + def test_single_question_with_text_only_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file1.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains no source code + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file1_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file2_content}"""))