ai_factory: added optional 'def_ai' and 'def_model' arguments to 'create_ai'

This commit is contained in:
juk0de 2023-09-22 07:41:15 +02:00
parent 6b4ce8448f
commit 6d085a6c80

View File

@ -3,18 +3,20 @@ Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast
from typing import cast, Optional
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
"""
Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
Creates an AI subclass instance from the given arguments and configuration file.
If AI has not been set in the arguments, it searches for the ID 'default'. If
that is not found, it uses the first AI in the list. It's also possible to
specify a default AI and model using 'def_ai' and 'def_model'.
"""
ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI:
@ -22,6 +24,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
@ -34,6 +38,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model:
ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature: