import os import unittest import yaml from tempfile import NamedTemporaryFile from pathlib import Path from typing import cast from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config class TestAIConfigInstance(unittest.TestCase): def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: ai_config = cast(OpenAIConfig, ai_config_instance('openai')) ai_reference = OpenAIConfig() self.assertEqual(ai_config.ID, ai_reference.ID) self.assertEqual(ai_config.name, ai_reference.name) self.assertEqual(ai_config.api_key, ai_reference.api_key) self.assertEqual(ai_config.system, ai_reference.system) self.assertEqual(ai_config.model, ai_reference.model) self.assertEqual(ai_config.temperature, ai_reference.temperature) self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) self.assertEqual(ai_config.top_p, ai_reference.top_p) self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: conf_dict = { 'system': 'Custom system', 'api_key': '9876543210', 'model': 'custom_model', 'max_tokens': 5000, 'temperature': 0.5, 'top_p': 0.8, 'frequency_penalty': 0.7, 'presence_penalty': 0.2 } ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) self.assertEqual(ai_config.system, 'Custom system') self.assertEqual(ai_config.api_key, '9876543210') self.assertEqual(ai_config.model, 'custom_model') self.assertEqual(ai_config.max_tokens, 5000) self.assertAlmostEqual(ai_config.temperature, 0.5) self.assertAlmostEqual(ai_config.top_p, 0.8) self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) self.assertAlmostEqual(ai_config.presence_penalty, 0.2) def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: with self.assertRaises(ConfigError): ai_config_instance('invalid_name') class TestConfig(unittest.TestCase): def setUp(self) -> None: self.test_file = NamedTemporaryFile(delete=False) def tearDown(self) -> None: os.remove(self.test_file.name) def test_from_dict_should_create_config_from_dict(self) -> None: source_dict = { 'cache': '.', 'db': './test_db/', 'ais': { 'myopenai': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', 'model': 'custom_model', 'max_tokens': 5000, 'temperature': 0.5, 'top_p': 0.8, 'frequency_penalty': 0.7, 'presence_penalty': 0.2 } }, 'glossaries': './glossaries/' } config = Config.from_dict(source_dict) self.assertEqual(config.cache, '.') self.assertEqual(config.db, './test_db/') self.assertEqual(config.glossaries, './glossaries/') self.assertEqual(len(config.ais), 1) self.assertEqual(config.ais['myopenai'].name, 'openai') self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') # check that 'ID' has been added self.assertEqual(config.ais['myopenai'].ID, 'myopenai') def test_create_default_should_create_default_config(self) -> None: Config.create_default(Path(self.test_file.name)) with open(self.test_file.name, 'r') as f: default_config = yaml.load(f, Loader=yaml.FullLoader) config_reference = Config() self.assertEqual(default_config['db'], config_reference.db) def test_from_file_should_load_config_from_file(self) -> None: source_dict = { 'cache': './test_cache/', 'db': './test_db/', 'ais': { 'default': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', 'model': 'custom_model', 'max_tokens': 5000, 'temperature': 0.5, 'top_p': 0.8, 'frequency_penalty': 0.7, 'presence_penalty': 0.2 } # omit glossaries, since it's optional } } with open(self.test_file.name, 'w') as f: yaml.dump(source_dict, f) config = Config.from_file(self.test_file.name) self.assertIsInstance(config, Config) self.assertEqual(config.cache, './test_cache/') self.assertEqual(config.db, './test_db/') # missing 'glossaries' should result in 'None' self.assertEqual(config.glossaries, None) self.assertEqual(len(config.ais), 1) self.assertIsInstance(config.ais['default'], AIConfig) self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') def test_to_file_should_save_config_to_file(self) -> None: config = Config( cache='./test_cache/', db='./test_db/', ais={ 'myopenai': OpenAIConfig( ID='myopenai', system='Custom system', api_key='9876543210', model='custom_model', max_tokens=5000, temperature=0.5, top_p=0.8, frequency_penalty=0.7, presence_penalty=0.2 ) } ) config.to_file(Path(self.test_file.name)) with open(self.test_file.name, 'r') as f: saved_config = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(saved_config['cache'], './test_cache/') self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = { 'cache': './test_cache/', 'db': './test_db/', 'ais': { 'default': { 'name': 'foobla', 'system': 'Custom system', 'api_key': '9876543210', 'model': 'custom_model', 'max_tokens': 5000, 'temperature': 0.5, 'top_p': 0.8, 'frequency_penalty': 0.7, 'presence_penalty': 0.2 } } } with open(self.test_file.name, 'w') as f: yaml.dump(source_dict, f) with self.assertRaises(ConfigError): Config.from_file(self.test_file.name)