Compare commits

..

No commits in common. "056bf4c6b574177c13e8260765f1508c7b220018" and "6406d2f5b5daf35a8cfe550290dabf8196653aba" have entirely different histories.

3 changed files with 55 additions and 68 deletions

View File

@ -7,7 +7,7 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency from .utils import terminal_width, process_tags, print_chat_hist, display_source_code, print_tags_frequency
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from itertools import zip_longest from itertools import zip_longest
@ -27,9 +27,9 @@ def read_config(path: str):
return config return config
def create_question_with_hist(args: argparse.Namespace, def create_question_and_chat(args: argparse.Namespace,
config: dict, config: dict,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[list[dict[str, str]], str, list[str]]:
""" """
Creates the "SI request", including the question and chat history as determined Creates the "SI request", including the question and chat history as determined
by the specified tags. by the specified tags.
@ -39,7 +39,7 @@ def create_question_with_hist(args: argparse.Namespace,
otags = args.output_tags or [] otags = args.output_tags or []
if not args.only_source_code: if not args.only_source_code:
print_tag_args(tags, extags, otags) process_tags(tags, extags, otags)
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.question if args.question is not None else []
@ -57,19 +57,21 @@ def create_question_with_hist(args: argparse.Namespace,
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, extags, config, chat = create_chat_hist(full_question, tags, extags, config,
args.match_all_tags, False, False) args.match_all_tags, args.with_tags,
args.with_file)
return chat, full_question, tags return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: dict) -> None: def tag_cmd(args: argparse.Namespace) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tag' command.
""" """
config = read_config(args.config)
if args.list: if args.list:
print_tags_frequency(get_tags(config, None)) print_tags_frequency(get_tags(config, None), args.dump)
def model_cmd(args: argparse.Namespace, config: dict) -> None: def model_cmd(args: argparse.Namespace) -> None:
""" """
Handler for the 'model' command. Handler for the 'model' command.
""" """
@ -77,12 +79,19 @@ def model_cmd(args: argparse.Namespace, config: dict) -> None:
print_models() print_models()
def ask_cmd(args: argparse.Namespace, config: dict) -> None: def ask_cmd(args: argparse.Namespace) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
chat, question, tags = create_question_with_hist(args, config) config = read_config(args.config)
print_chat_hist(chat, False, args.only_source_code) if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens
if args.temperature:
config['openai']['temperature'] = args.temperature
if args.model:
config['openai']['model'] = args.model
chat, question, tags = create_question_and_chat(args, config)
print_chat_hist(chat, args.dump, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []
answers, usage = ai(chat, config, args.number) answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config) save_answers(question, answers, tags, otags, config)
@ -90,32 +99,27 @@ def ask_cmd(args: argparse.Namespace, config: dict) -> None:
print(f"Usage: {usage}") print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: dict) -> None: def hist_cmd(args: argparse.Namespace) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
tags = args.tags or [] config = read_config(args.config)
extags = args.extags or [] chat, q, t = create_question_and_chat(args, config)
chat = create_chat_hist(None, tags, extags, config,
args.match_all_tags,
args.with_tags,
args.with_files)
print_chat_hist(chat, args.dump, args.only_source_code) print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace, config: dict) -> None: def print_cmd(args: argparse.Namespace) -> None:
""" """
Handler for the 'print' command. Handler for the 'print' command.
""" """
fname = pathlib.Path(args.file) fname = pathlib.Path(args.print)
if fname.suffix == '.yaml': if fname.suffix == '.yaml':
with open(args.file, 'r') as f: with open(args.print, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader) data = yaml.load(f, Loader=yaml.FullLoader)
elif fname.suffix == '.txt': elif fname.suffix == '.txt':
data = read_file(fname) data = read_file(fname)
else: else:
print(f"Unknown file type: {args.file}") print(f"Unknown file type: {args.print}")
sys.exit(1) sys.exit(1)
if args.only_source_code: if args.only_source_code:
display_source_code(data['answer']) display_source_code(data['answer'])
@ -131,8 +135,8 @@ def create_parser() -> argparse.ArgumentParser:
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
title='commands', title='commands',
description='supported commands', description='supported commands')
required=True) cmdparser.required = True
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
@ -155,16 +159,12 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
help="Ask a question.") help="Ask a question.")
ask_cmd_parser.set_defaults(func=ask_cmd) ask_cmd_parser.set_defaults(func=ask_cmd)
ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', required=True)
required=True)
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
ask_cmd_parser.add_argument('-M', '--model', help='Model to use') ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1)
default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
action='store_true')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
@ -208,18 +208,10 @@ def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
config = read_config(args.config)
# modify config according to args openai_api_key(read_config(args.config)['openai']['api_key'])
openai_api_key(config['openai']['api_key'])
if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens
if args.temperature:
config['openai']['temperature'] = args.temperature
if args.model:
config['openai']['model'] = args.model
command.func(command, config) command.func(command)
return 0 return 0

View File

@ -11,10 +11,7 @@ def pp(*args, **kwargs) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = [] printed_messages = []
if tags: if tags:
@ -78,6 +75,9 @@ def print_chat_hist(chat, dump=False, source_code=False) -> None:
print(f"{message['role'].upper()}: {message['content']}") print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: List[str]) -> None: def print_tags_frequency(tags: List[str], dump=False) -> None:
if dump:
pp(tags)
return
for tag in sorted(set(tags)): for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}") print(f"- {tag}: {tags.count(tag)}")

View File

@ -7,7 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai from chatmastermind.api_client import ai
from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY from unittest.mock import patch, MagicMock, Mock
class TestCreateChat(unittest.TestCase): class TestCreateChat(unittest.TestCase):
@ -113,28 +113,23 @@ class TestHandleQuestion(unittest.TestCase):
} }
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args") @patch("chatmastermind.main.process_tags")
@patch("chatmastermind.utils.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp") @patch("chatmastermind.utils.pp")
@patch("builtins.print") @patch("builtins.print")
def test_ask_cmd(self, mock_print, mock_pp, mock_ai, def test_ask_cmd(self, mock_print, mock_pp, mock_ai,
mock_print_tag_args, mock_create_chat_hist, mock_process_tags, mock_create_chat_hist):
mock_print_chat_hist):
open_mock = MagicMock() open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config) ask_cmd(self.args, self.config, True)
mock_print_tag_args.assert_called_once_with(self.args.tags, mock_process_tags.assert_called_once_with(self.args.tags,
self.args.extags, self.args.extags,
[]) [])
mock_create_chat_hist.assert_called_once_with(self.question, mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags, self.args.tags,
self.args.extags, self.args.extags,
self.config, self.config,
False, False, False) False, False, False)
mock_print_chat_hist.assert_called_once_with('test_chat',
False,
self.args.only_source_code)
mock_pp.assert_called_once_with("test_chat") mock_pp.assert_called_once_with("test_chat")
mock_ai.assert_called_with("test_chat", mock_ai.assert_called_with("test_chat",
self.config, self.config,
@ -210,15 +205,15 @@ class TestAI(unittest.TestCase):
class TestCreateParser(unittest.TestCase): class TestCreateParser(unittest.TestCase):
def test_create_parser(self): def test_create_parser(self):
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group:
mock_cmdparser = Mock() mock_group = Mock()
mock_add_subparsers.return_value = mock_cmdparser mock_add_mutually_exclusive_group.return_value = mock_group
parser = create_parser() parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser) self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY) mock_group.add_argument.assert_any_call('-p', '--print', help='File to print')
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY) mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY) mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true')
mock_cmdparser.add_parser.assert_any_call('model', help=ANY) mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat history as readable text", action='store_true')
mock_cmdparser.add_parser.assert_any_call('print', help=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertTrue('.config.yaml' in parser.get_default('config'))
self.assertEqual(parser.get_default('number'), 1)