Compare commits
3 Commits
4eddc55197
...
beb612e642
| Author | SHA1 | Date | |
|---|---|---|---|
| beb612e642 | |||
| 0f6a31940f | |||
| 50d41d573f |
@ -1,6 +1,6 @@
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type, TypeVar, Any, Optional, ClassVar
|
from typing import Type, TypeVar, Any, Optional, Final
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
|
|
||||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||||
@ -22,19 +22,8 @@ class AIConfig:
|
|||||||
"""
|
"""
|
||||||
The base class of all AI configurations.
|
The base class of all AI configurations.
|
||||||
"""
|
"""
|
||||||
# the name of the AI the config class represents
|
|
||||||
# -> it's a class variable and thus not part of the
|
|
||||||
# dataclass constructor
|
|
||||||
name: ClassVar[str]
|
|
||||||
# a user-defined ID for an AI configuration entry
|
|
||||||
ID: str
|
ID: str
|
||||||
|
name: str
|
||||||
# the name must not be changed
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
|
||||||
if name == 'name':
|
|
||||||
raise AttributeError("'{name}' is not allowed to be changed")
|
|
||||||
else:
|
|
||||||
super().__setattr__(name, value)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -42,8 +31,8 @@ class OpenAIConfig(AIConfig):
|
|||||||
"""
|
"""
|
||||||
The OpenAI section of the configuration file.
|
The OpenAI section of the configuration file.
|
||||||
"""
|
"""
|
||||||
name: ClassVar[str] = 'openai'
|
# the name must not be changed
|
||||||
|
name: Final[str] = 'openai'
|
||||||
# all members have default values, so we can easily create
|
# all members have default values, so we can easily create
|
||||||
# a default configuration
|
# a default configuration
|
||||||
ID: str = 'default'
|
ID: str = 'default'
|
||||||
@ -76,6 +65,9 @@ class OpenAIConfig(AIConfig):
|
|||||||
res.ID = source['ID']
|
res.ID = source['ID']
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
||||||
"""
|
"""
|
||||||
@ -147,7 +139,4 @@ class Config:
|
|||||||
yaml.dump(data, f, sort_keys=False)
|
yaml.dump(data, f, sort_keys=False)
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
res = asdict(self)
|
return asdict(self)
|
||||||
for ID, conf in res['ais'].items():
|
|
||||||
conf.update({'name': self.ais[ID].name})
|
|
||||||
return res
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import yaml
|
|||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
|
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config
|
||||||
|
|
||||||
|
|
||||||
class TestAIConfigInstance(unittest.TestCase):
|
class TestAIConfigInstance(unittest.TestCase):
|
||||||
@ -86,54 +86,3 @@ class TestConfig(unittest.TestCase):
|
|||||||
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
config_reference = Config()
|
config_reference = Config()
|
||||||
self.assertEqual(default_config['db'], config_reference.db)
|
self.assertEqual(default_config['db'], config_reference.db)
|
||||||
|
|
||||||
def test_from_file_should_load_config_from_file(self) -> None:
|
|
||||||
source_dict = {
|
|
||||||
'db': './test_db/',
|
|
||||||
'ais': {
|
|
||||||
'default': {
|
|
||||||
'name': 'openai',
|
|
||||||
'system': 'Custom system',
|
|
||||||
'api_key': '9876543210',
|
|
||||||
'model': 'custom_model',
|
|
||||||
'max_tokens': 5000,
|
|
||||||
'temperature': 0.5,
|
|
||||||
'top_p': 0.8,
|
|
||||||
'frequency_penalty': 0.7,
|
|
||||||
'presence_penalty': 0.2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
with open(self.test_file.name, 'w') as f:
|
|
||||||
yaml.dump(source_dict, f)
|
|
||||||
|
|
||||||
config = Config.from_file(self.test_file.name)
|
|
||||||
self.assertIsInstance(config, Config)
|
|
||||||
self.assertEqual(config.db, './test_db/')
|
|
||||||
self.assertEqual(len(config.ais), 1)
|
|
||||||
self.assertIsInstance(config.ais['default'], AIConfig)
|
|
||||||
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
|
|
||||||
|
|
||||||
def test_to_file_should_save_config_to_file(self) -> None:
|
|
||||||
config = Config(
|
|
||||||
db='./test_db/',
|
|
||||||
ais={
|
|
||||||
'default': OpenAIConfig(
|
|
||||||
ID='default',
|
|
||||||
system='Custom system',
|
|
||||||
api_key='9876543210',
|
|
||||||
model='custom_model',
|
|
||||||
max_tokens=5000,
|
|
||||||
temperature=0.5,
|
|
||||||
top_p=0.8,
|
|
||||||
frequency_penalty=0.7,
|
|
||||||
presence_penalty=0.2
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
config.to_file(Path(self.test_file.name))
|
|
||||||
with open(self.test_file.name, 'r') as f:
|
|
||||||
saved_config = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
self.assertEqual(saved_config['db'], './test_db/')
|
|
||||||
self.assertEqual(len(saved_config['ais']), 1)
|
|
||||||
self.assertEqual(saved_config['ais']['default']['system'], 'Custom system')
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user