Compare commits

..

6 Commits

View File

@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase):
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
"""
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
def tearDown(self) -> None:
self.db_path.cleanup()
@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase):
def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 1)
self.assertEqual(chat_db.get_next_fid(), 2)
self.assertEqual(chat_db.get_next_fid(), 3)
self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '3')
self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None:
# create a new ChatDB instance
@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase):
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.write_db()
# check if the written files are in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase):
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths
chat_db.write_db()
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
@ -337,11 +348,11 @@ class TestChatDB(CmmTestCase):
# clear the cache and check the cache dir
chat_db.clear_cache()
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there
self.assertEqual(len(chat_db.messages), 5)
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
@ -350,19 +361,15 @@ class TestChatDB(CmmTestCase):
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*.[ty]*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# add new messages to the cache dir
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.add_to_cache([message1])
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIsNotNone(chat_db.messages[4].file_path)
self.assertEqual(chat_db.messages[4].file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@ -371,9 +378,7 @@ class TestChatDB(CmmTestCase):
message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2"))
chat_db.add_to_db([message2])
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*.[ty]*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5)
self.assertIsNotNone(chat_db.messages[5].file_path)
self.assertEqual(chat_db.messages[5].file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
next_fname.unlink()