From 4303fb414f6f12702264549da625acfd6a53b5b7 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 15 Aug 2023 23:36:45 +0200 Subject: [PATCH] added typ hints for all functions in 'main.py', 'utils.py', 'storage.py' and 'api_client.py' --- chatmastermind/api_client.py | 15 +++++++++++++-- chatmastermind/main.py | 18 +++++++++--------- chatmastermind/storage.py | 22 +++++++++++----------- chatmastermind/utils.py | 10 ++++++---- 4 files changed, 39 insertions(+), 26 deletions(-) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index 8eaf695..d3282eb 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -1,11 +1,16 @@ import openai +from .utils import ConfigType, ChatType + def openai_api_key(api_key: str) -> None: openai.api_key = api_key def print_models() -> None: + """ + Print all models supported by the current AI. + """ not_ready = [] for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): if engine['ready']: @@ -16,10 +21,16 @@ def print_models() -> None: print('\nNot ready: ' + ', '.join(not_ready)) -def ai(chat: list[dict[str, str]], - config: dict, +def ai(chat: ChatType, + config: ConfigType, number: int ) -> tuple[list[str], dict[str, int]]: + """ + Make AI request with the given chat history and configuration. + Return AI response and tokens used. + """ + if not isinstance(config['openai'], dict): + raise RuntimeError('Configuration openai is not a dict.') response = openai.ChatCompletion.create( model=config['openai']['model'], messages=chat, diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 15e8208..ec33cb3 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,15 +7,16 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType, ChatType from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .api_client import ai, openai_api_key, print_models from itertools import zip_longest +from typing import Any default_config = '.config.yaml' -def tags_completer(prefix, parsed_args, **kwargs): +def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: with open(parsed_args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return get_tags_unique(config, prefix) @@ -23,7 +24,7 @@ def tags_completer(prefix, parsed_args, **kwargs): def create_question_with_hist(args: argparse.Namespace, config: ConfigType, - ) -> tuple[list[dict[str, str]], str, list[str]]: + ) -> tuple[ChatType, str, list[str]]: """ Creates the "AI request", including the question and chat history as determined by the specified tags. @@ -67,7 +68,7 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'config' command. """ - if type(config['openai']) is not dict: + if not isinstance(config['openai'], dict): raise RuntimeError('Configuration openai is not a dict.') if args.list_models: @@ -83,15 +84,14 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'ask' command. """ - if type(config['openai']) is not dict: + if not isinstance(config['openai'], dict): raise RuntimeError('Configuration openai is not a dict.') - config_openai = config['openai'] if args.max_tokens: - config_openai['max_tokens'] = args.max_tokens + config['openai']['max_tokens'] = args.max_tokens if args.temperature: - config_openai['temperature'] = args.temperature + config['openai']['temperature'] = args.temperature if args.model: - config_openai['model'] = args.model + config['openai']['model'] = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index d90598b..fa3fb14 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,11 +1,11 @@ import yaml import io import pathlib -from .utils import terminal_width, append_message, message_to_chat, ConfigType -from typing import List, Dict, Any, Optional +from .utils import terminal_width, append_message, message_to_chat, ConfigType, ChatType +from typing import Any, Optional -def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: +def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: with open(fname, "r") as fd: tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() # also support tags separated by ',' (old format) @@ -33,7 +33,7 @@ def write_config(path: str, config: ConfigType) -> None: yaml.dump(config, f) -def dump_data(data: Dict[str, Any]) -> str: +def dump_data(data: dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') @@ -41,7 +41,7 @@ def dump_data(data: Dict[str, Any]) -> str: return fd.getvalue() -def write_file(fname: str, data: Dict[str, Any]) -> None: +def write_file(fname: str, data: dict[str, Any]) -> None: with open(fname, "w") as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') @@ -75,14 +75,14 @@ def save_answers(question: str, def create_chat_hist(question: Optional[str], - tags: Optional[List[str]], - extags: Optional[List[str]], + tags: Optional[list[str]], + extags: Optional[list[str]], config: ConfigType, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False - ) -> List[Dict[str, str]]: - chat: List[Dict[str, str]] = [] + ) -> ChatType: + chat: ChatType = [] append_message(chat, 'system', str(config['system']).strip()) for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -108,7 +108,7 @@ def create_chat_hist(question: Optional[str], return chat -def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: +def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]: result = [] for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -127,5 +127,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: return result -def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: +def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> list[str]: return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index fba8296..c6d527c 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -1,14 +1,16 @@ import shutil from pprint import PrettyPrinter +from typing import Any ConfigType = dict[str, str | dict[str, str | int | float]] +ChatType = list[dict[str, str]] def terminal_width() -> int: return shutil.get_terminal_size().columns -def pp(*args, **kwargs) -> None: +def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) @@ -30,7 +32,7 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None print() -def append_message(chat: list[dict[str, str]], +def append_message(chat: ChatType, role: str, content: str ) -> None: @@ -38,7 +40,7 @@ def append_message(chat: list[dict[str, str]], def message_to_chat(message: dict[str, str], - chat: list[dict[str, str]], + chat: ChatType, with_tags: bool = False, with_file: bool = False ) -> None: @@ -61,7 +63,7 @@ def display_source_code(content: str) -> None: pass -def print_chat_hist(chat, dump=False, source_code=False) -> None: +def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: if dump: pp(chat) return