configuration: made 'default' AI ID optional

This commit is contained in:
juk0de 2023-09-11 07:38:49 +02:00
parent c143c001f9
commit d4021eeb11
4 changed files with 22 additions and 17 deletions

View File

@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration.
import argparse import argparse
from typing import cast from typing import cast
from .configuration import Config, OpenAIConfig, default_ai_ID from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
""" """
Creates an AI subclass instance from the given arguments Creates an AI subclass instance from the given arguments
and configuration file. and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
""" """
ai_conf: AIConfig
if args.AI: if args.AI:
try: try:
ai_conf = config.ais[args.AI] ai_conf = config.ais[args.AI]
except KeyError: except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif default_ai_ID in config.ais: elif 'default' in config.ais:
ai_conf = config.ais[default_ai_ID] ai_conf = config.ais['default']
else: else:
raise AIError("No AI name given and no default exists") try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai': if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))

View File

@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai'] supported_ais: list[str] = ['openai']
default_ai_ID: str = 'default'
default_config_path = '.config.yaml' default_config_path = '.config.yaml'
@ -58,7 +57,7 @@ class OpenAIConfig(AIConfig):
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
ID: str = 'default' ID: str = 'myopenai'
api_key: str = '0123456789' api_key: str = '0123456789'
model: str = 'gpt-3.5-turbo-16k' model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0 temperature: float = 1.0

View File

@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase): class TestCreateAI(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'default' self.args.AI = 'myopenai'
self.args.model = None self.args.model = None
self.args.max_tokens = None self.args.max_tokens = None
self.args.temperature = None self.args.temperature = None
@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase):
def test_create_ai_from_args(self) -> None: def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration # Create an AI with the default configuration
config = Config() config = Config()
self.args.AI = 'default' self.args.AI = 'myopenai'
ai = create_ai(self.args, config) ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI) self.assertIsInstance(ai, OpenAI)

View File

@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase):
source_dict = { source_dict = {
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'myopenai': {
'name': 'openai', 'name': 'openai',
'system': 'Custom system', 'system': 'Custom system',
'api_key': '9876543210', 'api_key': '9876543210',
@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase):
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added # check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default') self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None: def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name)) Config.create_default(Path(self.test_file.name))
@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase):
config = Config( config = Config(
db='./test_db/', db='./test_db/',
ais={ ais={
'default': OpenAIConfig( 'myopenai': OpenAIConfig(
ID='default', ID='myopenai',
system='Custom system', system='Custom system',
api_key='9876543210', api_key='9876543210',
model='custom_model', model='custom_model',
@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase):
saved_config = yaml.load(f, Loader=yaml.FullLoader) saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None: def test_from_file_error_unknown_ai(self) -> None:
source_dict = { source_dict = {