cmm: added 'question' command

This commit is contained in:
juk0de 2023-09-04 22:35:53 +02:00
parent 484e16de4d
commit 1e68617a46
2 changed files with 89 additions and 28 deletions

View File

@ -11,7 +11,8 @@ from .storage import save_answers, create_chat_hist
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config from .configuration import Config
from .chat import ChatDB from .chat import ChatDB
from .message import Message, MessageFilter, MessageError from .message import Message, MessageFilter, MessageError, Question
from .ai_factory import create_ai
from itertools import zip_longest from itertools import zip_longest
from typing import Any from typing import Any
@ -30,12 +31,12 @@ def create_question_with_hist(args: argparse.Namespace,
Creates the "AI request", including the question and chat history as determined Creates the "AI request", including the question and chat history as determined
by the specified tags. by the specified tags.
""" """
tags = args.tags or [] tags = args.or_tags or []
etags = args.etags or [] xtags = args.exclude_tags or []
otags = args.output_tags or [] otags = args.output_tags or []
if not args.source_code_only: if not args.source_code_only:
print_tag_args(tags, etags, otags) print_tag_args(tags, xtags, otags)
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.question if args.question is not None else []
@ -52,8 +53,8 @@ def create_question_with_hist(args: argparse.Namespace,
question_parts.append(f"```\n{r.read().strip()}\n```") question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, etags, config, chat = create_chat_hist(full_question, tags, xtags, config,
match_all_tags=True if args.atags else False, # FIXME match_all_tags=True if args.and_tags else False, # FIXME
with_tags=False, with_tags=False,
with_file=False) with_file=False)
return chat, full_question, tags return chat, full_question, tags
@ -85,6 +86,44 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
config.to_file(args.config) config.to_file(args.config)
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = Message(question=Question(args.question),
tags=args.ouput_tags, # FIXME
ai=args.ai,
model=args.model)
chat.add_to_cache([message])
if args.create:
return
# create the correct AI instance
ai = create_ai(args, config)
if args.ask:
ai.request(message,
chat,
args.num_answers, # FIXME
args.otags) # FIXME
# TODO:
# * add answer to the message above (and create
# more messages for any additional answers)
pass
elif args.repeat:
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass
def ask_cmd(args: argparse.Namespace, config: Config) -> None: def ask_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
@ -98,7 +137,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
chat, question, tags = create_question_with_hist(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.source_code_only) print_chat_hist(chat, False, args.source_code_only)
otags = args.output_tags or [] otags = args.output_tags or []
answers, usage = ai(chat, config, args.number) answers, usage = ai(chat, config, args.num_answers)
save_answers(question, answers, tags, otags, config) save_answers(question, answers, tags, otags, config)
print("-" * terminal_width()) print("-" * terminal_width())
print(f"Usage: {usage}") print(f"Usage: {usage}")
@ -109,9 +148,9 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
mfilter = MessageFilter(tags_or=args.tags, mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.atags, tags_and=args.and_tags,
tags_not=args.etags, tags_not=args.exclude_tags,
question_contains=args.question, question_contains=args.question,
answer_contains=args.answer) answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'), chat = ChatDB.from_dir(Path('.'),
@ -139,7 +178,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-c', '--config', help='Config file name.', default=default_config) parser.add_argument('-C', '--config', help='Config file name.', default=default_config)
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
@ -149,19 +188,41 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
help='List of tag names (one must match)', metavar='TAGS') help='List of tag names (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
help='List of tag names (all must match)', metavar='TAGS') help='List of tag names (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tag names to exclude', metavar='ETAGS') help='List of tag names to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tag names, default is input', metavar='OTAGS') help='List of output tag names, default is input', metavar='OUTTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
# 'question' command parser
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser],
help="ask, create and process questions.",
aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
question_cmd_parser.add_argument('-A', '--AI', help='AI to use')
question_cmd_parser.add_argument('-M', '--model', help='Model to use')
question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
default=1)
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true')
# 'ask' command parser # 'ask' command parser
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
help="Ask a question.", help="Ask a question.",
@ -172,7 +233,7 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
ask_cmd_parser.add_argument('-M', '--model', help='Model to use') ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
default=1) default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',

View File

@ -114,14 +114,14 @@ class TestHandleQuestion(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.question = "test question" self.question = "test question"
self.args = argparse.Namespace( self.args = argparse.Namespace(
tags=['tag1'], or_tags=['tag1'],
atags=None, and_tags=None,
etags=['etag1'], exclude_tags=['xtag1'],
output_tags=None, output_tags=None,
question=[self.question], question=[self.question],
source=None, source=None,
source_code_only=False, source_code_only=False,
number=3, num_answers=3,
max_tokens=None, max_tokens=None,
temperature=None, temperature=None,
model=None, model=None,
@ -143,12 +143,12 @@ class TestHandleQuestion(CmmTestCase):
open_mock = MagicMock() open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config) ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags, mock_print_tag_args.assert_called_once_with(self.args.or_tags,
self.args.etags, self.args.exclude_tags,
[]) [])
mock_create_chat_hist.assert_called_once_with(self.question, mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags, self.args.or_tags,
self.args.etags, self.args.exclude_tags,
self.config, self.config,
match_all_tags=False, match_all_tags=False,
with_tags=False, with_tags=False,
@ -158,7 +158,7 @@ class TestHandleQuestion(CmmTestCase):
self.args.source_code_only) self.args.source_code_only)
mock_ai.assert_called_with("test_chat", mock_ai.assert_called_with("test_chat",
self.config, self.config,
self.args.number) self.args.num_answers)
expected_calls = [] expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1): for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} ' title = f'-- ANSWER {num} '