Compare commits

..

6 Commits

2 changed files with 19 additions and 13 deletions

View File

@ -2,16 +2,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
import yaml
import sys import sys
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path from pathlib import Path
from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
from .storage import save_answers, create_chat_hist from .storage import save_answers, create_chat_hist, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config from .configuration import Config
from .chat import ChatDB from .chat import ChatDB
from .message import Message, MessageFilter, MessageError from .message import Message, MessageFilter
from itertools import zip_longest from itertools import zip_longest
from typing import Any from typing import Any
@ -127,13 +128,18 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
Handler for the 'print' command. Handler for the 'print' command.
""" """
fname = Path(args.file) fname = Path(args.file)
try: if fname.suffix == '.yaml':
message = Message.from_file(fname) with open(args.file, 'r') as f:
if message: data = yaml.load(f, Loader=yaml.FullLoader)
print(message.to_str(source_code_only=args.source_code_only)) elif fname.suffix == '.txt':
except MessageError: data = read_file(fname)
print(f"File is not a valid message: {args.file}") else:
print(f"Unknown file type: {args.file}")
sys.exit(1) sys.exit(1)
if args.source_code_only:
display_source_code(data['answer'])
else:
print(dump_data(data).strip())
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
@ -217,11 +223,11 @@ def create_parser() -> argparse.ArgumentParser:
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.", help="Print files.",
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true') action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)

View File

@ -398,10 +398,10 @@ class Message():
""" """
output: list[str] = [] output: list[str] = []
if source_code_only: if source_code_only:
# use the source code from answer only output.extend(self.question.source_code(include_delims=True))
if self.answer: if self.answer:
output.extend(self.answer.source_code(include_delims=True)) output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 1 else '' return '\n'.join(output)
if with_tags: if with_tags:
output.append(self.tags_str()) output.append(self.tags_str())
if with_file: if with_file: