Merge branch 'main' of fstage:ok/ChatMastermind
This commit is contained in:
commit
648d7b35cc
1
.gitignore
vendored
1
.gitignore
vendored
@ -130,3 +130,4 @@ dmypy.json
|
|||||||
|
|
||||||
.config.yaml
|
.config.yaml
|
||||||
db
|
db
|
||||||
|
noweb
|
||||||
@ -6,16 +6,27 @@ import yaml
|
|||||||
import sys
|
import sys
|
||||||
import argcomplete
|
import argcomplete
|
||||||
import argparse
|
import argparse
|
||||||
from .utils import terminal_width, pp, process_tags, display_chat
|
import pathlib
|
||||||
from .storage import save_answers, create_chat, get_tags
|
from .utils import terminal_width, process_tags, display_chat, display_source_code
|
||||||
|
from .storage import save_answers, create_chat, get_tags, read_file, dump_data
|
||||||
from .api_client import ai, openai_api_key
|
from .api_client import ai, openai_api_key
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
|
||||||
|
|
||||||
def run_print_command(args: argparse.Namespace, config: dict) -> None:
|
def run_print_command(args: argparse.Namespace, config: dict) -> None:
|
||||||
|
fname = pathlib.Path(args.print)
|
||||||
|
if fname.suffix == '.yaml':
|
||||||
with open(args.print, 'r') as f:
|
with open(args.print, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
pp(data)
|
elif fname.suffix == '.txt':
|
||||||
|
data = read_file(fname)
|
||||||
|
else:
|
||||||
|
print(f"Unknown file type: {args.print}")
|
||||||
|
sys.exit(1)
|
||||||
|
if args.only_source_code:
|
||||||
|
display_source_code(data['answer'])
|
||||||
|
else:
|
||||||
|
print(dump_data(data).strip())
|
||||||
|
|
||||||
|
|
||||||
def process_and_display_chat(args: argparse.Namespace,
|
def process_and_display_chat(args: argparse.Namespace,
|
||||||
@ -72,7 +83,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="ChatMastermind is a Python application that automates conversation with AI")
|
description="ChatMastermind is a Python application that automates conversation with AI")
|
||||||
group = parser.add_mutually_exclusive_group(required=True)
|
group = parser.add_mutually_exclusive_group(required=True)
|
||||||
group.add_argument('-p', '--print', help='YAML file to print')
|
group.add_argument('-p', '--print', help='File to print')
|
||||||
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
|
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
|
||||||
group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
||||||
group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true')
|
group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true')
|
||||||
@ -80,7 +91,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
||||||
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
||||||
parser.add_argument('-M', '--model', help='Model to use')
|
parser.add_argument('-M', '--model', help='Model to use')
|
||||||
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3)
|
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1)
|
||||||
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
|
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
|
||||||
parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
|
parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
|
||||||
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
|
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
|
||||||
|
|||||||
@ -5,6 +5,34 @@ from .utils import terminal_width, append_message, message_to_chat
|
|||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(fname: str, tags_only: bool = False) -> Dict[str, Any]:
|
||||||
|
with open(fname, "r") as fd:
|
||||||
|
if tags_only:
|
||||||
|
return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]}
|
||||||
|
text = fd.read().strip().split('\n')
|
||||||
|
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
|
||||||
|
question_idx = text.index("=== QUESTION ===") + 1
|
||||||
|
answer_idx = text.index("==== ANSWER ====")
|
||||||
|
question = "\n".join(text[question_idx:answer_idx]).strip()
|
||||||
|
answer = "\n".join(text[answer_idx + 1:]).strip()
|
||||||
|
return {"question": question, "answer": answer, "tags": tags}
|
||||||
|
|
||||||
|
|
||||||
|
def dump_data(data: Dict[str, Any]) -> str:
|
||||||
|
with io.StringIO() as fd:
|
||||||
|
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
||||||
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
||||||
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
||||||
|
return fd.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def write_file(fname: str, data: Dict[str, Any]) -> None:
|
||||||
|
with open(fname, "w") as fd:
|
||||||
|
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
||||||
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
||||||
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
||||||
|
|
||||||
|
|
||||||
def save_answers(question: str,
|
def save_answers(question: str,
|
||||||
answers: list[str],
|
answers: list[str],
|
||||||
tags: list[str],
|
tags: list[str],
|
||||||
@ -26,22 +54,7 @@ def save_answers(question: str,
|
|||||||
title_end = '-' * (terminal_width() - len(title))
|
title_end = '-' * (terminal_width() - len(title))
|
||||||
print(f'{title}{title_end}')
|
print(f'{title}{title_end}')
|
||||||
print(answer)
|
print(answer)
|
||||||
with open(f"{num:04d}.yaml", "w") as fd:
|
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
|
||||||
with io.StringIO() as f:
|
|
||||||
yaml.dump({'question': question},
|
|
||||||
f,
|
|
||||||
default_style="|",
|
|
||||||
default_flow_style=False)
|
|
||||||
fd.write(f.getvalue().replace('"question":', "question:", 1))
|
|
||||||
with io.StringIO() as f:
|
|
||||||
yaml.dump({'answer': answer},
|
|
||||||
f,
|
|
||||||
default_style="|",
|
|
||||||
default_flow_style=False)
|
|
||||||
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
|
|
||||||
yaml.dump({'tags': wtags},
|
|
||||||
fd,
|
|
||||||
default_flow_style=False)
|
|
||||||
with open(next_fname, 'w') as f:
|
with open(next_fname, 'w') as f:
|
||||||
f.write(f'{num}')
|
f.write(f'{num}')
|
||||||
|
|
||||||
@ -57,6 +70,10 @@ def create_chat(question: Optional[str],
|
|||||||
if file.suffix == '.yaml':
|
if file.suffix == '.yaml':
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
elif file.suffix == '.txt':
|
||||||
|
data = read_file(file)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
data_tags = set(data.get('tags', []))
|
data_tags = set(data.get('tags', []))
|
||||||
tags_match = \
|
tags_match = \
|
||||||
not tags or data_tags.intersection(tags)
|
not tags or data_tags.intersection(tags)
|
||||||
@ -75,6 +92,10 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|||||||
if file.suffix == '.yaml':
|
if file.suffix == '.yaml':
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
elif file.suffix == '.txt':
|
||||||
|
data = read_file(file, tags_only=True)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
for tag in data.get('tags', []):
|
for tag in data.get('tags', []):
|
||||||
if prefix and len(prefix) > 0:
|
if prefix and len(prefix) > 0:
|
||||||
if tag.startswith(prefix):
|
if tag.startswith(prefix):
|
||||||
|
|||||||
@ -40,25 +40,29 @@ def message_to_chat(message: Dict[str, str],
|
|||||||
append_message(chat, 'assistant', message['answer'])
|
append_message(chat, 'assistant', message['answer'])
|
||||||
|
|
||||||
|
|
||||||
|
def display_source_code(content: str) -> None:
|
||||||
|
try:
|
||||||
|
content_start = content.index('```')
|
||||||
|
content_end = content.rindex('```')
|
||||||
|
if content_start + 3 < content_end:
|
||||||
|
print(content[content_start + 3:content_end].strip())
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def display_chat(chat, dump=False, source_code=False) -> None:
|
def display_chat(chat, dump=False, source_code=False) -> None:
|
||||||
if dump:
|
if dump:
|
||||||
pp(chat)
|
pp(chat)
|
||||||
return
|
return
|
||||||
for message in chat:
|
for message in chat:
|
||||||
if message['role'] == 'user' and not source_code:
|
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
|
||||||
print('-' * (terminal_width()))
|
|
||||||
if len(message['content']) > terminal_width() - len(message['role']) - 2:
|
|
||||||
if not source_code:
|
|
||||||
print(f"{message['role'].upper()}:")
|
|
||||||
if source_code:
|
if source_code:
|
||||||
try:
|
display_source_code(message['content'])
|
||||||
content_start = message['content'].index('```')
|
continue
|
||||||
content_end = message['content'].rindex('```')
|
if message['role'] == 'user':
|
||||||
if content_start + 3 < content_end:
|
print('-' * terminal_width())
|
||||||
print(message['content'][content_start+3:content_end].strip())
|
if text_too_long:
|
||||||
except ValueError:
|
print(f"{message['role'].upper()}:")
|
||||||
pass
|
|
||||||
else:
|
|
||||||
print(message['content'])
|
print(message['content'])
|
||||||
elif not source_code:
|
else:
|
||||||
print(f"{message['role'].upper()}: {message['content']}")
|
print(f"{message['role'].upper()}: {message['content']}")
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import io
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
import yaml
|
|
||||||
import argparse
|
import argparse
|
||||||
from chatmastermind.utils import terminal_width
|
from chatmastermind.utils import terminal_width
|
||||||
from chatmastermind.main import create_parser, handle_question
|
from chatmastermind.main import create_parser, handle_question
|
||||||
from chatmastermind.api_client import ai
|
from chatmastermind.api_client import ai
|
||||||
from chatmastermind.storage import create_chat, save_answers
|
from chatmastermind.storage import create_chat, save_answers, dump_data
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch, MagicMock, Mock
|
from unittest.mock import patch, MagicMock, Mock
|
||||||
|
|
||||||
@ -24,8 +23,8 @@ class TestCreateChat(unittest.TestCase):
|
|||||||
@patch('os.listdir')
|
@patch('os.listdir')
|
||||||
@patch('builtins.open')
|
@patch('builtins.open')
|
||||||
def test_create_chat_with_tags(self, open_mock, listdir_mock):
|
def test_create_chat_with_tags(self, open_mock, listdir_mock):
|
||||||
listdir_mock.return_value = ['testfile.yaml']
|
listdir_mock.return_value = ['testfile.txt']
|
||||||
open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump(
|
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
||||||
{'question': 'test_content', 'answer': 'some answer',
|
{'question': 'test_content', 'answer': 'some answer',
|
||||||
'tags': ['test_tag']}))
|
'tags': ['test_tag']}))
|
||||||
|
|
||||||
@ -44,8 +43,8 @@ class TestCreateChat(unittest.TestCase):
|
|||||||
@patch('os.listdir')
|
@patch('os.listdir')
|
||||||
@patch('builtins.open')
|
@patch('builtins.open')
|
||||||
def test_create_chat_with_other_tags(self, open_mock, listdir_mock):
|
def test_create_chat_with_other_tags(self, open_mock, listdir_mock):
|
||||||
listdir_mock.return_value = ['testfile.yaml']
|
listdir_mock.return_value = ['testfile.txt']
|
||||||
open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump(
|
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
||||||
{'question': 'test_content', 'answer': 'some answer',
|
{'question': 'test_content', 'answer': 'some answer',
|
||||||
'tags': ['other_tag']}))
|
'tags': ['other_tag']}))
|
||||||
|
|
||||||
@ -60,12 +59,12 @@ class TestCreateChat(unittest.TestCase):
|
|||||||
@patch('os.listdir')
|
@patch('os.listdir')
|
||||||
@patch('builtins.open')
|
@patch('builtins.open')
|
||||||
def test_create_chat_without_tags(self, open_mock, listdir_mock):
|
def test_create_chat_without_tags(self, open_mock, listdir_mock):
|
||||||
listdir_mock.return_value = ['testfile.yaml', 'testfile2.yaml']
|
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
|
||||||
open_mock.side_effect = (
|
open_mock.side_effect = (
|
||||||
io.StringIO(yaml.dump({'question': 'test_content',
|
io.StringIO(dump_data({'question': 'test_content',
|
||||||
'answer': 'some answer',
|
'answer': 'some answer',
|
||||||
'tags': ['test_tag']})),
|
'tags': ['test_tag']})),
|
||||||
io.StringIO(yaml.dump({'question': 'test_content2',
|
io.StringIO(dump_data({'question': 'test_content2',
|
||||||
'answer': 'some answer2',
|
'answer': 'some answer2',
|
||||||
'tags': ['test_tag2']})),
|
'tags': ['test_tag2']})),
|
||||||
)
|
)
|
||||||
@ -109,8 +108,7 @@ class TestHandleQuestion(unittest.TestCase):
|
|||||||
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
||||||
@patch("chatmastermind.utils.pp")
|
@patch("chatmastermind.utils.pp")
|
||||||
@patch("builtins.print")
|
@patch("builtins.print")
|
||||||
@patch("chatmastermind.storage.yaml.dump")
|
def test_handle_question(self, mock_print, mock_pp, mock_ai,
|
||||||
def test_handle_question(self, _, mock_print, mock_pp, mock_ai,
|
|
||||||
mock_process_tags, mock_create_chat):
|
mock_process_tags, mock_create_chat):
|
||||||
open_mock = MagicMock()
|
open_mock = MagicMock()
|
||||||
with patch("chatmastermind.storage.open", open_mock):
|
with patch("chatmastermind.storage.open", open_mock):
|
||||||
@ -135,7 +133,7 @@ class TestHandleQuestion(unittest.TestCase):
|
|||||||
expected_calls.append((("-" * terminal_width(),),))
|
expected_calls.append((("-" * terminal_width(),),))
|
||||||
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
||||||
self.assertEqual(mock_print.call_args_list, expected_calls)
|
self.assertEqual(mock_print.call_args_list, expected_calls)
|
||||||
open_expected_calls = list([mock.call(f"{num:04d}.yaml", "w") for num in range(2, 5)])
|
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
|
||||||
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
||||||
|
|
||||||
|
|
||||||
@ -203,9 +201,9 @@ class TestCreateParser(unittest.TestCase):
|
|||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
self.assertIsInstance(parser, argparse.ArgumentParser)
|
self.assertIsInstance(parser, argparse.ArgumentParser)
|
||||||
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
|
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
|
||||||
mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print')
|
mock_group.add_argument.assert_any_call('-p', '--print', help='File to print')
|
||||||
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
|
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
|
||||||
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
||||||
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
|
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
|
||||||
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
||||||
self.assertEqual(parser.get_default('number'), 3)
|
self.assertEqual(parser.get_default('number'), 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user