Name files with autogenerated numbers
This commit is contained in:
parent
4375b6aafd
commit
16f059920b
7
.vscode/settings.json
vendored
Normal file
7
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"python.testing.pytestArgs": [
|
||||
"tests"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
@ -56,7 +56,7 @@ def handle_question(args: argparse.Namespace,
|
||||
chat, question, tags = process_and_display_chat(args, config, dump)
|
||||
otags = args.output_tags or []
|
||||
answers, usage = ai(chat, config, args.number)
|
||||
save_answers(question, answers, tags, otags)
|
||||
save_answers(question, answers, tags, otags, config)
|
||||
print("-" * terminal_width())
|
||||
print(f"Usage: {usage}")
|
||||
|
||||
|
||||
@ -8,15 +8,25 @@ from typing import List, Dict, Any, Optional
|
||||
def save_answers(question: str,
|
||||
answers: list[str],
|
||||
tags: list[str],
|
||||
otags: Optional[list[str]]
|
||||
otags: Optional[list[str]],
|
||||
config: Dict[str, Any]
|
||||
) -> None:
|
||||
wtags = otags or tags
|
||||
for num, answer in enumerate(answers, start=1):
|
||||
title = f'-- ANSWER {num} '
|
||||
num, inum = 0, 0
|
||||
next_fname = pathlib.Path(config['db']) / '.next'
|
||||
try:
|
||||
with open(next_fname, 'r') as f:
|
||||
num = int(f.read())
|
||||
except Exception:
|
||||
pass
|
||||
for answer in answers:
|
||||
num += 1
|
||||
inum += 1
|
||||
title = f'-- ANSWER {inum} '
|
||||
title_end = '-' * (terminal_width() - len(title))
|
||||
print(f'{title}{title_end}')
|
||||
print(answer)
|
||||
with open(f"{num:02d}.yaml", "w") as fd:
|
||||
with open(f"{num:04d}.yaml", "w") as fd:
|
||||
with io.StringIO() as f:
|
||||
yaml.dump({'question': question},
|
||||
f,
|
||||
@ -32,6 +42,8 @@ def save_answers(question: str,
|
||||
yaml.dump({'tags': wtags},
|
||||
fd,
|
||||
default_flow_style=False)
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write(f'{num}')
|
||||
|
||||
|
||||
def create_chat(question: Optional[str],
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import yaml
|
||||
import argparse
|
||||
from chatmastermind.utils import terminal_width
|
||||
@ -98,6 +98,7 @@ class TestHandleQuestion(unittest.TestCase):
|
||||
number=3
|
||||
)
|
||||
self.config = {
|
||||
'db': 'test_files',
|
||||
'setting1': 'value1',
|
||||
'setting2': 'value2'
|
||||
}
|
||||
@ -132,42 +133,33 @@ class TestHandleQuestion(unittest.TestCase):
|
||||
expected_calls.append(((answer,),))
|
||||
expected_calls.append((("-" * terminal_width(),),))
|
||||
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
||||
open_mock.assert_has_calls(
|
||||
[mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)] + [
|
||||
mock.call().__enter__(), mock.call().__exit__(None, None, None)] * 3,
|
||||
any_order=True)
|
||||
self.assertEqual(mock_print.call_args_list, expected_calls)
|
||||
open_expected_calls = list([mock.call(f"{num:04d}.yaml", "w") for num in range(2, 5)])
|
||||
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
||||
|
||||
|
||||
class TestSaveAnswers(unittest.TestCase):
|
||||
@mock.patch('builtins.open')
|
||||
@mock.patch('chatmastermind.storage.print')
|
||||
def test_save_answers(self, print_mock, open_mock):
|
||||
question = "Test question?"
|
||||
answers = ["Answer 1", "Answer 2"]
|
||||
tags = ["tag1", "tag2"]
|
||||
otags = ["otag1", "otag2"]
|
||||
config = {'db': 'test_db'}
|
||||
|
||||
def setUp(self):
|
||||
self.question = "What is AI?"
|
||||
self.answers = ["AI is Artificial Intelligence",
|
||||
"AI is a simulation of human intelligence"]
|
||||
self.tags = ["ai", "definition"]
|
||||
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
|
||||
mock.patch('chatmastermind.storage.yaml.dump'), \
|
||||
mock.patch('io.StringIO') as stringio_mock:
|
||||
stringio_instance = stringio_mock.return_value
|
||||
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
|
||||
save_answers(question, answers, tags, otags, config)
|
||||
|
||||
@patch('sys.stdout', new_callable=io.StringIO)
|
||||
def assert_stdout(self, expected_output: str, mock_stdout: io.StringIO):
|
||||
save_answers(self.question, self.answers, self.tags, None)
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
def test_save_answers(self):
|
||||
try:
|
||||
self.assert_stdout(f"-- ANSWER 1 {'-'*(terminal_width()-12)}\n"
|
||||
"AI is Artificial Intelligence\n"
|
||||
f"-- ANSWER 2 {'-'*(terminal_width()-12)}\n"
|
||||
"AI is a simulation of human intelligence\n")
|
||||
for idx, answer in enumerate(self.answers, start=1):
|
||||
with open(f"{idx:02d}.yaml", "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
self.assertEqual(data["question"], self.question)
|
||||
self.assertEqual(data["answer"], answer)
|
||||
self.assertEqual(data["tags"], self.tags)
|
||||
finally:
|
||||
for idx in range(1, len(self.answers) + 1):
|
||||
if os.path.exists(f"{idx:02d}.yaml"):
|
||||
os.remove(f"{idx:02d}.yaml")
|
||||
open_calls = [
|
||||
mock.call(pathlib.Path('test_db/.next'), 'r'),
|
||||
mock.call(pathlib.Path('test_db/.next'), 'w'),
|
||||
]
|
||||
open_mock.assert_has_calls(open_calls, any_order=True)
|
||||
|
||||
|
||||
class TestAI(unittest.TestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user