diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 1709a3c..818b1de 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -15,19 +15,21 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: question_parts = [] question_list = args.question if args.question is not None else [] source_list = args.source if args.source is not None else [] + code_list = args.source_code if args.source_code is not None else [] - # FIXME: don't surround all sourced files with ``` - # -> do it only if '--source-code-only' is True and no source code - # could be extracted from that file - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: + for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + if question is not None and len(question.strip()) > 0: question_parts.append(question) - elif source is not None: + if source is not None and len(source) > 0: with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + if code is not None and len(code) > 0: + with open(code) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 88121b4..f7163ab 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,9 +67,8 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 96b2fdf..06cc527 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -23,7 +23,7 @@ class TestMessageCreate(unittest.TestCase): # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source = None - self.args.source_code_only = False + self.args.source_code = None self.args.ai = None self.args.model = None self.args.output_tags = None @@ -94,11 +94,11 @@ Is it good?""")) # source file contains no source code # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_embedded_source_source(self) -> None: self.args.question = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) @@ -106,6 +106,20 @@ Is it good?""")) # source file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file2_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source_code = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question(f"""What is this? + +``` +{self.source_file2_content} +```"""))