From 650601774369b27be937efdf41c0827b786ba708 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:37:06 +0200 Subject: [PATCH] configuration: improved config file format --- chatmastermind/configuration.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 398fa03..08f6cbe 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -17,6 +17,18 @@ class ConfigError(Exception): pass +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + @dataclass class AIConfig: """ @@ -48,13 +60,13 @@ class OpenAIConfig(AIConfig): # a default configuration ID: str = 'default' api_key: str = '0123456789' - system: str = 'You are an assistant' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 max_tokens: int = 4000 top_p: float = 1.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + system: str = 'You are an assistant' @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: @@ -62,14 +74,14 @@ class OpenAIConfig(AIConfig): Create OpenAIConfig from a dict. """ res = cls( - system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), temperature=float(source['temperature']), top_p=float(source['top_p']), frequency_penalty=float(source['frequency_penalty']), - presence_penalty=float(source['presence_penalty']) + presence_penalty=float(source['presence_penalty']), + system=str(source['system']) ) # overwrite default ID if provided if 'ID' in source: @@ -148,6 +160,8 @@ class Config: def as_dict(self) -> dict[str, Any]: res = asdict(self) + # add the AI name manually (as first element) + # (not done by 'asdict' because it's a class variable) for ID, conf in res['ais'].items(): - conf.update({'name': self.ais[ID].name}) + res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf} return res