ai_factory: added optional 'def_ai' and 'def_model' arguments to 'create_ai'

This commit is contained in:
juk0de 2023-09-22 07:41:15 +02:00
parent 6b4ce8448f
commit 6d085a6c80

View File

@ -3,18 +3,20 @@ Creates different AI instances, based on the given configuration.
""" """
import argparse import argparse
from typing import cast from typing import cast, Optional
from .configuration import Config, AIConfig, OpenAIConfig from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
""" """
Creates an AI subclass instance from the given arguments Creates an AI subclass instance from the given arguments and configuration file.
and configuration file. If AI has not been set in the If AI has not been set in the arguments, it searches for the ID 'default'. If
arguments, it searches for the ID 'default'. If that that is not found, it uses the first AI in the list. It's also possible to
is not found, it uses the first AI in the list. specify a default AI and model using 'def_ai' and 'def_model'.
""" """
ai_conf: AIConfig ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI: if hasattr(args, 'AI') and args.AI:
@ -22,6 +24,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai_conf = config.ais[args.AI] ai_conf = config.ais[args.AI]
except KeyError: except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais: elif 'default' in config.ais:
ai_conf = config.ais['default'] ai_conf = config.ais['default']
else: else:
@ -34,6 +38,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model: if hasattr(args, 'model') and args.model:
ai.config.model = args.model ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens: if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature: if hasattr(args, 'temperature') and args.temperature: