diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index fdfc3d3..cb4855e 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -16,7 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] -msg_place = Literal['mem', 'disk', 'cache', 'db', 'all'] +msg_location = Literal['mem', 'disk', 'cache', 'db', 'all'] class ChatError(Exception): @@ -373,11 +373,11 @@ class ChatDB(Chat): self.msg_write(messages) def msg_gather(self, - source: msg_place, + loc: msg_location, require_file_path: bool = False, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ - Gather and return messages from the given source: + Gather and return messages from the given locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory @@ -386,19 +386,19 @@ class ChatDB(Chat): If 'require_file_path' is True, return only files with a valid file_path. """ - source_messages: list[Message] = [] - if source in ['mem', 'all']: + loc_messages: list[Message] = [] + if loc in ['mem', 'all']: if require_file_path: - source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] + loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] else: - source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] - if source in ['cache', 'disk', 'all']: - source_messages += read_dir(self.cache_path, mfilter=mfilter) - if source in ['db', 'disk', 'all']: - source_messages += read_dir(self.db_path, mfilter=mfilter) + loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] + if loc in ['cache', 'disk', 'all']: + loc_messages += read_dir(self.cache_path, mfilter=mfilter) + if loc in ['db', 'disk', 'all']: + loc_messages += read_dir(self.db_path, mfilter=mfilter) # remove_duplicates and sort the list unique_messages: list[Message] = [] - for m in source_messages: + for m in loc_messages: if not message_in(m, unique_messages): unique_messages.append(m) unique_messages.sort(key=lambda m: m.msg_id()) @@ -406,45 +406,52 @@ class ChatDB(Chat): def msg_find(self, msg_names: list[str], - source: msg_place = 'mem', + loc: msg_location = 'mem', ) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). Messages that can't be found are ignored (i. e. the caller should check the result if they require all messages). - Searches one of the following places: + Searches one of the following locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory * 'db' : messages in the DB directory * 'all' : all messages ('mem' + 'disk') """ - source_messages = self.msg_gather(source, require_file_path=True) - return [m for m in source_messages + loc_messages = self.msg_gather(loc, require_file_path=True) + return [m for m in loc_messages if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] - def msg_remove(self, msg_names: list[str]) -> None: + def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None: """ Remove the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). Also deletes the files of all given messages with a valid file_path. + Delete files from one of the following locations: + * 'mem' : messages currently in memory + * 'disk' : messages on disk (cache + DB directory), but not in memory + * 'cache': messages in the cache directory + * 'db' : messages in the DB directory + * 'all' : all messages ('mem' + 'disk') """ - # delete the message files first - rm_messages = self.msg_find(msg_names, source='all') - for m in rm_messages: - if (m.file_path): - m.file_path.unlink() + if loc != 'mem': + # delete the message files first + rm_messages = self.msg_find(msg_names, loc=loc) + for m in rm_messages: + if (m.file_path): + m.file_path.unlink() # then remove them from the internal list super().msg_remove(msg_names) def msg_latest(self, mfilter: Optional[MessageFilter] = None, - source: msg_place = 'mem') -> Optional[Message]: + loc: msg_location = 'mem') -> Optional[Message]: """ Return the last added message (according to the file ID) that matches the given filter. - Only consider messages with a valid file_path (except if source is 'mem'). - Searches one of the following places: + Only consider messages with a valid file_path (except if loc is 'mem'). + Searches one of the following locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory @@ -452,9 +459,9 @@ class ChatDB(Chat): * 'all' : all messages ('mem' + 'disk') """ # only consider messages with a valid file_path so they can be sorted - source_messages = self.msg_gather(source, require_file_path=True) - source_messages.sort(key=lambda m: m.msg_id(), reverse=True) - for m in source_messages: + loc_messages = self.msg_gather(loc, require_file_path=True) + loc_messages.sort(key=lambda m: m.msg_id(), reverse=True) + for m in loc_messages: if mfilter is None or m.match(mfilter): return m return None diff --git a/tests/test_chat.py b/tests/test_chat.py index 18cc4ef..7a0c94d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -526,43 +526,43 @@ class TestChatDB(unittest.TestCase): chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # search for a DB file in memory - self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) - self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1]) # and on disk - self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) - self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2]) - self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) - self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) + self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2]) + self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2]) # now search the cache -> expect empty result - self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) - self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), []) - self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) - self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) + self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), []) + self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), []) + self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), []) + self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), []) # search for multiple messages # -> search one twice, expect result to be unique search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] expected_result = [self.message1, self.message2, self.message3] - result = chat_db.msg_find(search_names, source='all') + result = chat_db.msg_find(search_names, loc='all') self.assertSequenceEqual(result, expected_result) def test_msg_latest(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.msg_latest(source='mem'), self.message4) - self.assertEqual(chat_db.msg_latest(source='db'), self.message4) - self.assertEqual(chat_db.msg_latest(source='disk'), self.message4) - self.assertEqual(chat_db.msg_latest(source='all'), self.message4) + self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4) + self.assertEqual(chat_db.msg_latest(loc='db'), self.message4) + self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4) + self.assertEqual(chat_db.msg_latest(loc='all'), self.message4) # the cache is currently empty: - self.assertIsNone(chat_db.msg_latest(source='cache')) + self.assertIsNone(chat_db.msg_latest(loc='cache')) # add new messages to the cache dir new_message = Message(question=Question("New Question"), answer=Answer("New Answer")) chat_db.cache_add([new_message]) - self.assertEqual(chat_db.msg_latest(source='cache'), new_message) - self.assertEqual(chat_db.msg_latest(source='mem'), new_message) - self.assertEqual(chat_db.msg_latest(source='disk'), new_message) - self.assertEqual(chat_db.msg_latest(source='all'), new_message) + self.assertEqual(chat_db.msg_latest(loc='cache'), new_message) + self.assertEqual(chat_db.msg_latest(loc='mem'), new_message) + self.assertEqual(chat_db.msg_latest(loc='disk'), new_message) + self.assertEqual(chat_db.msg_latest(loc='all'), new_message) # the DB does not contain the new message - self.assertEqual(chat_db.msg_latest(source='db'), self.message4) + self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)