Compare commits

..

1 Commits

3 changed files with 21 additions and 65 deletions

View File

@ -3,20 +3,18 @@ Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast, Optional
from typing import cast
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
"""
Creates an AI subclass instance from the given arguments and configuration file.
If AI has not been set in the arguments, it searches for the ID 'default'. If
that is not found, it uses the first AI in the list. It's also possible to
specify a default AI and model using 'def_ai' and 'def_model'.
Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI:
@ -24,8 +22,6 @@ def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
@ -38,8 +34,6 @@ def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature:

View File

@ -2,7 +2,6 @@ import sys
import argparse
from pathlib import Path
from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, MessageError, Question, source_code
@ -106,37 +105,21 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(response.tokens)
def make_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None # noqa: W503
or len(args.output_tags) == 0): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if (args.output_tags is None or len(args.output_tags) == 0) and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
ai = create_ai(make_msg_args(msg, args), config)
ai_args = args
# if AI or model have not been specified, use those from the original message
if args.AI is None or args.model is None:
ai_args = args.copy()
if args.AI is None and msg.ai is not None:
ai_args.AI = msg.ai
if args.model is None and msg.model is not None:
ai_args.model = msg.model
ai = create_ai(ai_args, config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
@ -151,34 +134,13 @@ def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namesp
make_request(ai, chat, message, args)
def modify_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the tags for this command:
* not tags specified -> no tags are selected
* empty tags specified -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
if args.exclude_tags is None:
args.exclude_tags = set()
elif len(args.exclude_tags) == 0:
args.exclude_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
modify_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter)

View File

@ -34,13 +34,13 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
help='List of tags (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',