From 57caba5360750ceb9ca134a34c3c82c0110ad998 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sun, 7 May 2023 14:02:19 +0200 Subject: [PATCH 1/5] Ignore noweb for now. --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0f48159..4ade1df 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,5 @@ dmypy.json .pyre/ .config.yaml -db \ No newline at end of file +db +noweb \ No newline at end of file From c5fd466ddaf27422ac8fe223f9a70eef08031512 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sun, 7 May 2023 14:31:17 +0200 Subject: [PATCH 2/5] Change storage to use text format per default, instead of yaml, but still support yaml. --- chatmastermind/storage.py | 75 ++++++++++++++++++++++++--------------- tests/test_main.py | 22 ++++++------ 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index fb7bd8d..be2bd88 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -5,6 +5,32 @@ from .utils import terminal_width, append_message, message_to_chat from typing import List, Dict, Any, Optional +def read_file(fname: str) -> Dict[str, Any]: + with open(fname, "r") as fd: + text = fd.read().split('\n') + tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')] + question_idx = text.index("=== QUESTION ===") + 1 + answer_idx = text.index("==== ANSWER ====") + question = "\n".join(text[question_idx:answer_idx]).strip() + answer = "\n".join(text[answer_idx + 1:]).strip() + return {"question": question, "answer": answer, "tags": tags} + + +def dump_data(data: Dict[str, Any]) -> str: + with io.StringIO() as fd: + fd.write(f'TAGS: {", ".join(data["tags"])}\n') + fd.write(f'=== QUESTION ===\n{data["question"]}\n') + fd.write(f'==== ANSWER ====\n{data["answer"]}\n') + return fd.getvalue() + + +def write_file(fname: str, data: Dict[str, Any]) -> None: + with open(fname, "w") as fd: + fd.write(f'TAGS: {", ".join(data["tags"])}\n') + fd.write(f'=== QUESTION ===\n{data["question"]}\n') + fd.write(f'==== ANSWER ====\n{data["answer"]}\n') + + def save_answers(question: str, answers: list[str], tags: list[str], @@ -26,22 +52,7 @@ def save_answers(question: str, title_end = '-' * (terminal_width() - len(title)) print(f'{title}{title_end}') print(answer) - with open(f"{num:04d}.yaml", "w") as fd: - with io.StringIO() as f: - yaml.dump({'question': question}, - f, - default_style="|", - default_flow_style=False) - fd.write(f.getvalue().replace('"question":', "question:", 1)) - with io.StringIO() as f: - yaml.dump({'answer': answer}, - f, - default_style="|", - default_flow_style=False) - fd.write(f.getvalue().replace('"answer":', "answer:", 1)) - yaml.dump({'tags': wtags}, - fd, - default_flow_style=False) + write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) with open(next_fname, 'w') as f: f.write(f'{num}') @@ -57,13 +68,17 @@ def create_chat(question: Optional[str], if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) - data_tags = set(data.get('tags', [])) - tags_match = \ - not tags or data_tags.intersection(tags) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat) + elif file.suffix == '.txt': + data = read_file(file) + else: + continue + data_tags = set(data.get('tags', [])) + tags_match = \ + not tags or data_tags.intersection(tags) + extags_do_not_match = \ + not extags or not data_tags.intersection(extags) + if tags_match and extags_do_not_match: + message_to_chat(data, chat) if question: append_message(chat, 'user', question) return chat @@ -75,10 +90,14 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: + elif file.suffix == '.txt': + data = read_file(file) + else: + continue + for tag in data.get('tags', []): + if prefix and len(prefix) > 0: + if tag.startswith(prefix): result.append(tag) + else: + result.append(tag) return list(set(result)) diff --git a/tests/test_main.py b/tests/test_main.py index 5b5d178..9fe4a6b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,12 +1,11 @@ import unittest import io import pathlib -import yaml import argparse from chatmastermind.utils import terminal_width from chatmastermind.main import create_parser, handle_question from chatmastermind.api_client import ai -from chatmastermind.storage import create_chat, save_answers +from chatmastermind.storage import create_chat, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock @@ -24,8 +23,8 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('builtins.open') def test_create_chat_with_tags(self, open_mock, listdir_mock): - listdir_mock.return_value = ['testfile.yaml'] - open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump( + listdir_mock.return_value = ['testfile.txt'] + open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( {'question': 'test_content', 'answer': 'some answer', 'tags': ['test_tag']})) @@ -44,8 +43,8 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('builtins.open') def test_create_chat_with_other_tags(self, open_mock, listdir_mock): - listdir_mock.return_value = ['testfile.yaml'] - open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump( + listdir_mock.return_value = ['testfile.txt'] + open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( {'question': 'test_content', 'answer': 'some answer', 'tags': ['other_tag']})) @@ -60,12 +59,12 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('builtins.open') def test_create_chat_without_tags(self, open_mock, listdir_mock): - listdir_mock.return_value = ['testfile.yaml', 'testfile2.yaml'] + listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] open_mock.side_effect = ( - io.StringIO(yaml.dump({'question': 'test_content', + io.StringIO(dump_data({'question': 'test_content', 'answer': 'some answer', 'tags': ['test_tag']})), - io.StringIO(yaml.dump({'question': 'test_content2', + io.StringIO(dump_data({'question': 'test_content2', 'answer': 'some answer2', 'tags': ['test_tag2']})), ) @@ -109,8 +108,7 @@ class TestHandleQuestion(unittest.TestCase): @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") - @patch("chatmastermind.storage.yaml.dump") - def test_handle_question(self, _, mock_print, mock_pp, mock_ai, + def test_handle_question(self, mock_print, mock_pp, mock_ai, mock_process_tags, mock_create_chat): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): @@ -135,7 +133,7 @@ class TestHandleQuestion(unittest.TestCase): expected_calls.append((("-" * terminal_width(),),)) expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) self.assertEqual(mock_print.call_args_list, expected_calls) - open_expected_calls = list([mock.call(f"{num:04d}.yaml", "w") for num in range(2, 5)]) + open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) open_mock.assert_has_calls(open_expected_calls, any_order=True) From 63a202376d0ad391acd0cee5538a4448a97afd83 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sun, 7 May 2023 14:43:56 +0200 Subject: [PATCH 3/5] Optimize tags retrival in TXT files. --- chatmastermind/storage.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index be2bd88..0ce6f9e 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -5,9 +5,11 @@ from .utils import terminal_width, append_message, message_to_chat from typing import List, Dict, Any, Optional -def read_file(fname: str) -> Dict[str, Any]: +def read_file(fname: str, tags_only: bool = False) -> Dict[str, Any]: with open(fname, "r") as fd: - text = fd.read().split('\n') + if tags_only: + return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]} + text = fd.read().strip().split('\n') tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')] question_idx = text.index("=== QUESTION ===") + 1 answer_idx = text.index("==== ANSWER ====") @@ -91,7 +93,7 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) elif file.suffix == '.txt': - data = read_file(file) + data = read_file(file, tags_only=True) else: continue for tag in data.get('tags', []): From e0fac306cbe6ab423010612fa7fb8d106252c593 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sun, 7 May 2023 15:01:48 +0200 Subject: [PATCH 4/5] Change default number of answers to 1 --- chatmastermind/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 1c937d4..ad68cba 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -82,7 +82,7 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) parser.add_argument('-M', '--model', help='Model to use') - parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3) + parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS') From 6ae0e7d084058d7df14327cabc5745dcdbc5de0b Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sun, 7 May 2023 16:32:24 +0200 Subject: [PATCH 5/5] Improve handling of printing source code only output. --- chatmastermind/main.py | 23 +++++++++++++++++------ chatmastermind/utils.py | 34 +++++++++++++++++++--------------- tests/test_main.py | 4 ++-- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ad68cba..1a04b94 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,15 +6,26 @@ import yaml import sys import argcomplete import argparse -from .utils import terminal_width, pp, process_tags, display_chat -from .storage import save_answers, create_chat, get_tags +import pathlib +from .utils import terminal_width, process_tags, display_chat, display_source_code +from .storage import save_answers, create_chat, get_tags, read_file, dump_data from .api_client import ai, openai_api_key def run_print_command(args: argparse.Namespace, config: dict) -> None: - with open(args.print, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - pp(data) + fname = pathlib.Path(args.print) + if fname.suffix == '.yaml': + with open(args.print, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + elif fname.suffix == '.txt': + data = read_file(fname) + else: + print(f"Unknown file type: {args.print}") + sys.exit(1) + if args.only_source_code: + display_source_code(data['answer']) + else: + print(dump_data(data).strip()) def process_and_display_chat(args: argparse.Namespace, @@ -74,7 +85,7 @@ def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") group = parser.add_mutually_exclusive_group(required=True) - group.add_argument('-p', '--print', help='YAML file to print') + group.add_argument('-p', '--print', help='File to print') group.add_argument('-q', '--question', nargs='*', help='Question to ask') group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true') diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index d0d05ae..5f2af92 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -40,24 +40,28 @@ def message_to_chat(message: Dict[str, str], append_message(chat, 'assistant', message['answer']) +def display_source_code(content: str) -> None: + code_block_count = 0 + for line in content.splitlines(): + if line.strip().startswith('```'): + code_block_count += 1 + elif code_block_count == 1: + print(line) + + def display_chat(chat, dump=False, source_code=False) -> None: if dump: pp(chat) return for message in chat: - if message['role'] == 'user' and not source_code: - print('-' * (terminal_width())) - if len(message['content']) > terminal_width() - len(message['role']) - 2: - if not source_code: - print(f"{message['role'].upper()}:") - if source_code: - out = 0 - for line in message['content'].splitlines(): - if line.strip().startswith('```'): - out += 1 - elif out == 1: - print(f"{line}") - else: - print(message['content']) - elif not source_code: + text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 + if source_code: + display_source_code(message['content']) + continue + if message['role'] == 'user': + print('-' * terminal_width()) + if text_too_long: + print(f"{message['role'].upper()}:") + print(message['content']) + else: print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py index 9fe4a6b..eca160f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -201,9 +201,9 @@ class TestCreateParser(unittest.TestCase): parser = create_parser() self.assertIsInstance(parser, argparse.ArgumentParser) mock_add_mutually_exclusive_group.assert_called_once_with(required=True) - mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print') + mock_group.add_argument.assert_any_call('-p', '--print', help='File to print') mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask') mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true') self.assertTrue('.config.yaml' in parser.get_default('config')) - self.assertEqual(parser.get_default('number'), 3) + self.assertEqual(parser.get_default('number'), 1)