Refactor process_tags
This commit is contained in:
parent
b23a9f663f
commit
4ee777118d
@ -6,8 +6,8 @@ import yaml
|
|||||||
import sys
|
import sys
|
||||||
import argcomplete
|
import argcomplete
|
||||||
import argparse
|
import argparse
|
||||||
from .utils import terminal_width, pp, tags_completer, process_tags, display_chat
|
from .utils import terminal_width, pp, process_tags, display_chat
|
||||||
from .storage import save_answers, create_chat
|
from .storage import save_answers, create_chat, get_tags
|
||||||
from .api_client import ai, openai_api_key
|
from .api_client import ai, openai_api_key
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,8 @@ def process_and_display_chat(args: argparse.Namespace,
|
|||||||
) -> tuple[list[dict[str, str]], str, list[str]]:
|
) -> tuple[list[dict[str, str]], str, list[str]]:
|
||||||
tags = args.tags or []
|
tags = args.tags or []
|
||||||
extags = args.extags or []
|
extags = args.extags or []
|
||||||
process_tags(config, tags, extags)
|
otags = args.output_tags or []
|
||||||
|
process_tags(tags, extags, otags)
|
||||||
|
|
||||||
question_parts = []
|
question_parts = []
|
||||||
question_list = args.question if args.question is not None else []
|
question_list = args.question if args.question is not None else []
|
||||||
@ -60,6 +61,12 @@ def handle_question(args: argparse.Namespace,
|
|||||||
print(f"Usage: {usage}")
|
print(f"Usage: {usage}")
|
||||||
|
|
||||||
|
|
||||||
|
def tags_completer(prefix, parsed_args, **kwargs):
|
||||||
|
with open(parsed_args.config, 'r') as f:
|
||||||
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
return get_tags(config, prefix)
|
||||||
|
|
||||||
|
|
||||||
def create_parser() -> argparse.ArgumentParser:
|
def create_parser() -> argparse.ArgumentParser:
|
||||||
default_config = '.config.yaml'
|
default_config = '.config.yaml'
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@ -55,3 +55,18 @@ def create_chat(question: Optional[str],
|
|||||||
if question:
|
if question:
|
||||||
append_message(chat, 'user', question)
|
append_message(chat, 'user', question)
|
||||||
return chat
|
return chat
|
||||||
|
|
||||||
|
|
||||||
|
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
||||||
|
result = []
|
||||||
|
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||||
|
if file.suffix == '.yaml':
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
for tag in data.get('tags', []):
|
||||||
|
if prefix and len(prefix) > 0:
|
||||||
|
if tag.startswith(prefix):
|
||||||
|
result.append(tag)
|
||||||
|
else:
|
||||||
|
result.append(tag)
|
||||||
|
return list(set(result))
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import yaml
|
|
||||||
import pathlib
|
|
||||||
from pprint import PrettyPrinter
|
from pprint import PrettyPrinter
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
@ -13,11 +11,19 @@ def pp(*args, **kwargs) -> None:
|
|||||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def process_tags(config: dict, tags: list, extags: list) -> None:
|
def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None:
|
||||||
print(f"Tags: {', '.join(tags)}")
|
printed_messages = []
|
||||||
if len(extags) > 0:
|
|
||||||
print(f"Excluding tags: {', '.join(extags)}")
|
if tags:
|
||||||
print()
|
printed_messages.append(f"Tags: {', '.join(tags)}")
|
||||||
|
if extags:
|
||||||
|
printed_messages.append(f"Excluding tags: {', '.join(extags)}")
|
||||||
|
if otags:
|
||||||
|
printed_messages.append(f"Output tags: {', '.join(otags)}")
|
||||||
|
|
||||||
|
if printed_messages:
|
||||||
|
print("\n".join(printed_messages))
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
def append_message(chat: List[Dict[str, str]],
|
def append_message(chat: List[Dict[str, str]],
|
||||||
@ -46,20 +52,3 @@ def display_chat(chat, dump=False) -> None:
|
|||||||
print(message['content'])
|
print(message['content'])
|
||||||
else:
|
else:
|
||||||
print(f"{message['role'].upper()}: {message['content']}")
|
print(f"{message['role'].upper()}: {message['content']}")
|
||||||
|
|
||||||
|
|
||||||
def tags_completer(prefix, parsed_args, **kwargs):
|
|
||||||
with open(parsed_args.config, 'r') as f:
|
|
||||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
result = []
|
|
||||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
||||||
if file.suffix == '.yaml':
|
|
||||||
with open(file, 'r') as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
for tag in data.get('tags', []):
|
|
||||||
if prefix and len(prefix) > 0:
|
|
||||||
if tag.startswith(prefix):
|
|
||||||
result.append(tag)
|
|
||||||
else:
|
|
||||||
result.append(tag)
|
|
||||||
return list(set(result))
|
|
||||||
|
|||||||
@ -113,9 +113,9 @@ class TestHandleQuestion(unittest.TestCase):
|
|||||||
open_mock = MagicMock()
|
open_mock = MagicMock()
|
||||||
with patch("chatmastermind.storage.open", open_mock):
|
with patch("chatmastermind.storage.open", open_mock):
|
||||||
handle_question(self.args, self.config, True)
|
handle_question(self.args, self.config, True)
|
||||||
mock_process_tags.assert_called_once_with(self.config,
|
mock_process_tags.assert_called_once_with(self.args.tags,
|
||||||
self.args.tags,
|
self.args.extags,
|
||||||
self.args.extags)
|
[])
|
||||||
mock_create_chat.assert_called_once_with(self.question,
|
mock_create_chat.assert_called_once_with(self.question,
|
||||||
self.args.tags,
|
self.args.tags,
|
||||||
self.args.extags,
|
self.args.extags,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user