Compare commits

...

3 Commits

Author SHA1 Message Date
c4a7c07a0c fixed tests 2023-08-12 14:16:16 +02:00
22bebc16ed fixed min nr of expected arguments 2023-08-12 14:16:16 +02:00
f7ba0c000f renamed 'model' command to 'config' 2023-08-12 14:16:16 +02:00
2 changed files with 27 additions and 15 deletions

View File

@ -8,7 +8,7 @@ import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from itertools import zip_longest from itertools import zip_longest
@ -63,12 +63,20 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print_tags_frequency(get_tags(config, None)) print_tags_frequency(get_tags(config, None))
def model_cmd(args: argparse.Namespace, config: ConfigType) -> None: def config_cmd(args: argparse.Namespace, config: ConfigType) -> None:
""" """
Handler for the 'model' command. Handler for the 'config' command.
""" """
if args.list: if type(config['openai']) is not dict:
raise RuntimeError('Configuration openai is not a dict.')
if args.list_models:
print_models() print_models()
elif args.show_model:
print(config['openai']['model'])
elif args.model:
config['openai']['model'] = args.model
write_config(args.config, config)
def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None:
@ -139,13 +147,13 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='*', tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tag names', metavar='TAGS') help='List of tag names', metavar='TAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
extag_arg = tag_parser.add_argument('-e', '--extags', nargs='*', extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+',
help='List of tag names to exclude', metavar='EXTAGS') help='List of tag names to exclude', metavar='EXTAGS')
extag_arg.completer = tags_completer # type: ignore extag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='*', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tag names, default is input', metavar='OTAGS') help='List of output tag names, default is input', metavar='OTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
tag_parser.add_argument('-a', '--match-all-tags', tag_parser.add_argument('-a', '--match-all-tags',
@ -164,7 +172,7 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser.add_argument('-M', '--model', help='Model to use') ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1) default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
action='store_true') action='store_true')
@ -188,12 +196,16 @@ def create_parser() -> argparse.ArgumentParser:
tag_cmd_parser.add_argument('-l', '--list', help="List all tags and their frequency", tag_cmd_parser.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
# 'model' command parser # 'config' command parser
model_cmd_parser = cmdparser.add_parser('model', config_cmd_parser = cmdparser.add_parser('config',
help="Manage models.") help="Manage configuration")
model_cmd_parser.set_defaults(func=model_cmd) config_cmd_parser.set_defaults(func=config_cmd)
model_cmd_parser.add_argument('-l', '--list', help="List all available models", config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
config_group.add_argument('-L', '--list-models', help="List all available models",
action='store_true') action='store_true')
config_group.add_argument('-m', '--show-model', help="Show current model",
action='store_true')
config_group.add_argument('-M', '--model', help="Set model in the config file")
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',

View File

@ -222,6 +222,6 @@ class TestCreateParser(unittest.TestCase):
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY)
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY) mock_cmdparser.add_parser.assert_any_call('tag', help=ANY)
mock_cmdparser.add_parser.assert_any_call('model', help=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertTrue('.config.yaml' in parser.get_default('config'))