diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 083b91e..41d12b9 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -16,7 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] -valid_sources = Literal['mem', 'disk', 'cache', 'db', 'all'] +msg_place = Literal['mem', 'disk', 'cache', 'db', 'all'] class ChatError(Exception): @@ -194,8 +194,9 @@ class Chat: def msg_find(self, msg_names: list[str]) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames - (with or without suffix) or full paths. Messages that can't be found are ignored - (i. e. the caller should check the result if they require all messages). + (with or without suffix), full paths or Message.msg_id(). Messages that can't be + found are ignored (i. e. the caller should check the result if they require all + messages). """ return [m for m in self.messages if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] @@ -203,7 +204,7 @@ class Chat: def msg_remove(self, msg_names: list[str]) -> None: """ Remove the messages with the given names. Names can either be filenames - (with or without suffix) or full paths. + (with or without suffix), full paths or Message.msg_id(). """ self.messages = [m for m in self.messages if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] @@ -352,7 +353,7 @@ class ChatDB(Chat): self.msg_write(messages) def msg_gather(self, - source: valid_sources, + source: msg_place, require_file_path: bool = False, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ @@ -385,13 +386,14 @@ class ChatDB(Chat): def msg_find(self, msg_names: list[str], - source: valid_sources = 'mem', + source: msg_place = 'mem', ) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames - (with or without suffix) or full paths. Messages that can't be found are ignored - (i. e. the caller should check the result if they require all messages). - Searches one of the following sources: + (with or without suffix), full paths or Message.msg_id(). Messages that can't be + found are ignored (i. e. the caller should check the result if they require all + messages). + Searches one of the following places: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory @@ -405,8 +407,8 @@ class ChatDB(Chat): def msg_remove(self, msg_names: list[str]) -> None: """ Remove the messages with the given names. Names can either be filenames - (with or without suffix) or full paths. Also deletes the files of all given - messages with a valid file_path. + (with or without suffix), full paths or Message.msg_id(). Also deletes the + files of all given messages with a valid file_path. """ # delete the message files first rm_messages = self.msg_find(msg_names, source='all') @@ -418,11 +420,11 @@ class ChatDB(Chat): def msg_latest(self, mfilter: Optional[MessageFilter] = None, - source: valid_sources = 'mem') -> Optional[Message]: + source: msg_place = 'mem') -> Optional[Message]: """ Return the last added message (according to the file ID) that matches the given filter. Only consider messages with a valid file_path (except if source is 'mem'). - Searches one of the following sources: + Searches one of the following places: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 8b32ae9..63f408c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -543,10 +543,11 @@ class Message(): def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. - Currently this is the file name. The ID is also used for sorting messages. + Currently this is the file name without suffix. The ID is also used for sorting + messages. """ if self.file_path: - return self.file_path.name + return self.file_path.stem else: raise MessageError("Can't create file ID without a file path") diff --git a/tests/test_chat.py b/tests/test_chat.py index ab37a6b..4962688 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -516,18 +516,22 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name)) # search for a DB file in memory self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) # and on disk self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) + self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) # now search the cache -> expect empty result self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) + self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) # search for multiple messages - search_names = ['0001', '0002.yaml', str(self.message3.file_path)] + # -> search one twice, expect result to be unique + search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] expected_result = [self.message1, self.message2, self.message3] result = chat_db.msg_find(search_names, source='all') self.assertSequenceEqual(result, expected_result) diff --git a/tests/test_message.py b/tests/test_message.py index 1f440df..5c7997f 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -730,7 +730,7 @@ class MessageIDTestCase(unittest.TestCase): self.file_path.unlink() def test_msg_id_txt(self) -> None: - self.assertEqual(self.message.msg_id(), self.file_path.name) + self.assertEqual(self.message.msg_id(), self.file_path.stem) def test_msg_id_txt_exception(self) -> None: with self.assertRaises(MessageError):