configuration: added tests

This commit is contained in:
juk0de 2023-09-08 10:40:22 +02:00
parent d7055163f9
commit fafad145e5

View File

@ -0,0 +1,88 @@
import os
import unittest
import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
ai_reference = OpenAIConfig()
self.assertEqual(ai_config.ID, ai_reference.ID)
self.assertEqual(ai_config.name, ai_reference.name)
self.assertEqual(ai_config.api_key, ai_reference.api_key)
self.assertEqual(ai_config.system, ai_reference.system)
self.assertEqual(ai_config.model, ai_reference.model)
self.assertEqual(ai_config.temperature, ai_reference.temperature)
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
self.assertEqual(ai_config.top_p, ai_reference.top_p)
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
conf_dict = {
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
self.assertEqual(ai_config.system, 'Custom system')
self.assertEqual(ai_config.api_key, '9876543210')
self.assertEqual(ai_config.model, 'custom_model')
self.assertEqual(ai_config.max_tokens, 5000)
self.assertAlmostEqual(ai_config.temperature, 0.5)
self.assertAlmostEqual(ai_config.top_p, 0.8)
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
with self.assertRaises(ConfigError):
ai_config_instance('invalid_name')
class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.test_file = NamedTemporaryFile(delete=False)
def tearDown(self) -> None:
os.remove(self.test_file.name)
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)