chat: new possibilites for adding messages and better tests
This commit is contained in:
parent
44cd1fab45
commit
0a3e96056d
@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path,
|
||||
return messages
|
||||
|
||||
|
||||
def make_file_path(dir_path: pathlib.Path,
|
||||
file_suffix: str,
|
||||
next_fid: Callable[[], int]) -> pathlib.Path:
|
||||
"""
|
||||
Create a file_path for the given directory using the
|
||||
given file_suffix and ID generator function.
|
||||
"""
|
||||
return dir_path / f"{next_fid():04d}{file_suffix}"
|
||||
|
||||
|
||||
def write_dir(dir_path: pathlib.Path,
|
||||
messages: list[Message],
|
||||
file_suffix: str,
|
||||
@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path,
|
||||
file_path = message.file_path
|
||||
# message has no file_path: create one
|
||||
if not file_path:
|
||||
fid = next_fid()
|
||||
fname = f"{fid:04d}{file_suffix}"
|
||||
file_path = dir_path / fname
|
||||
file_path = make_file_path(dir_path, file_suffix, next_fid)
|
||||
# file_path does not point to given directory: modify it
|
||||
elif not file_path.parent.samefile(dir_path):
|
||||
file_path = dir_path / file_path.name
|
||||
@ -124,11 +132,11 @@ class Chat:
|
||||
"""
|
||||
self.messages = []
|
||||
|
||||
def add_msgs(self, msgs: list[Message]) -> None:
|
||||
def add_messages(self, messages: list[Message]) -> None:
|
||||
"""
|
||||
Add new messages and sort them if possible.
|
||||
"""
|
||||
self.messages += msgs
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
@ -279,25 +287,25 @@ class ChatDB(Chat):
|
||||
self.messages += new_messages
|
||||
self.sort()
|
||||
|
||||
def write_db(self, msgs: Optional[list[Message]] = None) -> None:
|
||||
def write_db(self, messages: Optional[list[Message]] = None) -> None:
|
||||
"""
|
||||
Write messages to the DB directory. If a message has no file_path, a new one
|
||||
will be created. If message.file_path exists, it will be modified to point
|
||||
to the DB directory.
|
||||
"""
|
||||
write_dir(self.db_path,
|
||||
msgs if msgs else self.messages,
|
||||
messages if messages else self.messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
|
||||
def write_cache(self, msgs: Optional[list[Message]] = None) -> None:
|
||||
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
|
||||
"""
|
||||
Write messages to the cache directory. If a message has no file_path, a new one
|
||||
will be created. If message.file_path exists, it will be modified to point to
|
||||
the cache directory.
|
||||
"""
|
||||
write_dir(self.cache_path,
|
||||
msgs if msgs else self.messages,
|
||||
messages if messages else self.messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
|
||||
@ -309,3 +317,33 @@ class ChatDB(Chat):
|
||||
clear_dir(self.cache_path, self.glob)
|
||||
# only keep messages from DB dir (or those that have not yet been written)
|
||||
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
|
||||
|
||||
def add_to_db(self, messages: list[Message], do_write: bool = True) -> None:
|
||||
"""
|
||||
Adds the given messages and sets the file_path to the DB directory.
|
||||
"""
|
||||
if do_write:
|
||||
write_dir(self.db_path,
|
||||
messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
def add_to_cache(self, messages: list[Message], do_write: bool = True) -> None:
|
||||
"""
|
||||
Adds the given messages and sets the file_path to the cache directory.
|
||||
"""
|
||||
if do_write:
|
||||
write_dir(self.cache_path,
|
||||
messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
@ -22,14 +22,14 @@ class TestChat(CmmTestCase):
|
||||
file_path=pathlib.Path('0002.txt'))
|
||||
|
||||
def test_filter(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
|
||||
|
||||
self.assertEqual(len(self.chat.messages), 1)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
|
||||
def test_sort(self) -> None:
|
||||
self.chat.add_msgs([self.message2, self.message1])
|
||||
self.chat.add_messages([self.message2, self.message1])
|
||||
self.chat.sort()
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
@ -38,18 +38,18 @@ class TestChat(CmmTestCase):
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 1')
|
||||
|
||||
def test_clear(self) -> None:
|
||||
self.chat.add_msgs([self.message1])
|
||||
self.chat.add_messages([self.message1])
|
||||
self.chat.clear()
|
||||
self.assertEqual(len(self.chat.messages), 0)
|
||||
|
||||
def test_add_msgs(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
def test_add_messages(self) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.assertEqual(len(self.chat.messages), 2)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
|
||||
def test_tags(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
tags_all = self.chat.tags()
|
||||
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||
tags_pref = self.chat.tags(prefix='a')
|
||||
@ -58,13 +58,13 @@ class TestChat(CmmTestCase):
|
||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||
|
||||
def test_tags_frequency(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
tags_freq = self.chat.tags_frequency()
|
||||
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.print(paged=False)
|
||||
expected_output = f"""{'-'*terminal_width()}
|
||||
{Question.txt_header}
|
||||
@ -81,7 +81,7 @@ Answer 2
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.print(paged=False, with_tags=True, with_files=True)
|
||||
expected_output = f"""{'-'*terminal_width()}
|
||||
{TagLine.prefix} atag1 btag2
|
||||
@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase):
|
||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
# make the next FID match the current state
|
||||
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write('4')
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
|
||||
"""
|
||||
List all Message files in the given TemporaryDirectory.
|
||||
"""
|
||||
# exclude '.next'
|
||||
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.db_path.cleanup()
|
||||
@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase):
|
||||
def test_chat_db_fids(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.get_next_fid(), 1)
|
||||
self.assertEqual(chat_db.get_next_fid(), 2)
|
||||
self.assertEqual(chat_db.get_next_fid(), 3)
|
||||
self.assertEqual(chat_db.get_next_fid(), 5)
|
||||
self.assertEqual(chat_db.get_next_fid(), 6)
|
||||
self.assertEqual(chat_db.get_next_fid(), 7)
|
||||
with open(chat_db.next_fname, 'r') as f:
|
||||
self.assertEqual(f.read(), '3')
|
||||
self.assertEqual(f.read(), '7')
|
||||
|
||||
def test_chat_db_write(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase):
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
|
||||
@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase):
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
|
||||
|
||||
# check the timestamp of the files in the DB directory
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
|
||||
# overwrite the messages in the db directory
|
||||
time.sleep(0.05)
|
||||
chat_db.write_db()
|
||||
# check if the written files are in the DB directory
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase):
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
|
||||
# now rewrite them to the DB dir and check for modified paths
|
||||
chat_db.write_db()
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
@ -333,15 +344,41 @@ class TestChatDB(CmmTestCase):
|
||||
message_cache = Message(question=Question("What the hell am I doing here?"),
|
||||
answer=Answer("You're a creep!"),
|
||||
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
|
||||
chat_db.add_msgs([message_empty, message_cache])
|
||||
chat_db.add_messages([message_empty, message_cache])
|
||||
|
||||
# clear the cache and check the cache dir
|
||||
chat_db.clear_cache()
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 0)
|
||||
# make sure that the DB messages (and the new message) are still there
|
||||
self.assertEqual(len(chat_db.messages), 5)
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
# but not the message with the cache dir path
|
||||
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
|
||||
|
||||
def test_chat_db_add(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
|
||||
# add new messages to the cache dir
|
||||
message1 = Message(question=Question("Question 1"),
|
||||
answer=Answer("Answer 1"))
|
||||
chat_db.add_to_cache([message1])
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
self.assertIsNotNone(chat_db.messages[4].file_path)
|
||||
self.assertEqual(chat_db.messages[4].file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
|
||||
|
||||
# add new messages to the DB dir
|
||||
message2 = Message(question=Question("Question 2"),
|
||||
answer=Answer("Answer 2"))
|
||||
chat_db.add_to_db([message2])
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 5)
|
||||
self.assertIsNotNone(chat_db.messages[5].file_path)
|
||||
self.assertEqual(chat_db.messages[5].file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user