From 83448fa0961b2b1ec23a828a0de75247fb4d32e2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 10:40:22 +0200 Subject: [PATCH] configuration: added tests --- chatmastermind/configuration.py | 2 +- tests/test_configuration.py | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 tests/test_configuration.py diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index d82f913..398fa03 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -87,7 +87,7 @@ def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> else: return OpenAIConfig.from_dict(conf_dict) else: - raise ConfigError(f"AI '{name}' is not supported") + raise ConfigError(f"Unknown AI '{name}'") def create_default_ai_configs() -> dict[str, AIConfig]: diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 0000000..f3f9a98 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,160 @@ +import os +import unittest +import yaml +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import cast +from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config + + +class TestAIConfigInstance(unittest.TestCase): + def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: + ai_config = cast(OpenAIConfig, ai_config_instance('openai')) + ai_reference = OpenAIConfig() + self.assertEqual(ai_config.ID, ai_reference.ID) + self.assertEqual(ai_config.name, ai_reference.name) + self.assertEqual(ai_config.api_key, ai_reference.api_key) + self.assertEqual(ai_config.system, ai_reference.system) + self.assertEqual(ai_config.model, ai_reference.model) + self.assertEqual(ai_config.temperature, ai_reference.temperature) + self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) + self.assertEqual(ai_config.top_p, ai_reference.top_p) + self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) + self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) + + def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: + conf_dict = { + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) + self.assertEqual(ai_config.system, 'Custom system') + self.assertEqual(ai_config.api_key, '9876543210') + self.assertEqual(ai_config.model, 'custom_model') + self.assertEqual(ai_config.max_tokens, 5000) + self.assertAlmostEqual(ai_config.temperature, 0.5) + self.assertAlmostEqual(ai_config.top_p, 0.8) + self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) + self.assertAlmostEqual(ai_config.presence_penalty, 0.2) + + def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: + with self.assertRaises(ConfigError): + ai_config_instance('invalid_name') + + +class TestConfig(unittest.TestCase): + def setUp(self) -> None: + self.test_file = NamedTemporaryFile(delete=False) + + def tearDown(self) -> None: + os.remove(self.test_file.name) + + def test_from_dict_should_create_config_from_dict(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + config = Config.from_dict(source_dict) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertEqual(config.ais['default'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + # check that 'ID' has been added + self.assertEqual(config.ais['default'].ID, 'default') + + def test_create_default_should_create_default_config(self) -> None: + Config.create_default(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + default_config = yaml.load(f, Loader=yaml.FullLoader) + config_reference = Config() + self.assertEqual(default_config['db'], config_reference.db) + + def test_from_file_should_load_config_from_file(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + config = Config.from_file(self.test_file.name) + self.assertIsInstance(config, Config) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertIsInstance(config.ais['default'], AIConfig) + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + + def test_to_file_should_save_config_to_file(self) -> None: + config = Config( + db='./test_db/', + ais={ + 'default': OpenAIConfig( + ID='default', + system='Custom system', + api_key='9876543210', + model='custom_model', + max_tokens=5000, + temperature=0.5, + top_p=0.8, + frequency_penalty=0.7, + presence_penalty=0.2 + ) + } + ) + config.to_file(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['db'], './test_db/') + self.assertEqual(len(saved_config['ais']), 1) + self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + + def test_from_file_error_unknown_ai(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'foobla', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + with self.assertRaises(ConfigError): + Config.from_file(self.test_file.name)