Compare commits

..

3 Commits

2 changed files with 71 additions and 9 deletions

View File

@ -1,6 +1,6 @@
import yaml import yaml
from pathlib import Path from pathlib import Path
from typing import Type, TypeVar, Any, Optional, Final from typing import Type, TypeVar, Any, Optional, ClassVar
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
@ -22,8 +22,19 @@ class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
# the name of the AI the config class represents
# -> it's a class variable and thus not part of the
# dataclass constructor
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str ID: str
name: str
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name':
raise AttributeError("'{name}' is not allowed to be changed")
else:
super().__setattr__(name, value)
@dataclass @dataclass
@ -31,8 +42,8 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
# the name must not be changed name: ClassVar[str] = 'openai'
name: Final[str] = 'openai'
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
ID: str = 'default' ID: str = 'default'
@ -65,9 +76,6 @@ class OpenAIConfig(AIConfig):
res.ID = source['ID'] res.ID = source['ID']
return res return res
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
""" """
@ -139,4 +147,7 @@ class Config:
yaml.dump(data, f, sort_keys=False) yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) res = asdict(self)
for ID, conf in res['ais'].items():
conf.update({'name': self.ais[ID].name})
return res

View File

@ -4,7 +4,7 @@ import yaml
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase): class TestAIConfigInstance(unittest.TestCase):
@ -86,3 +86,54 @@ class TestConfig(unittest.TestCase):
default_config = yaml.load(f, Loader=yaml.FullLoader) default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config() config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db) self.assertEqual(default_config['db'], config_reference.db)
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
db='./test_db/',
ais={
'default': OpenAIConfig(
ID='default',
system='Custom system',
api_key='9876543210',
model='custom_model',
max_tokens=5000,
temperature=0.5,
top_p=0.8,
frequency_penalty=0.7,
presence_penalty=0.2
)
}
)
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system')