173 lines
7.0 KiB
Python
173 lines
7.0 KiB
Python
import os
|
|
import unittest
|
|
import yaml
|
|
from tempfile import NamedTemporaryFile
|
|
from pathlib import Path
|
|
from typing import cast
|
|
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
|
|
|
|
|
|
class TestAIConfigInstance(unittest.TestCase):
|
|
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
|
|
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
|
|
ai_reference = OpenAIConfig()
|
|
self.assertEqual(ai_config.ID, ai_reference.ID)
|
|
self.assertEqual(ai_config.name, ai_reference.name)
|
|
self.assertEqual(ai_config.api_key, ai_reference.api_key)
|
|
self.assertEqual(ai_config.system, ai_reference.system)
|
|
self.assertEqual(ai_config.model, ai_reference.model)
|
|
self.assertEqual(ai_config.temperature, ai_reference.temperature)
|
|
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
|
|
self.assertEqual(ai_config.top_p, ai_reference.top_p)
|
|
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
|
|
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
|
|
|
|
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
|
|
conf_dict = {
|
|
'system': 'Custom system',
|
|
'api_key': '9876543210',
|
|
'model': 'custom_model',
|
|
'max_tokens': 5000,
|
|
'temperature': 0.5,
|
|
'top_p': 0.8,
|
|
'frequency_penalty': 0.7,
|
|
'presence_penalty': 0.2
|
|
}
|
|
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
|
|
self.assertEqual(ai_config.system, 'Custom system')
|
|
self.assertEqual(ai_config.api_key, '9876543210')
|
|
self.assertEqual(ai_config.model, 'custom_model')
|
|
self.assertEqual(ai_config.max_tokens, 5000)
|
|
self.assertAlmostEqual(ai_config.temperature, 0.5)
|
|
self.assertAlmostEqual(ai_config.top_p, 0.8)
|
|
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
|
|
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
|
|
|
|
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
|
|
with self.assertRaises(ConfigError):
|
|
ai_config_instance('invalid_name')
|
|
|
|
|
|
class TestConfig(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.test_file = NamedTemporaryFile(delete=False)
|
|
|
|
def tearDown(self) -> None:
|
|
os.remove(self.test_file.name)
|
|
|
|
def test_from_dict_should_create_config_from_dict(self) -> None:
|
|
source_dict = {
|
|
'cache': '.',
|
|
'db': './test_db/',
|
|
'ais': {
|
|
'myopenai': {
|
|
'name': 'openai',
|
|
'system': 'Custom system',
|
|
'api_key': '9876543210',
|
|
'model': 'custom_model',
|
|
'max_tokens': 5000,
|
|
'temperature': 0.5,
|
|
'top_p': 0.8,
|
|
'frequency_penalty': 0.7,
|
|
'presence_penalty': 0.2
|
|
}
|
|
},
|
|
'glossaries': './glossaries/'
|
|
}
|
|
config = Config.from_dict(source_dict)
|
|
self.assertEqual(config.cache, '.')
|
|
self.assertEqual(config.db, './test_db/')
|
|
self.assertEqual(config.glossaries, './glossaries/')
|
|
self.assertEqual(len(config.ais), 1)
|
|
self.assertEqual(config.ais['myopenai'].name, 'openai')
|
|
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
|
|
# check that 'ID' has been added
|
|
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
|
|
|
|
def test_create_default_should_create_default_config(self) -> None:
|
|
Config.create_default(Path(self.test_file.name))
|
|
with open(self.test_file.name, 'r') as f:
|
|
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
config_reference = Config()
|
|
self.assertEqual(default_config['db'], config_reference.db)
|
|
|
|
def test_from_file_should_load_config_from_file(self) -> None:
|
|
source_dict = {
|
|
'cache': './test_cache/',
|
|
'db': './test_db/',
|
|
'ais': {
|
|
'default': {
|
|
'name': 'openai',
|
|
'system': 'Custom system',
|
|
'api_key': '9876543210',
|
|
'model': 'custom_model',
|
|
'max_tokens': 5000,
|
|
'temperature': 0.5,
|
|
'top_p': 0.8,
|
|
'frequency_penalty': 0.7,
|
|
'presence_penalty': 0.2
|
|
}
|
|
# omit glossaries, since it's optional
|
|
}
|
|
}
|
|
with open(self.test_file.name, 'w') as f:
|
|
yaml.dump(source_dict, f)
|
|
config = Config.from_file(self.test_file.name)
|
|
self.assertIsInstance(config, Config)
|
|
self.assertEqual(config.cache, './test_cache/')
|
|
self.assertEqual(config.db, './test_db/')
|
|
# missing 'glossaries' should result in 'None'
|
|
self.assertEqual(config.glossaries, None)
|
|
self.assertEqual(len(config.ais), 1)
|
|
self.assertIsInstance(config.ais['default'], AIConfig)
|
|
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
|
|
|
|
def test_to_file_should_save_config_to_file(self) -> None:
|
|
config = Config(
|
|
cache='./test_cache/',
|
|
db='./test_db/',
|
|
ais={
|
|
'myopenai': OpenAIConfig(
|
|
ID='myopenai',
|
|
system='Custom system',
|
|
api_key='9876543210',
|
|
model='custom_model',
|
|
max_tokens=5000,
|
|
temperature=0.5,
|
|
top_p=0.8,
|
|
frequency_penalty=0.7,
|
|
presence_penalty=0.2
|
|
)
|
|
}
|
|
)
|
|
config.to_file(Path(self.test_file.name))
|
|
with open(self.test_file.name, 'r') as f:
|
|
saved_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
self.assertEqual(saved_config['cache'], './test_cache/')
|
|
self.assertEqual(saved_config['db'], './test_db/')
|
|
self.assertEqual(len(saved_config['ais']), 1)
|
|
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
|
|
|
|
def test_from_file_error_unknown_ai(self) -> None:
|
|
source_dict = {
|
|
'cache': './test_cache/',
|
|
'db': './test_db/',
|
|
'ais': {
|
|
'default': {
|
|
'name': 'foobla',
|
|
'system': 'Custom system',
|
|
'api_key': '9876543210',
|
|
'model': 'custom_model',
|
|
'max_tokens': 5000,
|
|
'temperature': 0.5,
|
|
'top_p': 0.8,
|
|
'frequency_penalty': 0.7,
|
|
'presence_penalty': 0.2
|
|
}
|
|
}
|
|
}
|
|
with open(self.test_file.name, 'w') as f:
|
|
yaml.dump(source_dict, f)
|
|
with self.assertRaises(ConfigError):
|
|
Config.from_file(self.test_file.name)
|