diff --git a/chatmastermind/main.py b/chatmastermind/main.py index e3ddda6..4c00309 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,8 +6,8 @@ import yaml import sys import argcomplete import argparse -from .utils import terminal_width, pp, tags_completer, process_tags, display_chat -from .storage import save_answers, create_chat +from .utils import terminal_width, pp, process_tags, display_chat +from .storage import save_answers, create_chat, get_tags from .api_client import ai, openai_api_key @@ -23,7 +23,8 @@ def process_and_display_chat(args: argparse.Namespace, ) -> tuple[list[dict[str, str]], str, list[str]]: tags = args.tags or [] extags = args.extags or [] - process_tags(config, tags, extags) + otags = args.output_tags or [] + process_tags(tags, extags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -60,6 +61,12 @@ def handle_question(args: argparse.Namespace, print(f"Usage: {usage}") +def tags_completer(prefix, parsed_args, **kwargs): + with open(parsed_args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + return get_tags(config, prefix) + + def create_parser() -> argparse.ArgumentParser: default_config = '.config.yaml' parser = argparse.ArgumentParser( diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index 7b7e17b..2d1d373 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -55,3 +55,18 @@ def create_chat(question: Optional[str], if question: append_message(chat, 'user', question) return chat + + +def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: + result = [] + for file in sorted(pathlib.Path(config['db']).iterdir()): + if file.suffix == '.yaml': + with open(file, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + for tag in data.get('tags', []): + if prefix and len(prefix) > 0: + if tag.startswith(prefix): + result.append(tag) + else: + result.append(tag) + return list(set(result)) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 3db408b..2c07ce3 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -1,6 +1,4 @@ import shutil -import yaml -import pathlib from pprint import PrettyPrinter from typing import List, Dict @@ -13,11 +11,19 @@ def pp(*args, **kwargs) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) -def process_tags(config: dict, tags: list, extags: list) -> None: - print(f"Tags: {', '.join(tags)}") - if len(extags) > 0: - print(f"Excluding tags: {', '.join(extags)}") - print() +def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None: + printed_messages = [] + + if tags: + printed_messages.append(f"Tags: {', '.join(tags)}") + if extags: + printed_messages.append(f"Excluding tags: {', '.join(extags)}") + if otags: + printed_messages.append(f"Output tags: {', '.join(otags)}") + + if printed_messages: + print("\n".join(printed_messages)) + print() def append_message(chat: List[Dict[str, str]], @@ -46,20 +52,3 @@ def display_chat(chat, dump=False) -> None: print(message['content']) else: print(f"{message['role'].upper()}: {message['content']}") - - -def tags_completer(prefix, parsed_args, **kwargs): - with open(parsed_args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - result = [] - for file in sorted(pathlib.Path(config['db']).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return list(set(result)) diff --git a/tests/test_main.py b/tests/test_main.py index 19386b2..6c35fe9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -113,9 +113,9 @@ class TestHandleQuestion(unittest.TestCase): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): handle_question(self.args, self.config, True) - mock_process_tags.assert_called_once_with(self.config, - self.args.tags, - self.args.extags) + mock_process_tags.assert_called_once_with(self.args.tags, + self.args.extags, + []) mock_create_chat.assert_called_once_with(self.question, self.args.tags, self.args.extags,