added typ hints for all functions in 'main.py', 'utils.py', 'storage.py' and 'api_client.py'

This commit is contained in:
juk0de 2023-08-15 23:36:45 +02:00
parent ba41794f4e
commit 4303fb414f
4 changed files with 39 additions and 26 deletions

View File

@ -1,11 +1,16 @@
import openai import openai
from .utils import ConfigType, ChatType
def openai_api_key(api_key: str) -> None: def openai_api_key(api_key: str) -> None:
openai.api_key = api_key openai.api_key = api_key
def print_models() -> None: def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = [] not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']: if engine['ready']:
@ -16,10 +21,16 @@ def print_models() -> None:
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: list[dict[str, str]], def ai(chat: ChatType,
config: dict, config: ConfigType,
number: int number: int
) -> tuple[list[str], dict[str, int]]: ) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.')
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=config['openai']['model'], model=config['openai']['model'],
messages=chat, messages=chat,

View File

@ -7,15 +7,16 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from itertools import zip_longest from itertools import zip_longest
from typing import Any
default_config = '.config.yaml' default_config = '.config.yaml'
def tags_completer(prefix, parsed_args, **kwargs): def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
with open(parsed_args.config, 'r') as f: with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader)
return get_tags_unique(config, prefix) return get_tags_unique(config, prefix)
@ -23,7 +24,7 @@ def tags_completer(prefix, parsed_args, **kwargs):
def create_question_with_hist(args: argparse.Namespace, def create_question_with_hist(args: argparse.Namespace,
config: ConfigType, config: ConfigType,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[ChatType, str, list[str]]:
""" """
Creates the "AI request", including the question and chat history as determined Creates the "AI request", including the question and chat history as determined
by the specified tags. by the specified tags.
@ -67,7 +68,7 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None:
""" """
Handler for the 'config' command. Handler for the 'config' command.
""" """
if type(config['openai']) is not dict: if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.') raise RuntimeError('Configuration openai is not a dict.')
if args.list_models: if args.list_models:
@ -83,15 +84,14 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
if type(config['openai']) is not dict: if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.') raise RuntimeError('Configuration openai is not a dict.')
config_openai = config['openai']
if args.max_tokens: if args.max_tokens:
config_openai['max_tokens'] = args.max_tokens config['openai']['max_tokens'] = args.max_tokens
if args.temperature: if args.temperature:
config_openai['temperature'] = args.temperature config['openai']['temperature'] = args.temperature
if args.model: if args.model:
config_openai['model'] = args.model config['openai']['model'] = args.model
chat, question, tags = create_question_with_hist(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code) print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []

View File

@ -1,11 +1,11 @@
import yaml import yaml
import io import io
import pathlib import pathlib
from .utils import terminal_width, append_message, message_to_chat, ConfigType from .utils import terminal_width, append_message, message_to_chat, ConfigType, ChatType
from typing import List, Dict, Any, Optional from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd: with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format) # also support tags separated by ',' (old format)
@ -33,7 +33,7 @@ def write_config(path: str, config: ConfigType) -> None:
yaml.dump(config, f) yaml.dump(config, f)
def dump_data(data: Dict[str, Any]) -> str: def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd: with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@ -41,7 +41,7 @@ def dump_data(data: Dict[str, Any]) -> str:
return fd.getvalue() return fd.getvalue()
def write_file(fname: str, data: Dict[str, Any]) -> None: def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd: with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@ -75,14 +75,14 @@ def save_answers(question: str,
def create_chat_hist(question: Optional[str], def create_chat_hist(question: Optional[str],
tags: Optional[List[str]], tags: Optional[list[str]],
extags: Optional[List[str]], extags: Optional[list[str]],
config: ConfigType, config: ConfigType,
match_all_tags: bool = False, match_all_tags: bool = False,
with_tags: bool = False, with_tags: bool = False,
with_file: bool = False with_file: bool = False
) -> List[Dict[str, str]]: ) -> ChatType:
chat: List[Dict[str, str]] = [] chat: ChatType = []
append_message(chat, 'system', str(config['system']).strip()) append_message(chat, 'system', str(config['system']).strip())
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
@ -108,7 +108,7 @@ def create_chat_hist(question: Optional[str],
return chat return chat
def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]:
result = [] result = []
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
@ -127,5 +127,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]:
return result return result
def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix))) return list(set(get_tags(config, prefix)))

View File

@ -1,14 +1,16 @@
import shutil import shutil
from pprint import PrettyPrinter from pprint import PrettyPrinter
from typing import Any
ConfigType = dict[str, str | dict[str, str | int | float]] ConfigType = dict[str, str | dict[str, str | int | float]]
ChatType = list[dict[str, str]]
def terminal_width() -> int: def terminal_width() -> int:
return shutil.get_terminal_size().columns return shutil.get_terminal_size().columns
def pp(*args, **kwargs) -> None: def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
@ -30,7 +32,7 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None
print() print()
def append_message(chat: list[dict[str, str]], def append_message(chat: ChatType,
role: str, role: str,
content: str content: str
) -> None: ) -> None:
@ -38,7 +40,7 @@ def append_message(chat: list[dict[str, str]],
def message_to_chat(message: dict[str, str], def message_to_chat(message: dict[str, str],
chat: list[dict[str, str]], chat: ChatType,
with_tags: bool = False, with_tags: bool = False,
with_file: bool = False with_file: bool = False
) -> None: ) -> None:
@ -61,7 +63,7 @@ def display_source_code(content: str) -> None:
pass pass
def print_chat_hist(chat, dump=False, source_code=False) -> None: def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump: if dump:
pp(chat) pp(chat)
return return