Compare commits
3 Commits
beb612e642
...
4eddc55197
| Author | SHA1 | Date | |
|---|---|---|---|
| 4eddc55197 | |||
| 2ac873fd95 | |||
| dc8b225a91 |
@ -1,6 +1,6 @@
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type, TypeVar, Any, Optional, Final
|
from typing import Type, TypeVar, Any, Optional, ClassVar
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
|
|
||||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||||
@ -22,8 +22,19 @@ class AIConfig:
|
|||||||
"""
|
"""
|
||||||
The base class of all AI configurations.
|
The base class of all AI configurations.
|
||||||
"""
|
"""
|
||||||
|
# the name of the AI the config class represents
|
||||||
|
# -> it's a class variable and thus not part of the
|
||||||
|
# dataclass constructor
|
||||||
|
name: ClassVar[str]
|
||||||
|
# a user-defined ID for an AI configuration entry
|
||||||
ID: str
|
ID: str
|
||||||
name: str
|
|
||||||
|
# the name must not be changed
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if name == 'name':
|
||||||
|
raise AttributeError("'{name}' is not allowed to be changed")
|
||||||
|
else:
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -31,8 +42,8 @@ class OpenAIConfig(AIConfig):
|
|||||||
"""
|
"""
|
||||||
The OpenAI section of the configuration file.
|
The OpenAI section of the configuration file.
|
||||||
"""
|
"""
|
||||||
# the name must not be changed
|
name: ClassVar[str] = 'openai'
|
||||||
name: Final[str] = 'openai'
|
|
||||||
# all members have default values, so we can easily create
|
# all members have default values, so we can easily create
|
||||||
# a default configuration
|
# a default configuration
|
||||||
ID: str = 'default'
|
ID: str = 'default'
|
||||||
@ -65,9 +76,6 @@ class OpenAIConfig(AIConfig):
|
|||||||
res.ID = source['ID']
|
res.ID = source['ID']
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
|
||||||
return asdict(self)
|
|
||||||
|
|
||||||
|
|
||||||
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
||||||
"""
|
"""
|
||||||
@ -139,4 +147,7 @@ class Config:
|
|||||||
yaml.dump(data, f, sort_keys=False)
|
yaml.dump(data, f, sort_keys=False)
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
return asdict(self)
|
res = asdict(self)
|
||||||
|
for ID, conf in res['ais'].items():
|
||||||
|
conf.update({'name': self.ais[ID].name})
|
||||||
|
return res
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import yaml
|
|||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config
|
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
|
||||||
|
|
||||||
|
|
||||||
class TestAIConfigInstance(unittest.TestCase):
|
class TestAIConfigInstance(unittest.TestCase):
|
||||||
@ -86,3 +86,54 @@ class TestConfig(unittest.TestCase):
|
|||||||
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
config_reference = Config()
|
config_reference = Config()
|
||||||
self.assertEqual(default_config['db'], config_reference.db)
|
self.assertEqual(default_config['db'], config_reference.db)
|
||||||
|
|
||||||
|
def test_from_file_should_load_config_from_file(self) -> None:
|
||||||
|
source_dict = {
|
||||||
|
'db': './test_db/',
|
||||||
|
'ais': {
|
||||||
|
'default': {
|
||||||
|
'name': 'openai',
|
||||||
|
'system': 'Custom system',
|
||||||
|
'api_key': '9876543210',
|
||||||
|
'model': 'custom_model',
|
||||||
|
'max_tokens': 5000,
|
||||||
|
'temperature': 0.5,
|
||||||
|
'top_p': 0.8,
|
||||||
|
'frequency_penalty': 0.7,
|
||||||
|
'presence_penalty': 0.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open(self.test_file.name, 'w') as f:
|
||||||
|
yaml.dump(source_dict, f)
|
||||||
|
|
||||||
|
config = Config.from_file(self.test_file.name)
|
||||||
|
self.assertIsInstance(config, Config)
|
||||||
|
self.assertEqual(config.db, './test_db/')
|
||||||
|
self.assertEqual(len(config.ais), 1)
|
||||||
|
self.assertIsInstance(config.ais['default'], AIConfig)
|
||||||
|
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
|
||||||
|
|
||||||
|
def test_to_file_should_save_config_to_file(self) -> None:
|
||||||
|
config = Config(
|
||||||
|
db='./test_db/',
|
||||||
|
ais={
|
||||||
|
'default': OpenAIConfig(
|
||||||
|
ID='default',
|
||||||
|
system='Custom system',
|
||||||
|
api_key='9876543210',
|
||||||
|
model='custom_model',
|
||||||
|
max_tokens=5000,
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.8,
|
||||||
|
frequency_penalty=0.7,
|
||||||
|
presence_penalty=0.2
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
config.to_file(Path(self.test_file.name))
|
||||||
|
with open(self.test_file.name, 'r') as f:
|
||||||
|
saved_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
self.assertEqual(saved_config['db'], './test_db/')
|
||||||
|
self.assertEqual(len(saved_config['ais']), 1)
|
||||||
|
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user