diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index d3282eb..d8634bd 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -1,6 +1,7 @@ import openai -from .utils import ConfigType, ChatType +from .utils import ChatType +from .configuration import Config def openai_api_key(api_key: str) -> None: @@ -22,7 +23,7 @@ def print_models() -> None: def ai(chat: ChatType, - config: ConfigType, + config: Config, number: int ) -> tuple[list[str], dict[str, int]]: """ diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py new file mode 100644 index 0000000..2917865 --- /dev/null +++ b/chatmastermind/configuration.py @@ -0,0 +1,23 @@ +from typing import TypedDict + + +class OpenAIConfig(TypedDict): + """ + The OpenAI section of the configuration file. + """ + api_key: str + model: str + temperature: float + max_tokens: int + top_p: float + frequency_penalty: float + presence_penalty: float + + +class Config(TypedDict): + """ + The configuration file structure. + """ + system: str + db: str + openai: OpenAIConfig diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ec33cb3..7c6df33 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,9 +7,10 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType, ChatType +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .api_client import ai, openai_api_key, print_models +from .configuration import Config from itertools import zip_longest from typing import Any @@ -23,7 +24,7 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def create_question_with_hist(args: argparse.Namespace, - config: ConfigType, + config: Config, ) -> tuple[ChatType, str, list[str]]: """ Creates the "AI request", including the question and chat history as determined @@ -56,7 +57,7 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def tag_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tag' command. """ @@ -64,13 +65,10 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_tags_frequency(get_tags(config, None)) -def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def config_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'config' command. """ - if not isinstance(config['openai'], dict): - raise RuntimeError('Configuration openai is not a dict.') - if args.list_models: print_models() elif args.print_model: @@ -80,12 +78,10 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: write_config(args.config, config) -def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. """ - if not isinstance(config['openai'], dict): - raise RuntimeError('Configuration openai is not a dict.') if args.max_tokens: config['openai']['max_tokens'] = args.max_tokens if args.temperature: @@ -101,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: print(f"Usage: {usage}") -def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ @@ -115,7 +111,7 @@ def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_chat_hist(chat, args.dump, args.only_source_code) -def print_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ @@ -231,10 +227,7 @@ def main() -> int: command = parser.parse_args() config = read_config(args.config) - if type(config['openai']) is dict and type(config['openai']['api_key']) is str: - openai_api_key(config['openai']['api_key']) - else: - raise RuntimeError("Configuration openai.api_key is wrong.") + openai_api_key(config['openai']['api_key']) command.func(command, config) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index fa3fb14..ca8ae32 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,7 +1,8 @@ import yaml import io import pathlib -from .utils import terminal_width, append_message, message_to_chat, ConfigType, ChatType +from .utils import terminal_width, append_message, message_to_chat, ChatType +from .configuration import Config from typing import Any, Optional @@ -22,13 +23,13 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: "file": fname.name} -def read_config(path: str) -> ConfigType: +def read_config(path: str) -> Config: with open(path, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return config -def write_config(path: str, config: ConfigType) -> None: +def write_config(path: str, config: Config) -> None: with open(path, 'w') as f: yaml.dump(config, f) @@ -52,7 +53,7 @@ def save_answers(question: str, answers: list[str], tags: list[str], otags: Optional[list[str]], - config: ConfigType + config: Config ) -> None: wtags = otags or tags num, inum = 0, 0 @@ -77,7 +78,7 @@ def save_answers(question: str, def create_chat_hist(question: Optional[str], tags: Optional[list[str]], extags: Optional[list[str]], - config: ConfigType, + config: Config, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False @@ -108,7 +109,7 @@ def create_chat_hist(question: Optional[str], return chat -def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]: +def get_tags(config: Config, prefix: Optional[str]) -> list[str]: result = [] for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -127,5 +128,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]: return result -def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> list[str]: +def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index c6d527c..bd80e4f 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -2,7 +2,6 @@ import shutil from pprint import PrettyPrinter from typing import Any -ConfigType = dict[str, str | dict[str, str | int | float]] ChatType = list[dict[str, str]]