From 6d085a6c808d563f6623936e1e084e400ee97f16 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 22 Sep 2023 07:41:15 +0200 Subject: [PATCH] ai_factory: added optional 'def_ai' and 'def_model' arguments to 'create_ai' --- chatmastermind/ai_factory.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index 36a987b..42b27c1 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -3,18 +3,20 @@ Creates different AI instances, based on the given configuration. """ import argparse -from typing import cast +from typing import cast, Optional from .configuration import Config, AIConfig, OpenAIConfig from .ai import AI, AIError from .ais.openai import OpenAI -def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 +def create_ai(args: argparse.Namespace, config: Config, # noqa: 11 + def_ai: Optional[str] = None, + def_model: Optional[str] = None) -> AI: """ - Creates an AI subclass instance from the given arguments - and configuration file. If AI has not been set in the - arguments, it searches for the ID 'default'. If that - is not found, it uses the first AI in the list. + Creates an AI subclass instance from the given arguments and configuration file. + If AI has not been set in the arguments, it searches for the ID 'default'. If + that is not found, it uses the first AI in the list. It's also possible to + specify a default AI and model using 'def_ai' and 'def_model'. """ ai_conf: AIConfig if hasattr(args, 'AI') and args.AI: @@ -22,6 +24,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 ai_conf = config.ais[args.AI] except KeyError: raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") + elif def_ai: + ai_conf = config.ais[def_ai] elif 'default' in config.ais: ai_conf = config.ais['default'] else: @@ -34,6 +38,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 ai = OpenAI(cast(OpenAIConfig, ai_conf)) if hasattr(args, 'model') and args.model: ai.config.model = args.model + elif def_model: + ai.config.model = def_model if hasattr(args, 'max_tokens') and args.max_tokens: ai.config.max_tokens = args.max_tokens if hasattr(args, 'temperature') and args.temperature: