Compare commits

..

3 Commits

2 changed files with 9 additions and 71 deletions

View File

@ -1,6 +1,6 @@
import yaml
from pathlib import Path
from typing import Type, TypeVar, Any, Optional, ClassVar
from typing import Type, TypeVar, Any, Optional, Final
from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config')
@ -22,19 +22,8 @@ class AIConfig:
"""
The base class of all AI configurations.
"""
# the name of the AI the config class represents
# -> it's a class variable and thus not part of the
# dataclass constructor
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name':
raise AttributeError("'{name}' is not allowed to be changed")
else:
super().__setattr__(name, value)
name: str
@dataclass
@ -42,8 +31,8 @@ class OpenAIConfig(AIConfig):
"""
The OpenAI section of the configuration file.
"""
name: ClassVar[str] = 'openai'
# the name must not be changed
name: Final[str] = 'openai'
# all members have default values, so we can easily create
# a default configuration
ID: str = 'default'
@ -76,6 +65,9 @@ class OpenAIConfig(AIConfig):
res.ID = source['ID']
return res
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
@ -147,7 +139,4 @@ class Config:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]:
res = asdict(self)
for ID, conf in res['ais'].items():
conf.update({'name': self.ais[ID].name})
return res
return asdict(self)

View File

@ -4,7 +4,7 @@ import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
@ -86,54 +86,3 @@ class TestConfig(unittest.TestCase):
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
db='./test_db/',
ais={
'default': OpenAIConfig(
ID='default',
system='Custom system',
api_key='9876543210',
model='custom_model',
max_tokens=5000,
temperature=0.5,
top_p=0.8,
frequency_penalty=0.7,
presence_penalty=0.2
)
}
)
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system')