diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 90b782b..756a051 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question +from ..message import Message, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ question_parts = [] question_list = args.ask if args.ask is not None else [] - source_list = args.source if args.source is not None else [] - code_list = args.source_code if args.source_code is not None else [] + text_files = args.source_text if args.source_text is not None else [] + code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) if source is not None and len(source) > 0: @@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code is not None and len(code) > 0: with open(code) as r: content = r.read().strip() - if len(content) > 0: + if len(content) == 0: + continue + # try to extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + # if there's none, add the whole file + else: question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index eadb095..99aca09 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,7 +67,7 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 35de3b9..7107c13 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -414,7 +414,7 @@ class Message(): return '\n'.join(output) def __str__(self) -> str: - return self.to_str(False, False, False) + return self.to_str(True, True, False) def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index aa0dc25..40ea4d8 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -22,18 +22,19 @@ class TestMessageCreate(unittest.TestCase): db_path=Path(self.db_path.name)) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) - self.args.source = None + self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None - # create some files for sourcing + # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) + # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` @@ -42,12 +43,26 @@ This is embedded source code. And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) + # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) + # File 4 : two source code blocks + self.source_file4 = tempfile.NamedTemporaryFile(delete=False) + self.source_file4_content = """This is just text. +``` +This is embedded source code. +``` +And some text again. +``` +This is embedded source code. +``` +Aaaand again some text.""" + with open(self.source_file4.name, 'w') as f: + f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) @@ -86,40 +101,62 @@ Language is called 'brainfart'.""" Is it good?""")) - def test_single_question_with_text_only_source(self) -> None: + def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file1.name}"] + self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains no source code + # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_source(self) -> None: - self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file2.name}"] - message = create_message(self.chat, self.args) - self.assertIsInstance(message, Message) - # source file contains 1 source code block - # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question(f"""What is this? - -{self.source_file2_content}""")) - - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains 1 source code block + # file contains 1 source code block # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` +""")) + + def test_single_question_with_code_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file3.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file is complete source code + self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` -{self.source_file2_content} +{self.source_file3_content} ```""")) + + def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file4.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 2 source code blocks + # -> expect them in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` + + +``` +This is embedded source code. +``` +"""))