From f9d749cdd8f3f921b275c89302fedc8f844caa4a Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 09:19:47 +0200 Subject: [PATCH] chat: added clear_cache() function and test --- chatmastermind/chat.py | 20 +++++++++++++++++++ tests/test_chat.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index e4e8ab6..9fc0a27 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -82,6 +82,17 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) +def clear_dir(dir_path: pathlib.Path, + glob: Optional[str] = None) -> None: + """ + Deletes all Message files in the given directory. + """ + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in file_iter: + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + file_path.unlink(missing_ok=True) + + @dataclass class Chat: """ @@ -289,3 +300,12 @@ class ChatDB(Chat): msgs if msgs else self.messages, self.file_suffix, self.get_next_fid) + + def clear_cache(self) -> None: + """ + Deletes all Message files from the cache dir and removes those messages from + the internal list. + """ + clear_dir(self.cache_path, self.glob) + # only keep messages from DB dir (or those that have not yet been written) + self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e1ad0d..9e74061 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -300,3 +300,48 @@ class TestChatDB(CmmTestCase): # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + + def test_chat_db_clear(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + + # now rewrite them to the DB dir and check for modified paths + chat_db.write_db() + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + + # add a new message with empty file_path + message_empty = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You don't belong here!")) + # and one for the cache dir + message_cache = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You're a creep!"), + file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + chat_db.add_msgs([message_empty, message_cache]) + + # clear the cache and check the cache dir + chat_db.clear_cache() + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 0) + # make sure that the DB messages (and the new message) are still there + self.assertEqual(len(chat_db.messages), 5) + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + # but not the message with the cache dir path + self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))