Compare commits

...

4 Commits

5 changed files with 88 additions and 58 deletions

View File

@ -6,11 +6,13 @@ import yaml
import sys import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib from pathlib import Path
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .storage import save_answers, create_chat_hist, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config from .configuration import Config
from .chat import ChatDB
from .message import Message, MessageFilter
from itertools import zip_longest from itertools import zip_longest
from typing import Any from typing import Any
@ -18,9 +20,8 @@ default_config = '.config.yaml'
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
with open(parsed_args.config, 'r') as f: config = Config.from_file(parsed_args.config)
config = yaml.load(f, Loader=yaml.FullLoader) return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
return get_tags_unique(config, prefix)
def create_question_with_hist(args: argparse.Namespace, def create_question_with_hist(args: argparse.Namespace,
@ -31,11 +32,11 @@ def create_question_with_hist(args: argparse.Namespace,
by the specified tags. by the specified tags.
""" """
tags = args.tags or [] tags = args.tags or []
extags = args.extags or [] etags = args.etags or []
otags = args.output_tags or [] otags = args.output_tags or []
if not args.only_source_code: if not args.source_code_only:
print_tag_args(tags, extags, otags) print_tag_args(tags, etags, otags)
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.question if args.question is not None else []
@ -52,17 +53,23 @@ def create_question_with_hist(args: argparse.Namespace,
question_parts.append(f"```\n{r.read().strip()}\n```") question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, extags, config, chat = create_chat_hist(full_question, tags, etags, config,
args.match_all_tags, False, False) match_all_tags=True if args.atags else False, # FIXME
with_tags=False,
with_file=False)
return chat, full_question, tags return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: Config) -> None: def tags_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tags' command.
""" """
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list: if args.list:
print_tags_frequency(get_tags(config, None)) tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
def config_cmd(args: argparse.Namespace, config: Config) -> None: def config_cmd(args: argparse.Namespace, config: Config) -> None:
@ -89,7 +96,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
if args.model: if args.model:
config.openai.model = args.model config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code) print_chat_hist(chat, False, args.source_code_only)
otags = args.output_tags or [] otags = args.output_tags or []
answers, usage = ai(chat, config, args.number) answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config) save_answers(question, answers, tags, otags, config)
@ -101,21 +108,25 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
tags = args.tags or []
extags = args.extags or []
chat = create_chat_hist(None, tags, extags, config, mfilter = MessageFilter(tags_or=args.tags,
args.match_all_tags, tags_and=args.atags,
args.with_tags, tags_not=args.etags,
args.with_files) question_contains=args.question,
print_chat_hist(chat, args.dump, args.only_source_code) answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
def print_cmd(args: argparse.Namespace, config: Config) -> None: def print_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'print' command. Handler for the 'print' command.
""" """
fname = pathlib.Path(args.file) fname = Path(args.file)
if fname.suffix == '.yaml': if fname.suffix == '.yaml':
with open(args.file, 'r') as f: with open(args.file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader) data = yaml.load(f, Loader=yaml.FullLoader)
@ -124,7 +135,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
else: else:
print(f"Unknown file type: {args.file}") print(f"Unknown file type: {args.file}")
sys.exit(1) sys.exit(1)
if args.only_source_code: if args.source_code_only:
display_source_code(data['answer']) display_source_code(data['answer'])
else: else:
print(dump_data(data).strip()) print(dump_data(data).strip())
@ -144,18 +155,17 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tag names', metavar='TAGS') help='List of tag names (one must match)', metavar='TAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+',
help='List of tag names to exclude', metavar='EXTAGS') help='List of tag names (all must match)', metavar='TAGS')
extag_arg.completer = tags_completer # type: ignore atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+',
help='List of tag names to exclude', metavar='ETAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tag names, default is input', metavar='OTAGS') help='List of output tag names, default is input', metavar='OTAGS')
otag_arg.completer = tags_completer # type: ignore otag_arg.completer = tags_completer # type: ignore
tag_parser.add_argument('-a', '--match-all-tags',
help="All given tags must match when selecting chat history entries",
action='store_true')
# enable autocompletion for tags
# 'ask' command parser # 'ask' command parser
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
@ -170,7 +180,7 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1) default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true') action='store_true')
# 'hist' command parser # 'hist' command parser
@ -178,23 +188,25 @@ def create_parser() -> argparse.ArgumentParser:
help="Print chat history.", help="Print chat history.",
aliases=['h']) aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd) hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure",
action='store_true')
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true') action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
# 'tag' command parser # 'tags' command parser
tag_cmd_parser = cmdparser.add_parser('tag', tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.", help="Manage tags.",
aliases=['t']) aliases=['t'])
tag_cmd_parser.set_defaults(func=tag_cmd) tags_cmd_parser.set_defaults(func=tags_cmd)
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tag_group.add_argument('-l', '--list', help="List all tags and their frequency", tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',
@ -214,7 +226,7 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['p']) aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true') action='store_true')
argcomplete.autocomplete(parser) argcomplete.autocomplete(parser)

View File

@ -5,7 +5,7 @@ import pathlib
import yaml import yaml
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
QuestionInst = TypeVar('QuestionInst', bound='Question') QuestionInst = TypeVar('QuestionInst', bound='Question')
AnswerInst = TypeVar('AnswerInst', bound='Answer') AnswerInst = TypeVar('AnswerInst', bound='Answer')
@ -499,6 +499,14 @@ class Message():
return False return False
return True return True
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None:
"""
Renames the given tags. The first tuple element is the old name,
the second one is the new name.
"""
if self.tags:
self.tags = rename_tags(self.tags, tags_rename)
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.

View File

@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
print(message['content']) print(message['content'])
else: else:
print(f"{message['role'].upper()}: {message['content']}") print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: list[str]) -> None:
for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}")

View File

@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase):
self.question = "test question" self.question = "test question"
self.args = argparse.Namespace( self.args = argparse.Namespace(
tags=['tag1'], tags=['tag1'],
extags=['extag1'], atags=None,
etags=['etag1'],
output_tags=None, output_tags=None,
question=[self.question], question=[self.question],
source=None, source=None,
only_source_code=False, source_code_only=False,
number=3, number=3,
max_tokens=None, max_tokens=None,
temperature=None, temperature=None,
@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase):
with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config) ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags, mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.extags, self.args.etags,
[]) [])
mock_create_chat_hist.assert_called_once_with(self.question, mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags, self.args.tags,
self.args.extags, self.args.etags,
self.config, self.config,
False, False, False) match_all_tags=False,
with_tags=False,
with_file=False)
mock_print_chat_hist.assert_called_once_with('test_chat', mock_print_chat_hist.assert_called_once_with('test_chat',
False, False,
self.args.only_source_code) self.args.source_code_only)
mock_ai.assert_called_with("test_chat", mock_ai.assert_called_with("test_chat",
self.config, self.config,
self.args.number) self.args.number)
@ -227,7 +230,7 @@ class TestCreateParser(CmmTestCase):
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertTrue('.config.yaml' in parser.get_default('config'))

View File

@ -792,3 +792,15 @@ class MessageInTestCase(CmmTestCase):
def test_message_in(self) -> None: def test_message_in(self) -> None:
self.assertTrue(message_in(self.message1, [self.message1])) self.assertTrue(message_in(self.message1, [self.message1]))
self.assertFalse(message_in(self.message1, [self.message2])) self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(CmmTestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_rename_tags(self) -> None:
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
self.assertIsNotNone(self.message.tags)
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]