Compare commits

...

3 Commits

6 changed files with 75 additions and 44 deletions

View File

@ -1,11 +1,17 @@
import openai import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None: def openai_api_key(api_key: str) -> None:
openai.api_key = api_key openai.api_key = api_key
def print_models() -> None: def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = [] not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']: if engine['ready']:
@ -16,10 +22,16 @@ def print_models() -> None:
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: list[dict[str, str]], def ai(chat: ChatType,
config: dict, config: Config,
number: int number: int
) -> tuple[list[str], dict[str, int]]: ) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.')
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=config['openai']['model'], model=config['openai']['model'],
messages=chat, messages=chat,

View File

@ -0,0 +1,23 @@
from typing import TypedDict
class OpenAIConfig(TypedDict):
"""
The OpenAI section of the configuration file.
"""
api_key: str
model: str
temperature: float
max_tokens: int
top_p: float
frequency_penalty: float
presence_penalty: float
class Config(TypedDict):
"""
The configuration file structure.
"""
system: str
db: str
openai: OpenAIConfig

View File

@ -7,23 +7,25 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from itertools import zip_longest from itertools import zip_longest
from typing import Any
default_config = '.config.yaml' default_config = '.config.yaml'
def tags_completer(prefix, parsed_args, **kwargs): def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
with open(parsed_args.config, 'r') as f: with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader)
return get_tags_unique(config, prefix) return get_tags_unique(config, prefix)
def create_question_with_hist(args: argparse.Namespace, def create_question_with_hist(args: argparse.Namespace,
config: ConfigType, config: Config,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[ChatType, str, list[str]]:
""" """
Creates the "AI request", including the question and chat history as determined Creates the "AI request", including the question and chat history as determined
by the specified tags. by the specified tags.
@ -55,7 +57,7 @@ def create_question_with_hist(args: argparse.Namespace,
return chat, full_question, tags return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: def tag_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tag' command.
""" """
@ -63,13 +65,10 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print_tags_frequency(get_tags(config, None)) print_tags_frequency(get_tags(config, None))
def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: def config_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'config' command. Handler for the 'config' command.
""" """
if type(config['openai']) is not dict:
raise RuntimeError('Configuration openai is not a dict.')
if args.list_models: if args.list_models:
print_models() print_models()
elif args.print_model: elif args.print_model:
@ -79,19 +78,16 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None:
write_config(args.config, config) write_config(args.config, config)
def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: def ask_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
if type(config['openai']) is not dict:
raise RuntimeError('Configuration openai is not a dict.')
config_openai = config['openai']
if args.max_tokens: if args.max_tokens:
config_openai['max_tokens'] = args.max_tokens config['openai']['max_tokens'] = args.max_tokens
if args.temperature: if args.temperature:
config_openai['temperature'] = args.temperature config['openai']['temperature'] = args.temperature
if args.model: if args.model:
config_openai['model'] = args.model config['openai']['model'] = args.model
chat, question, tags = create_question_with_hist(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code) print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []
@ -101,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print(f"Usage: {usage}") print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: def hist_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
@ -115,7 +111,7 @@ def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print_chat_hist(chat, args.dump, args.only_source_code) print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace, config: ConfigType) -> None: def print_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'print' command. Handler for the 'print' command.
""" """
@ -231,10 +227,7 @@ def main() -> int:
command = parser.parse_args() command = parser.parse_args()
config = read_config(args.config) config = read_config(args.config)
if type(config['openai']) is dict and type(config['openai']['api_key']) is str: openai_api_key(config['openai']['api_key'])
openai_api_key(config['openai']['api_key'])
else:
raise RuntimeError("Configuration openai.api_key is wrong.")
command.func(command, config) command.func(command, config)

View File

@ -1,11 +1,12 @@
import yaml import yaml
import io import io
import pathlib import pathlib
from .utils import terminal_width, append_message, message_to_chat, ConfigType from .utils import terminal_width, append_message, message_to_chat, ChatType
from typing import List, Dict, Any, Optional from .configuration import Config
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd: with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format) # also support tags separated by ',' (old format)
@ -22,18 +23,18 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]:
"file": fname.name} "file": fname.name}
def read_config(path: str) -> ConfigType: def read_config(path: str) -> Config:
with open(path, 'r') as f: with open(path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader)
return config return config
def write_config(path: str, config: ConfigType) -> None: def write_config(path: str, config: Config) -> None:
with open(path, 'w') as f: with open(path, 'w') as f:
yaml.dump(config, f) yaml.dump(config, f)
def dump_data(data: Dict[str, Any]) -> str: def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd: with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@ -41,7 +42,7 @@ def dump_data(data: Dict[str, Any]) -> str:
return fd.getvalue() return fd.getvalue()
def write_file(fname: str, data: Dict[str, Any]) -> None: def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd: with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@ -52,7 +53,7 @@ def save_answers(question: str,
answers: list[str], answers: list[str],
tags: list[str], tags: list[str],
otags: Optional[list[str]], otags: Optional[list[str]],
config: ConfigType config: Config
) -> None: ) -> None:
wtags = otags or tags wtags = otags or tags
num, inum = 0, 0 num, inum = 0, 0
@ -75,14 +76,14 @@ def save_answers(question: str,
def create_chat_hist(question: Optional[str], def create_chat_hist(question: Optional[str],
tags: Optional[List[str]], tags: Optional[list[str]],
extags: Optional[List[str]], extags: Optional[list[str]],
config: ConfigType, config: Config,
match_all_tags: bool = False, match_all_tags: bool = False,
with_tags: bool = False, with_tags: bool = False,
with_file: bool = False with_file: bool = False
) -> List[Dict[str, str]]: ) -> ChatType:
chat: List[Dict[str, str]] = [] chat: ChatType = []
append_message(chat, 'system', str(config['system']).strip()) append_message(chat, 'system', str(config['system']).strip())
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
@ -108,7 +109,7 @@ def create_chat_hist(question: Optional[str],
return chat return chat
def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = [] result = []
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
@ -127,5 +128,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]:
return result return result
def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix))) return list(set(get_tags(config, prefix)))

View File

@ -1,14 +1,15 @@
import shutil import shutil
from pprint import PrettyPrinter from pprint import PrettyPrinter
from typing import Any
ConfigType = dict[str, str | dict[str, str | int | float]] ChatType = list[dict[str, str]]
def terminal_width() -> int: def terminal_width() -> int:
return shutil.get_terminal_size().columns return shutil.get_terminal_size().columns
def pp(*args, **kwargs) -> None: def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
@ -30,7 +31,7 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None
print() print()
def append_message(chat: list[dict[str, str]], def append_message(chat: ChatType,
role: str, role: str,
content: str content: str
) -> None: ) -> None:
@ -38,7 +39,7 @@ def append_message(chat: list[dict[str, str]],
def message_to_chat(message: dict[str, str], def message_to_chat(message: dict[str, str],
chat: list[dict[str, str]], chat: ChatType,
with_tags: bool = False, with_tags: bool = False,
with_file: bool = False with_file: bool = False
) -> None: ) -> None:
@ -61,7 +62,7 @@ def display_source_code(content: str) -> None:
pass pass
def print_chat_hist(chat, dump=False, source_code=False) -> None: def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump: if dump:
pp(chat) pp(chat)
return return

View File

@ -5,3 +5,4 @@ strict_optional = True
warn_unused_ignores = False warn_unused_ignores = False
warn_redundant_casts = True warn_redundant_casts = True
warn_unused_configs = True warn_unused_configs = True
disallow_untyped_defs = True