diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 2917865..bc58574 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,4 +1,5 @@ -from typing import TypedDict +import pathlib +from typing import TypedDict, Any class OpenAIConfig(TypedDict): @@ -14,6 +15,25 @@ class OpenAIConfig(TypedDict): presence_penalty: float +def openai_config_valid(conf: dict[str, str | float | int]) -> bool: + """ + Checks if the given Open AI configuration dict is complete + and contains valid types and values. + """ + try: + str(conf['api_key']) + str(conf['model']) + int(conf['max_tokens']) + float(conf['temperature']) + float(conf['top_p']) + float(conf['frequency_penalty']) + float(conf['presence_penalty']) + return True + except Exception as e: + print(f"OpenAI configuration is invalid: {e}") + return False + + class Config(TypedDict): """ The configuration file structure. @@ -21,3 +41,23 @@ class Config(TypedDict): system: str db: str openai: OpenAIConfig + + +def config_valid(conf: dict[str, Any]) -> bool: + """ + Checks if the given configuration dict is complete + and contains valid types and values. + """ + try: + str(conf['system']) + pathlib.Path(str(conf['db'])) + return True + except Exception as e: + print(f"Configuration is invalid: {e}") + return False + if 'openai' in conf: + return openai_config_valid(conf['openai']) + else: + # required as long as we only support OpenAI + print("Section 'openai' is missing in the configuration!") + return False diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index ca8ae32..a4648b0 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,8 +1,9 @@ import yaml +import sys import io import pathlib from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config +from .configuration import Config, config_valid from typing import Any, Optional @@ -26,6 +27,8 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: def read_config(path: str) -> Config: with open(path, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) + if not config_valid(config): + sys.exit(1) return config