70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
import yaml
|
|
from typing import Type, TypeVar, Any
|
|
from dataclasses import dataclass, asdict
|
|
|
|
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
|
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
|
|
|
|
|
@dataclass
|
|
class OpenAIConfig():
|
|
"""
|
|
The OpenAI section of the configuration file.
|
|
"""
|
|
api_key: str
|
|
model: str
|
|
temperature: float
|
|
max_tokens: int
|
|
top_p: float
|
|
frequency_penalty: float
|
|
presence_penalty: float
|
|
|
|
@classmethod
|
|
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
|
"""
|
|
Create OpenAIConfig from a dict.
|
|
"""
|
|
return cls(
|
|
api_key=str(source['api_key']),
|
|
model=str(source['model']),
|
|
max_tokens=int(source['max_tokens']),
|
|
temperature=float(source['temperature']),
|
|
top_p=float(source['top_p']),
|
|
frequency_penalty=float(source['frequency_penalty']),
|
|
presence_penalty=float(source['presence_penalty'])
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Config():
|
|
"""
|
|
The configuration file structure.
|
|
"""
|
|
system: str
|
|
db: str
|
|
openai: OpenAIConfig
|
|
|
|
@classmethod
|
|
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
|
"""
|
|
Create OpenAIConfig from a dict.
|
|
"""
|
|
return cls(
|
|
system=str(source['system']),
|
|
db=str(source['db']),
|
|
openai=OpenAIConfig.from_dict(source['openai'])
|
|
)
|
|
|
|
@classmethod
|
|
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
|
|
with open(path, 'r') as f:
|
|
source = yaml.load(f, Loader=yaml.FullLoader)
|
|
return cls.from_dict(source)
|
|
|
|
def to_file(self, path: str) -> None:
|
|
with open(path, 'w') as f:
|
|
yaml.dump(asdict(self), f, sort_keys=False)
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
return asdict(self)
|