diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index f030c5e..06895ac 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -401,7 +401,11 @@ class ChatDB(Chat): for m in loc_messages: if not message_in(m, unique_messages): unique_messages.append(m) - unique_messages.sort(key=lambda m: m.msg_id()) + try: + unique_messages.sort(key=lambda m: m.msg_id()) + # messages in 'mem' can have an empty file_path + except MessageError: + pass return unique_messages def msg_find(self, @@ -541,6 +545,22 @@ class ChatDB(Chat): # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + def cache_move(self, message: Message) -> None: + """ + Moves the given messages to the cache directory. + """ + # remember the old path (if any) + old_path: Optional[Path] = None + if message.file_path: + old_path = message.file_path + # write message to the new destination + self.cache_write([message]) + # remove the old one (if any) + if old_path: + self.msg_remove([str(old_path)], loc='db') + # (re)add it to the internal list + self.msg_add([message]) + def db_read(self) -> None: """ Read messages from the DB directory. New ones are added to the internal list, @@ -583,3 +603,19 @@ class ChatDB(Chat): m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) self.messages += messages self.msg_sort() + + def db_move(self, message: Message) -> None: + """ + Moves the given messages to the db directory. + """ + # remember the old path (if any) + old_path: Optional[Path] = None + if message.file_path: + old_path = message.file_path + # write message to the new destination + self.db_write([message]) + # remove the old one (if any) + if old_path: + self.msg_remove([str(old_path)], loc='cache') + # (re)add it to the internal list + self.msg_add([message]) diff --git a/tests/test_chat.py b/tests/test_chat.py index 3421852..eea2923 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -585,3 +585,52 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.msg_latest(loc='all'), new_message) # the DB does not contain the new message self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) + + def test_msg_gather(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + all_messages = [self.message1, self.message2, self.message3, self.message4] + self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + # add a new message, but only to the internal list + new_message = Message(Question("What?")) + all_messages_mem = all_messages + [new_message] + chat_db.msg_add([new_message]) + self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem) + self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem) + # the nr. of messages on disk did not change -> expect old result + self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + # test with MessageFilter + self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), + [self.message1]) + self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), + [self.message2]) + self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), + []) + self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), + [new_message]) + + def test_msg_move_and_gather(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + all_messages = [self.message1, self.message2, self.message3, self.message4] + self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + # move first message to the cache + chat_db.cache_move(self.message1) + self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1]) + self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) + self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) + self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) + # now move first message back to the DB + chat_db.db_move(self.message1) + self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages)