diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index bc4583c..420b287 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration. import argparse from typing import cast -from .configuration import Config, OpenAIConfig, default_ai_ID +from .configuration import Config, AIConfig, OpenAIConfig from .ai import AI, AIError from .ais.openai import OpenAI -def create_ai(args: argparse.Namespace, config: Config) -> AI: +def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 """ Creates an AI subclass instance from the given arguments - and configuration file. + and configuration file. If AI has not been set in the + arguments, it searches for the ID 'default'. If that + is not found, it uses the first AI in the list. """ + ai_conf: AIConfig if args.AI: try: ai_conf = config.ais[args.AI] except KeyError: raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") - elif default_ai_ID in config.ais: - ai_conf = config.ais[default_ai_ID] + elif 'default' in config.ais: + ai_conf = config.ais['default'] else: - raise AIError("No AI name given and no default exists") + try: + ai_conf = next(iter(config.ais.values())) + except StopIteration: + raise AIError("No AI found in this configuration") if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 08f6cbe..5397f4a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') supported_ais: list[str] = ['openai'] -default_ai_ID: str = 'default' default_config_path = '.config.yaml' @@ -58,7 +57,7 @@ class OpenAIConfig(AIConfig): # all members have default values, so we can easily create # a default configuration - ID: str = 'default' + ID: str = 'myopenai' api_key: str = '0123456789' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d00b319..9cb94d3 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.AI = 'default' + self.args.AI = 'myopenai' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.AI = 'default' + self.args.AI = 'myopenai' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index f3f9a98..ba8a5aa 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase): source_dict = { 'db': './test_db/', 'ais': { - 'default': { + 'myopenai': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', @@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase): config = Config.from_dict(source_dict) self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) - self.assertEqual(config.ais['default'].name, 'openai') - self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + self.assertEqual(config.ais['myopenai'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') # check that 'ID' has been added - self.assertEqual(config.ais['default'].ID, 'default') + self.assertEqual(config.ais['myopenai'].ID, 'myopenai') def test_create_default_should_create_default_config(self) -> None: Config.create_default(Path(self.test_file.name)) @@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase): config = Config( db='./test_db/', ais={ - 'default': OpenAIConfig( - ID='default', + 'myopenai': OpenAIConfig( + ID='myopenai', system='Custom system', api_key='9876543210', model='custom_model', @@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase): saved_config = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) - self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = {