Merge branch 'main' of fstage:ok/ChatMastermind

This commit is contained in:
Oleksandr Kozachuk 2023-05-08 13:59:42 +02:00
commit 648d7b35cc
5 changed files with 101 additions and 66 deletions

3
.gitignore vendored
View File

@ -129,4 +129,5 @@ dmypy.json
.pyre/
.config.yaml
db
db
noweb

View File

@ -6,16 +6,27 @@ import yaml
import sys
import argcomplete
import argparse
from .utils import terminal_width, pp, process_tags, display_chat
from .storage import save_answers, create_chat, get_tags
import pathlib
from .utils import terminal_width, process_tags, display_chat, display_source_code
from .storage import save_answers, create_chat, get_tags, read_file, dump_data
from .api_client import ai, openai_api_key
from itertools import zip_longest
def run_print_command(args: argparse.Namespace, config: dict) -> None:
with open(args.print, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
pp(data)
fname = pathlib.Path(args.print)
if fname.suffix == '.yaml':
with open(args.print, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif fname.suffix == '.txt':
data = read_file(fname)
else:
print(f"Unknown file type: {args.print}")
sys.exit(1)
if args.only_source_code:
display_source_code(data['answer'])
else:
print(dump_data(data).strip())
def process_and_display_chat(args: argparse.Namespace,
@ -72,7 +83,7 @@ def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-p', '--print', help='YAML file to print')
group.add_argument('-p', '--print', help='File to print')
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true')
@ -80,7 +91,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
parser.add_argument('-M', '--model', help='Model to use')
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3)
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1)
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')

View File

@ -5,6 +5,34 @@ from .utils import terminal_width, append_message, message_to_chat
from typing import List, Dict, Any, Optional
def read_file(fname: str, tags_only: bool = False) -> Dict[str, Any]:
with open(fname, "r") as fd:
if tags_only:
return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]}
text = fd.read().strip().split('\n')
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags}
def dump_data(data: Dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: Dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
@ -26,22 +54,7 @@ def save_answers(question: str,
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
with open(f"{num:04d}.yaml", "w") as fd:
with io.StringIO() as f:
yaml.dump({'question': question},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"question":', "question:", 1))
with io.StringIO() as f:
yaml.dump({'answer': answer},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
yaml.dump({'tags': wtags},
fd,
default_flow_style=False)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
@ -57,13 +70,17 @@ def create_chat(question: Optional[str],
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
if question:
append_message(chat, 'user', question)
return chat
@ -75,10 +92,14 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
elif file.suffix == '.txt':
data = read_file(file, tags_only=True)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))

View File

@ -40,25 +40,29 @@ def message_to_chat(message: Dict[str, str],
append_message(chat, 'assistant', message['answer'])
def display_source_code(content: str) -> None:
try:
content_start = content.index('```')
content_end = content.rindex('```')
if content_start + 3 < content_end:
print(content[content_start + 3:content_end].strip())
except ValueError:
pass
def display_chat(chat, dump=False, source_code=False) -> None:
if dump:
pp(chat)
return
for message in chat:
if message['role'] == 'user' and not source_code:
print('-' * (terminal_width()))
if len(message['content']) > terminal_width() - len(message['role']) - 2:
if not source_code:
print(f"{message['role'].upper()}:")
if source_code:
try:
content_start = message['content'].index('```')
content_end = message['content'].rindex('```')
if content_start + 3 < content_end:
print(message['content'][content_start+3:content_end].strip())
except ValueError:
pass
else:
print(message['content'])
elif not source_code:
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
if source_code:
display_source_code(message['content'])
continue
if message['role'] == 'user':
print('-' * terminal_width())
if text_too_long:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")

View File

@ -1,12 +1,11 @@
import unittest
import io
import pathlib
import yaml
import argparse
from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, handle_question
from chatmastermind.api_client import ai
from chatmastermind.storage import create_chat, save_answers
from chatmastermind.storage import create_chat, save_answers, dump_data
from unittest import mock
from unittest.mock import patch, MagicMock, Mock
@ -24,8 +23,8 @@ class TestCreateChat(unittest.TestCase):
@patch('os.listdir')
@patch('builtins.open')
def test_create_chat_with_tags(self, open_mock, listdir_mock):
listdir_mock.return_value = ['testfile.yaml']
open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump(
listdir_mock.return_value = ['testfile.txt']
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['test_tag']}))
@ -44,8 +43,8 @@ class TestCreateChat(unittest.TestCase):
@patch('os.listdir')
@patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock, listdir_mock):
listdir_mock.return_value = ['testfile.yaml']
open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump(
listdir_mock.return_value = ['testfile.txt']
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['other_tag']}))
@ -60,12 +59,12 @@ class TestCreateChat(unittest.TestCase):
@patch('os.listdir')
@patch('builtins.open')
def test_create_chat_without_tags(self, open_mock, listdir_mock):
listdir_mock.return_value = ['testfile.yaml', 'testfile2.yaml']
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
open_mock.side_effect = (
io.StringIO(yaml.dump({'question': 'test_content',
io.StringIO(dump_data({'question': 'test_content',
'answer': 'some answer',
'tags': ['test_tag']})),
io.StringIO(yaml.dump({'question': 'test_content2',
io.StringIO(dump_data({'question': 'test_content2',
'answer': 'some answer2',
'tags': ['test_tag2']})),
)
@ -109,8 +108,7 @@ class TestHandleQuestion(unittest.TestCase):
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp")
@patch("builtins.print")
@patch("chatmastermind.storage.yaml.dump")
def test_handle_question(self, _, mock_print, mock_pp, mock_ai,
def test_handle_question(self, mock_print, mock_pp, mock_ai,
mock_process_tags, mock_create_chat):
open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock):
@ -135,7 +133,7 @@ class TestHandleQuestion(unittest.TestCase):
expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
self.assertEqual(mock_print.call_args_list, expected_calls)
open_expected_calls = list([mock.call(f"{num:04d}.yaml", "w") for num in range(2, 5)])
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
open_mock.assert_has_calls(open_expected_calls, any_order=True)
@ -203,9 +201,9 @@ class TestCreateParser(unittest.TestCase):
parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print')
mock_group.add_argument.assert_any_call('-p', '--print', help='File to print')
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
self.assertTrue('.config.yaml' in parser.get_default('config'))
self.assertEqual(parser.get_default('number'), 3)
self.assertEqual(parser.get_default('number'), 1)