Compare commits
6 Commits
822d582aaf
...
028024474a
| Author | SHA1 | Date | |
|---|---|---|---|
| 028024474a | |||
| 4d5316bf18 | |||
| 5771206946 | |||
| 2cb90aac5c | |||
| d889022021 | |||
| 0a3e96056d |
@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase):
|
||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
# make the next FID match the current state
|
||||
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write('4')
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
|
||||
"""
|
||||
List all Message files in the given TemporaryDirectory.
|
||||
"""
|
||||
# exclude '.next'
|
||||
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.db_path.cleanup()
|
||||
@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase):
|
||||
def test_chat_db_fids(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.get_next_fid(), 1)
|
||||
self.assertEqual(chat_db.get_next_fid(), 2)
|
||||
self.assertEqual(chat_db.get_next_fid(), 3)
|
||||
self.assertEqual(chat_db.get_next_fid(), 5)
|
||||
self.assertEqual(chat_db.get_next_fid(), 6)
|
||||
self.assertEqual(chat_db.get_next_fid(), 7)
|
||||
with open(chat_db.next_fname, 'r') as f:
|
||||
self.assertEqual(f.read(), '3')
|
||||
self.assertEqual(f.read(), '7')
|
||||
|
||||
def test_chat_db_write(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase):
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
|
||||
@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase):
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
|
||||
|
||||
# check the timestamp of the files in the DB directory
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
|
||||
# overwrite the messages in the db directory
|
||||
time.sleep(0.05)
|
||||
chat_db.write_db()
|
||||
# check if the written files are in the DB directory
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase):
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
|
||||
# now rewrite them to the DB dir and check for modified paths
|
||||
chat_db.write_db()
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
@ -337,11 +348,11 @@ class TestChatDB(CmmTestCase):
|
||||
|
||||
# clear the cache and check the cache dir
|
||||
chat_db.clear_cache()
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 0)
|
||||
# make sure that the DB messages (and the new message) are still there
|
||||
self.assertEqual(len(chat_db.messages), 5)
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
# but not the message with the cache dir path
|
||||
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
|
||||
@ -350,19 +361,15 @@ class TestChatDB(CmmTestCase):
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# make the next FID match the current state
|
||||
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write('4')
|
||||
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*.[ty]*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
|
||||
# add new messages to the cache dir
|
||||
message1 = Message(question=Question("Question 1"),
|
||||
answer=Answer("Answer 1"))
|
||||
chat_db.add_to_cache([message1])
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
self.assertIsNotNone(chat_db.messages[4].file_path)
|
||||
self.assertEqual(chat_db.messages[4].file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
|
||||
@ -371,9 +378,7 @@ class TestChatDB(CmmTestCase):
|
||||
message2 = Message(question=Question("Question 2"),
|
||||
answer=Answer("Answer 2"))
|
||||
chat_db.add_to_db([message2])
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*.[ty]*'))
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 5)
|
||||
self.assertIsNotNone(chat_db.messages[5].file_path)
|
||||
self.assertEqual(chat_db.messages[5].file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
|
||||
|
||||
next_fname.unlink()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user