fixed 'hist' command and simplified reading the config file

This commit is contained in:
juk0de 2023-08-12 08:13:31 +02:00
parent 6406d2f5b5
commit f90e7bcd47
2 changed files with 25 additions and 18 deletions

View File

@ -7,7 +7,7 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, process_tags, print_chat_hist, display_source_code, print_tags_frequency from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from itertools import zip_longest from itertools import zip_longest
@ -27,9 +27,9 @@ def read_config(path: str):
return config return config
def create_question_and_chat(args: argparse.Namespace, def create_question_with_hist(args: argparse.Namespace,
config: dict, config: dict,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[list[dict[str, str]], str, list[str]]:
""" """
Creates the "SI request", including the question and chat history as determined Creates the "SI request", including the question and chat history as determined
by the specified tags. by the specified tags.
@ -39,7 +39,7 @@ def create_question_and_chat(args: argparse.Namespace,
otags = args.output_tags or [] otags = args.output_tags or []
if not args.only_source_code: if not args.only_source_code:
process_tags(tags, extags, otags) print_tag_args(tags, extags, otags)
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.question if args.question is not None else []
@ -62,16 +62,15 @@ def create_question_and_chat(args: argparse.Namespace,
return chat, full_question, tags return chat, full_question, tags
def tag_cmd(args: argparse.Namespace) -> None: def tag_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tag' command.
""" """
config = read_config(args.config)
if args.list: if args.list:
print_tags_frequency(get_tags(config, None), args.dump) print_tags_frequency(get_tags(config, None), args.dump)
def model_cmd(args: argparse.Namespace) -> None: def model_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'model' command. Handler for the 'model' command.
""" """
@ -79,18 +78,17 @@ def model_cmd(args: argparse.Namespace) -> None:
print_models() print_models()
def ask_cmd(args: argparse.Namespace) -> None: def ask_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
config = read_config(args.config)
if args.max_tokens: if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens config['openai']['max_tokens'] = args.max_tokens
if args.temperature: if args.temperature:
config['openai']['temperature'] = args.temperature config['openai']['temperature'] = args.temperature
if args.model: if args.model:
config['openai']['model'] = args.model config['openai']['model'] = args.model
chat, question, tags = create_question_and_chat(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, args.dump, args.only_source_code) print_chat_hist(chat, args.dump, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []
answers, usage = ai(chat, config, args.number) answers, usage = ai(chat, config, args.number)
@ -99,16 +97,21 @@ def ask_cmd(args: argparse.Namespace) -> None:
print(f"Usage: {usage}") print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace) -> None: def hist_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
config = read_config(args.config) tags = args.tags or []
chat, q, t = create_question_and_chat(args, config) extags = args.extags or []
chat = create_chat_hist(None, tags, extags, config,
args.match_all_tags,
args.with_tags,
args.with_files)
print_chat_hist(chat, args.dump, args.only_source_code) print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace) -> None: def print_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'print' command. Handler for the 'print' command.
""" """
@ -209,9 +212,10 @@ def main() -> int:
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
openai_api_key(read_config(args.config)['openai']['api_key']) config = read_config(args.config)
openai_api_key(config['openai']['api_key'])
command.func(command) command.func(command, config)
return 0 return 0

View File

@ -11,7 +11,10 @@ def pp(*args, **kwargs) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None: def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = [] printed_messages = []
if tags: if tags: