diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ad68cba..1a04b94 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,15 +6,26 @@ import yaml import sys import argcomplete import argparse -from .utils import terminal_width, pp, process_tags, display_chat -from .storage import save_answers, create_chat, get_tags +import pathlib +from .utils import terminal_width, process_tags, display_chat, display_source_code +from .storage import save_answers, create_chat, get_tags, read_file, dump_data from .api_client import ai, openai_api_key def run_print_command(args: argparse.Namespace, config: dict) -> None: - with open(args.print, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - pp(data) + fname = pathlib.Path(args.print) + if fname.suffix == '.yaml': + with open(args.print, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + elif fname.suffix == '.txt': + data = read_file(fname) + else: + print(f"Unknown file type: {args.print}") + sys.exit(1) + if args.only_source_code: + display_source_code(data['answer']) + else: + print(dump_data(data).strip()) def process_and_display_chat(args: argparse.Namespace, @@ -74,7 +85,7 @@ def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") group = parser.add_mutually_exclusive_group(required=True) - group.add_argument('-p', '--print', help='YAML file to print') + group.add_argument('-p', '--print', help='File to print') group.add_argument('-q', '--question', nargs='*', help='Question to ask') group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true') diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index d0d05ae..5f2af92 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -40,24 +40,28 @@ def message_to_chat(message: Dict[str, str], append_message(chat, 'assistant', message['answer']) +def display_source_code(content: str) -> None: + code_block_count = 0 + for line in content.splitlines(): + if line.strip().startswith('```'): + code_block_count += 1 + elif code_block_count == 1: + print(line) + + def display_chat(chat, dump=False, source_code=False) -> None: if dump: pp(chat) return for message in chat: - if message['role'] == 'user' and not source_code: - print('-' * (terminal_width())) - if len(message['content']) > terminal_width() - len(message['role']) - 2: - if not source_code: - print(f"{message['role'].upper()}:") - if source_code: - out = 0 - for line in message['content'].splitlines(): - if line.strip().startswith('```'): - out += 1 - elif out == 1: - print(f"{line}") - else: - print(message['content']) - elif not source_code: + text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 + if source_code: + display_source_code(message['content']) + continue + if message['role'] == 'user': + print('-' * terminal_width()) + if text_too_long: + print(f"{message['role'].upper()}:") + print(message['content']) + else: print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py index 9fe4a6b..eca160f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -201,9 +201,9 @@ class TestCreateParser(unittest.TestCase): parser = create_parser() self.assertIsInstance(parser, argparse.ArgumentParser) mock_add_mutually_exclusive_group.assert_called_once_with(required=True) - mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print') + mock_group.add_argument.assert_any_call('-p', '--print', help='File to print') mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask') mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true') self.assertTrue('.config.yaml' in parser.get_default('config')) - self.assertEqual(parser.get_default('number'), 3) + self.assertEqual(parser.get_default('number'), 1)