diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 4936d8f..d143792 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,11 +3,52 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, MessageFilter, Question, source_code +from ..message import Message, MessageFilter, MessageError, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse +def add_file_as_text(question_parts: list[str], file: str) -> None: + """ + Add the given file as plain text to the question part list. + If the file is a Message, add the answer. + """ + file_path = Path(file) + content: str + try: + message = Message.from_file(file_path) + if message and message.answer: + content = message.answer + except MessageError: + with open(file) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + + +def add_file_as_code(question_parts: list[str], file: str) -> None: + """ + Add all source code from the given file. If no code segments can be extracted, + the whole content is added as source code segment. If the file is a Message, + extract the source code from the answer. + """ + file_path = Path(file) + content: str + try: + message = Message.from_file(file_path) + if message and message.answer: + content = message.answer + except MessageError: + with open(file) as r: + content = r.read().strip() + # extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + else: + question_parts.append(f"```\n{content}\n```") + + def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Creates (and writes) a new message from the given arguments. @@ -17,26 +58,13 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: text_files = args.source_text if args.source_text is not None else [] code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): + for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) - if source is not None and len(source) > 0: - with open(source) as r: - content = r.read().strip() - if len(content) > 0: - question_parts.append(content) - if code is not None and len(code) > 0: - with open(code) as r: - content = r.read().strip() - if len(content) == 0: - continue - # try to extract and add source code - code_parts = source_code(content, include_delims=True) - if len(code_parts) > 0: - question_parts += code_parts - # if there's none, add the whole file - else: - question_parts.append(f"```\n{content}\n```") + if text_file is not None and len(text_file) > 0: + add_file_as_text(question_parts, text_file) + if code_file is not None and len(code_file) > 0: + add_file_as_code(question_parts, code_file) full_question = '\n\n'.join(question_parts) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 40ea4d8..b8e7874 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -5,7 +5,7 @@ import tempfile from pathlib import Path from unittest.mock import MagicMock from chatmastermind.commands.question import create_message -from chatmastermind.message import Message, Question +from chatmastermind.message import Message, Question, Answer from chatmastermind.chat import ChatDB @@ -20,6 +20,12 @@ class TestMessageCreate(unittest.TestCase): self.cache_path = tempfile.TemporaryDirectory() self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), db_path=Path(self.db_path.name)) + # create some messages + self.message_text = Message(Question("What is this?"), + Answer("It is pure text")) + self.message_code = Message(Question("What is this?"), + Answer("Text\n```\nIt is embedded code\n```\ntext")) + self.chat.add_to_db([self.message_text, self.message_code]) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source_text = None @@ -160,3 +166,29 @@ This is embedded source code. This is embedded source code. ``` """)) + + def test_single_question_with_text_only_message(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_text = [f"{self.chat.messages[0].file_path}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains no source code (only text) + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question(f"""What is this? + +{self.message_text.answer}""")) + + def test_single_question_with_message_and_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_text = [f"{self.chat.messages[1].file_path}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains no source code (only text) + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +It is embedded code +```"""))