Compare commits

...

5 Commits

Author SHA1 Message Date
056bf4c6b5 fixed almost all tests 2023-08-12 09:51:13 +02:00
93a8b0081a main: cleanup 2023-08-12 09:50:54 +02:00
5119b3a874 fixed 'ask' command 2023-08-12 08:28:07 +02:00
5a435c5f8f fixed 'tag' and 'hist' commands 2023-08-12 08:20:00 +02:00
f90e7bcd47 fixed 'hist' command and simplified reading the config file 2023-08-12 08:13:31 +02:00
3 changed files with 68 additions and 55 deletions

View File

@ -7,7 +7,7 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, process_tags, print_chat_hist, display_source_code, print_tags_frequency from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from itertools import zip_longest from itertools import zip_longest
@ -27,9 +27,9 @@ def read_config(path: str):
return config return config
def create_question_and_chat(args: argparse.Namespace, def create_question_with_hist(args: argparse.Namespace,
config: dict, config: dict,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[list[dict[str, str]], str, list[str]]:
""" """
Creates the "SI request", including the question and chat history as determined Creates the "SI request", including the question and chat history as determined
by the specified tags. by the specified tags.
@ -39,7 +39,7 @@ def create_question_and_chat(args: argparse.Namespace,
otags = args.output_tags or [] otags = args.output_tags or []
if not args.only_source_code: if not args.only_source_code:
process_tags(tags, extags, otags) print_tag_args(tags, extags, otags)
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.question if args.question is not None else []
@ -57,21 +57,19 @@ def create_question_and_chat(args: argparse.Namespace,
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, extags, config, chat = create_chat_hist(full_question, tags, extags, config,
args.match_all_tags, args.with_tags, args.match_all_tags, False, False)
args.with_file)
return chat, full_question, tags return chat, full_question, tags
def tag_cmd(args: argparse.Namespace) -> None: def tag_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tag' command.
""" """
config = read_config(args.config)
if args.list: if args.list:
print_tags_frequency(get_tags(config, None), args.dump) print_tags_frequency(get_tags(config, None))
def model_cmd(args: argparse.Namespace) -> None: def model_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'model' command. Handler for the 'model' command.
""" """
@ -79,19 +77,12 @@ def model_cmd(args: argparse.Namespace) -> None:
print_models() print_models()
def ask_cmd(args: argparse.Namespace) -> None: def ask_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
config = read_config(args.config) chat, question, tags = create_question_with_hist(args, config)
if args.max_tokens: print_chat_hist(chat, False, args.only_source_code)
config['openai']['max_tokens'] = args.max_tokens
if args.temperature:
config['openai']['temperature'] = args.temperature
if args.model:
config['openai']['model'] = args.model
chat, question, tags = create_question_and_chat(args, config)
print_chat_hist(chat, args.dump, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []
answers, usage = ai(chat, config, args.number) answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config) save_answers(question, answers, tags, otags, config)
@ -99,27 +90,32 @@ def ask_cmd(args: argparse.Namespace) -> None:
print(f"Usage: {usage}") print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace) -> None: def hist_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
config = read_config(args.config) tags = args.tags or []
chat, q, t = create_question_and_chat(args, config) extags = args.extags or []
chat = create_chat_hist(None, tags, extags, config,
args.match_all_tags,
args.with_tags,
args.with_files)
print_chat_hist(chat, args.dump, args.only_source_code) print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace) -> None: def print_cmd(args: argparse.Namespace, config: dict) -> None:
""" """
Handler for the 'print' command. Handler for the 'print' command.
""" """
fname = pathlib.Path(args.print) fname = pathlib.Path(args.file)
if fname.suffix == '.yaml': if fname.suffix == '.yaml':
with open(args.print, 'r') as f: with open(args.file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader) data = yaml.load(f, Loader=yaml.FullLoader)
elif fname.suffix == '.txt': elif fname.suffix == '.txt':
data = read_file(fname) data = read_file(fname)
else: else:
print(f"Unknown file type: {args.print}") print(f"Unknown file type: {args.file}")
sys.exit(1) sys.exit(1)
if args.only_source_code: if args.only_source_code:
display_source_code(data['answer']) display_source_code(data['answer'])
@ -135,8 +131,8 @@ def create_parser() -> argparse.ArgumentParser:
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
title='commands', title='commands',
description='supported commands') description='supported commands',
cmdparser.required = True required=True)
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
@ -159,12 +155,16 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
help="Ask a question.") help="Ask a question.")
ask_cmd_parser.set_defaults(func=ask_cmd) ask_cmd_parser.set_defaults(func=ask_cmd)
ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', required=True) ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask',
required=True)
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
ask_cmd_parser.add_argument('-M', '--model', help='Model to use') ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
action='store_true')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
@ -208,10 +208,18 @@ def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
config = read_config(args.config)
openai_api_key(read_config(args.config)['openai']['api_key']) # modify config according to args
openai_api_key(config['openai']['api_key'])
if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens
if args.temperature:
config['openai']['temperature'] = args.temperature
if args.model:
config['openai']['model'] = args.model
command.func(command) command.func(command, config)
return 0 return 0

View File

@ -11,7 +11,10 @@ def pp(*args, **kwargs) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None: def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = [] printed_messages = []
if tags: if tags:
@ -75,9 +78,6 @@ def print_chat_hist(chat, dump=False, source_code=False) -> None:
print(f"{message['role'].upper()}: {message['content']}") print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: List[str], dump=False) -> None: def print_tags_frequency(tags: List[str]) -> None:
if dump:
pp(tags)
return
for tag in sorted(set(tags)): for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}") print(f"- {tag}: {tags.count(tag)}")

View File

@ -7,7 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai from chatmastermind.api_client import ai
from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock, Mock from unittest.mock import patch, MagicMock, Mock, ANY
class TestCreateChat(unittest.TestCase): class TestCreateChat(unittest.TestCase):
@ -113,23 +113,28 @@ class TestHandleQuestion(unittest.TestCase):
} }
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.process_tags") @patch("chatmastermind.main.print_tag_args")
@patch("chatmastermind.utils.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp") @patch("chatmastermind.utils.pp")
@patch("builtins.print") @patch("builtins.print")
def test_ask_cmd(self, mock_print, mock_pp, mock_ai, def test_ask_cmd(self, mock_print, mock_pp, mock_ai,
mock_process_tags, mock_create_chat_hist): mock_print_tag_args, mock_create_chat_hist,
mock_print_chat_hist):
open_mock = MagicMock() open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config, True) ask_cmd(self.args, self.config)
mock_process_tags.assert_called_once_with(self.args.tags, mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.extags, self.args.extags,
[]) [])
mock_create_chat_hist.assert_called_once_with(self.question, mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags, self.args.tags,
self.args.extags, self.args.extags,
self.config, self.config,
False, False, False) False, False, False)
mock_print_chat_hist.assert_called_once_with('test_chat',
False,
self.args.only_source_code)
mock_pp.assert_called_once_with("test_chat") mock_pp.assert_called_once_with("test_chat")
mock_ai.assert_called_with("test_chat", mock_ai.assert_called_with("test_chat",
self.config, self.config,
@ -205,15 +210,15 @@ class TestAI(unittest.TestCase):
class TestCreateParser(unittest.TestCase): class TestCreateParser(unittest.TestCase):
def test_create_parser(self): def test_create_parser(self):
with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group: with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_group = Mock() mock_cmdparser = Mock()
mock_add_mutually_exclusive_group.return_value = mock_group mock_add_subparsers.return_value = mock_cmdparser
parser = create_parser() parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser) self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_mutually_exclusive_group.assert_called_once_with(required=True) mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_group.add_argument.assert_any_call('-p', '--print', help='File to print') mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY)
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask') mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY)
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true') mock_cmdparser.add_parser.assert_any_call('tag', help=ANY)
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat history as readable text", action='store_true') mock_cmdparser.add_parser.assert_any_call('model', help=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertTrue('.config.yaml' in parser.get_default('config'))
self.assertEqual(parser.get_default('number'), 1)