diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 8823da4..0aee2fe 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -16,6 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] +valid_sources = Literal['mem', 'disk', 'cache', 'db', 'all'] class ChatError(Exception): @@ -118,6 +119,16 @@ class Chat: messages: list[Message] + def msg_name_matches(self, file_path: Path, name: str) -> bool: + """ + Return True if the given name matches the given file_path. + Matching is True if: + * 'name' matches the full 'file_path' + * 'name' matches 'file_path.name' (i. e. including the suffix) + * 'name' matches 'file_path.stem' (i. e. without a suffix) + """ + return Path(name) == file_path or name == file_path.name or name == file_path.stem + def msg_filter(self, mfilter: MessageFilter) -> None: """ Use 'Message.match(mfilter) to remove all messages that @@ -164,19 +175,19 @@ class Chat: def msg_find(self, msg_names: list[str]) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames - (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the - caller should check the result if he requires all messages). + (with or without suffix) or full paths. Messages that can't be found are ignored + (i. e. the caller should check the result if they require all messages). """ return [m for m in self.messages - if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] def msg_remove(self, msg_names: list[str]) -> None: """ Remove the messages with the given names. Names can either be filenames - (incl. the suffix) or full paths. + (with or without suffix) or full paths. """ self.messages = [m for m in self.messages - if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] self.msg_sort() def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -308,8 +319,8 @@ class ChatDB(Chat): def msg_update(self, messages: list[Message], write: bool = True) -> None: """ Update EXISTING messages. A message is determined as 'existing' if a message with - the same base filename (i. e. 'file_path.name') is already in the list. Only accepts - existing messages. + the same base filename (i. e. 'file_path.name') is already in the list. + Only accepts existing messages. """ if any(not message_in(m, self.messages) for m in messages): raise ChatError("Can't update messages that are not in the internal list") @@ -321,29 +332,80 @@ class ChatDB(Chat): if write: self.msg_write(messages) - def msg_latest(self, - mfilter: Optional[MessageFilter] = None, - source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]: + def msg_gather(self, + source: valid_sources, + require_file_path: bool = False, + mfilter: Optional[MessageFilter] = None) -> list[Message]: """ - Return the last added message (according to the file ID) that matches the given filter. - Only consider messages with a valid file_path (except if source is 'mem'). - Searches one of the following sources: - * 'mem' : only search messages currently in memory - * 'disk' : search messages on disk (cache + DB directory), but not in memory - * 'cache': only search messages in the cache directory - * 'db' : only search messages in the DB directory - * 'all' : search all messages ('mem' + 'disk') + Gather and return messages from the given source: + * 'mem' : messages currently in memory + * 'disk' : messages on disk (cache + DB directory), but not in memory + * 'cache': messages in the cache directory + * 'db' : messages in the DB directory + * 'all' : all messages ('mem' + 'disk') + + If 'require_file_path' is True, return only files with a valid file_path. """ source_messages: list[Message] = [] - if source == 'mem': - return super().msg_latest(mfilter) + if source in ['mem', 'all']: + if require_file_path: + source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] + else: + source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] if source in ['cache', 'disk', 'all']: source_messages += read_dir(self.cache_path, mfilter=mfilter) if source in ['db', 'disk', 'all']: source_messages += read_dir(self.db_path, mfilter=mfilter) - if source in ['all']: - # only consider messages with a valid file_path so they can be sorted - source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] + return source_messages + + def msg_find(self, + msg_names: list[str], + source: valid_sources = 'mem', + ) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (with or without suffix) or full paths. Messages that can't be found are ignored + (i. e. the caller should check the result if they require all messages). + Searches one of the following sources: + * 'mem' : messages currently in memory + * 'disk' : messages on disk (cache + DB directory), but not in memory + * 'cache': messages in the cache directory + * 'db' : messages in the DB directory + * 'all' : all messages ('mem' + 'disk') + """ + source_messages = self.msg_gather(source, require_file_path=True) + return [m for m in source_messages + if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] + + def msg_remove(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (with or without suffix) or full paths. Also deletes the files of all given + messages with a valid file_path. + """ + # delete the message files first + rm_messages = self.msg_find(msg_names, source='all') + for m in rm_messages: + if (m.file_path): + m.file_path.unlink() + # then remove them from the internal list + super().msg_remove(msg_names) + + def msg_latest(self, + mfilter: Optional[MessageFilter] = None, + source: valid_sources = 'mem') -> Optional[Message]: + """ + Return the last added message (according to the file ID) that matches the given filter. + Only consider messages with a valid file_path (except if source is 'mem'). + Searches one of the following sources: + * 'mem' : messages currently in memory + * 'disk' : messages on disk (cache + DB directory), but not in memory + * 'cache': messages in the cache directory + * 'db' : messages in the DB directory + * 'all' : all messages ('mem' + 'disk') + """ + # only consider messages with a valid file_path so they can be sorted + source_messages = self.msg_gather(source, require_file_path=True) source_messages.sort(key=lambda m: m.msg_id(), reverse=True) for m in source_messages: if mfilter is None or m.match(mfilter):