Change storage to use text format per default, instead of yaml, but still support yaml.
This commit is contained in:
parent
57caba5360
commit
c5fd466dda
@ -5,6 +5,32 @@ from .utils import terminal_width, append_message, message_to_chat
|
|||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(fname: str) -> Dict[str, Any]:
|
||||||
|
with open(fname, "r") as fd:
|
||||||
|
text = fd.read().split('\n')
|
||||||
|
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
|
||||||
|
question_idx = text.index("=== QUESTION ===") + 1
|
||||||
|
answer_idx = text.index("==== ANSWER ====")
|
||||||
|
question = "\n".join(text[question_idx:answer_idx]).strip()
|
||||||
|
answer = "\n".join(text[answer_idx + 1:]).strip()
|
||||||
|
return {"question": question, "answer": answer, "tags": tags}
|
||||||
|
|
||||||
|
|
||||||
|
def dump_data(data: Dict[str, Any]) -> str:
|
||||||
|
with io.StringIO() as fd:
|
||||||
|
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
||||||
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
||||||
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
||||||
|
return fd.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def write_file(fname: str, data: Dict[str, Any]) -> None:
|
||||||
|
with open(fname, "w") as fd:
|
||||||
|
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
||||||
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
||||||
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
||||||
|
|
||||||
|
|
||||||
def save_answers(question: str,
|
def save_answers(question: str,
|
||||||
answers: list[str],
|
answers: list[str],
|
||||||
tags: list[str],
|
tags: list[str],
|
||||||
@ -26,22 +52,7 @@ def save_answers(question: str,
|
|||||||
title_end = '-' * (terminal_width() - len(title))
|
title_end = '-' * (terminal_width() - len(title))
|
||||||
print(f'{title}{title_end}')
|
print(f'{title}{title_end}')
|
||||||
print(answer)
|
print(answer)
|
||||||
with open(f"{num:04d}.yaml", "w") as fd:
|
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
|
||||||
with io.StringIO() as f:
|
|
||||||
yaml.dump({'question': question},
|
|
||||||
f,
|
|
||||||
default_style="|",
|
|
||||||
default_flow_style=False)
|
|
||||||
fd.write(f.getvalue().replace('"question":', "question:", 1))
|
|
||||||
with io.StringIO() as f:
|
|
||||||
yaml.dump({'answer': answer},
|
|
||||||
f,
|
|
||||||
default_style="|",
|
|
||||||
default_flow_style=False)
|
|
||||||
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
|
|
||||||
yaml.dump({'tags': wtags},
|
|
||||||
fd,
|
|
||||||
default_flow_style=False)
|
|
||||||
with open(next_fname, 'w') as f:
|
with open(next_fname, 'w') as f:
|
||||||
f.write(f'{num}')
|
f.write(f'{num}')
|
||||||
|
|
||||||
@ -57,13 +68,17 @@ def create_chat(question: Optional[str],
|
|||||||
if file.suffix == '.yaml':
|
if file.suffix == '.yaml':
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
data_tags = set(data.get('tags', []))
|
elif file.suffix == '.txt':
|
||||||
tags_match = \
|
data = read_file(file)
|
||||||
not tags or data_tags.intersection(tags)
|
else:
|
||||||
extags_do_not_match = \
|
continue
|
||||||
not extags or not data_tags.intersection(extags)
|
data_tags = set(data.get('tags', []))
|
||||||
if tags_match and extags_do_not_match:
|
tags_match = \
|
||||||
message_to_chat(data, chat)
|
not tags or data_tags.intersection(tags)
|
||||||
|
extags_do_not_match = \
|
||||||
|
not extags or not data_tags.intersection(extags)
|
||||||
|
if tags_match and extags_do_not_match:
|
||||||
|
message_to_chat(data, chat)
|
||||||
if question:
|
if question:
|
||||||
append_message(chat, 'user', question)
|
append_message(chat, 'user', question)
|
||||||
return chat
|
return chat
|
||||||
@ -75,10 +90,14 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|||||||
if file.suffix == '.yaml':
|
if file.suffix == '.yaml':
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
for tag in data.get('tags', []):
|
elif file.suffix == '.txt':
|
||||||
if prefix and len(prefix) > 0:
|
data = read_file(file)
|
||||||
if tag.startswith(prefix):
|
else:
|
||||||
result.append(tag)
|
continue
|
||||||
else:
|
for tag in data.get('tags', []):
|
||||||
|
if prefix and len(prefix) > 0:
|
||||||
|
if tag.startswith(prefix):
|
||||||
result.append(tag)
|
result.append(tag)
|
||||||
|
else:
|
||||||
|
result.append(tag)
|
||||||
return list(set(result))
|
return list(set(result))
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import io
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
import yaml
|
|
||||||
import argparse
|
import argparse
|
||||||
from chatmastermind.utils import terminal_width
|
from chatmastermind.utils import terminal_width
|
||||||
from chatmastermind.main import create_parser, handle_question
|
from chatmastermind.main import create_parser, handle_question
|
||||||
from chatmastermind.api_client import ai
|
from chatmastermind.api_client import ai
|
||||||
from chatmastermind.storage import create_chat, save_answers
|
from chatmastermind.storage import create_chat, save_answers, dump_data
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch, MagicMock, Mock
|
from unittest.mock import patch, MagicMock, Mock
|
||||||
|
|
||||||
@ -24,8 +23,8 @@ class TestCreateChat(unittest.TestCase):
|
|||||||
@patch('os.listdir')
|
@patch('os.listdir')
|
||||||
@patch('builtins.open')
|
@patch('builtins.open')
|
||||||
def test_create_chat_with_tags(self, open_mock, listdir_mock):
|
def test_create_chat_with_tags(self, open_mock, listdir_mock):
|
||||||
listdir_mock.return_value = ['testfile.yaml']
|
listdir_mock.return_value = ['testfile.txt']
|
||||||
open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump(
|
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
||||||
{'question': 'test_content', 'answer': 'some answer',
|
{'question': 'test_content', 'answer': 'some answer',
|
||||||
'tags': ['test_tag']}))
|
'tags': ['test_tag']}))
|
||||||
|
|
||||||
@ -44,8 +43,8 @@ class TestCreateChat(unittest.TestCase):
|
|||||||
@patch('os.listdir')
|
@patch('os.listdir')
|
||||||
@patch('builtins.open')
|
@patch('builtins.open')
|
||||||
def test_create_chat_with_other_tags(self, open_mock, listdir_mock):
|
def test_create_chat_with_other_tags(self, open_mock, listdir_mock):
|
||||||
listdir_mock.return_value = ['testfile.yaml']
|
listdir_mock.return_value = ['testfile.txt']
|
||||||
open_mock.return_value.__enter__.return_value = io.StringIO(yaml.dump(
|
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
||||||
{'question': 'test_content', 'answer': 'some answer',
|
{'question': 'test_content', 'answer': 'some answer',
|
||||||
'tags': ['other_tag']}))
|
'tags': ['other_tag']}))
|
||||||
|
|
||||||
@ -60,12 +59,12 @@ class TestCreateChat(unittest.TestCase):
|
|||||||
@patch('os.listdir')
|
@patch('os.listdir')
|
||||||
@patch('builtins.open')
|
@patch('builtins.open')
|
||||||
def test_create_chat_without_tags(self, open_mock, listdir_mock):
|
def test_create_chat_without_tags(self, open_mock, listdir_mock):
|
||||||
listdir_mock.return_value = ['testfile.yaml', 'testfile2.yaml']
|
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
|
||||||
open_mock.side_effect = (
|
open_mock.side_effect = (
|
||||||
io.StringIO(yaml.dump({'question': 'test_content',
|
io.StringIO(dump_data({'question': 'test_content',
|
||||||
'answer': 'some answer',
|
'answer': 'some answer',
|
||||||
'tags': ['test_tag']})),
|
'tags': ['test_tag']})),
|
||||||
io.StringIO(yaml.dump({'question': 'test_content2',
|
io.StringIO(dump_data({'question': 'test_content2',
|
||||||
'answer': 'some answer2',
|
'answer': 'some answer2',
|
||||||
'tags': ['test_tag2']})),
|
'tags': ['test_tag2']})),
|
||||||
)
|
)
|
||||||
@ -109,8 +108,7 @@ class TestHandleQuestion(unittest.TestCase):
|
|||||||
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
||||||
@patch("chatmastermind.utils.pp")
|
@patch("chatmastermind.utils.pp")
|
||||||
@patch("builtins.print")
|
@patch("builtins.print")
|
||||||
@patch("chatmastermind.storage.yaml.dump")
|
def test_handle_question(self, mock_print, mock_pp, mock_ai,
|
||||||
def test_handle_question(self, _, mock_print, mock_pp, mock_ai,
|
|
||||||
mock_process_tags, mock_create_chat):
|
mock_process_tags, mock_create_chat):
|
||||||
open_mock = MagicMock()
|
open_mock = MagicMock()
|
||||||
with patch("chatmastermind.storage.open", open_mock):
|
with patch("chatmastermind.storage.open", open_mock):
|
||||||
@ -135,7 +133,7 @@ class TestHandleQuestion(unittest.TestCase):
|
|||||||
expected_calls.append((("-" * terminal_width(),),))
|
expected_calls.append((("-" * terminal_width(),),))
|
||||||
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
||||||
self.assertEqual(mock_print.call_args_list, expected_calls)
|
self.assertEqual(mock_print.call_args_list, expected_calls)
|
||||||
open_expected_calls = list([mock.call(f"{num:04d}.yaml", "w") for num in range(2, 5)])
|
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
|
||||||
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user