chat: added new functions to ChatDB: msg_gather(), msg_find(), msg_remove()

This commit is contained in:
juk0de 2023-09-15 09:28:39 +02:00
parent 378bba6002
commit 44fbff33fe

View File

@ -16,6 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
valid_sources = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception): class ChatError(Exception):
@ -118,6 +119,16 @@ class Chat:
messages: list[Message] messages: list[Message]
def msg_name_matches(self, file_path: Path, name: str) -> bool:
"""
Return True if the given name matches the given file_path.
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
def msg_filter(self, mfilter: MessageFilter) -> None: def msg_filter(self, mfilter: MessageFilter) -> None:
""" """
Use 'Message.match(mfilter) to remove all messages that Use 'Message.match(mfilter) to remove all messages that
@ -164,19 +175,19 @@ class Chat:
def msg_find(self, msg_names: list[str]) -> list[Message]: def msg_find(self, msg_names: list[str]) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the (with or without suffix) or full paths. Messages that can't be found are ignored
caller should check the result if he requires all messages). (i. e. the caller should check the result if they require all messages).
""" """
return [m for m in self.messages return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str]) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths. (with or without suffix) or full paths.
""" """
self.messages = [m for m in self.messages self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
self.msg_sort() self.msg_sort()
def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
@ -308,8 +319,8 @@ class ChatDB(Chat):
def msg_update(self, messages: list[Message], write: bool = True) -> None: def msg_update(self, messages: list[Message], write: bool = True) -> None:
""" """
Update EXISTING messages. A message is determined as 'existing' if a message with Update EXISTING messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts the same base filename (i. e. 'file_path.name') is already in the list.
existing messages. Only accepts existing messages.
""" """
if any(not message_in(m, self.messages) for m in messages): if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list") raise ChatError("Can't update messages that are not in the internal list")
@ -321,29 +332,80 @@ class ChatDB(Chat):
if write: if write:
self.msg_write(messages) self.msg_write(messages)
def msg_latest(self, def msg_gather(self,
mfilter: Optional[MessageFilter] = None, source: valid_sources,
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]: require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Gather and return messages from the given source:
Only consider messages with a valid file_path (except if source is 'mem'). * 'mem' : messages currently in memory
Searches one of the following sources: * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'mem' : only search messages currently in memory * 'cache': messages in the cache directory
* 'disk' : search messages on disk (cache + DB directory), but not in memory * 'db' : messages in the DB directory
* 'cache': only search messages in the cache directory * 'all' : all messages ('mem' + 'disk')
* 'db' : only search messages in the DB directory
* 'all' : search all messages ('mem' + 'disk') If 'require_file_path' is True, return only files with a valid file_path.
""" """
source_messages: list[Message] = [] source_messages: list[Message] = []
if source == 'mem': if source in ['mem', 'all']:
return super().msg_latest(mfilter) if require_file_path:
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if source in ['cache', 'disk', 'all']: if source in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter) source_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']: if source in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter) source_messages += read_dir(self.db_path, mfilter=mfilter)
if source in ['all']: return source_messages
def msg_find(self,
msg_names: list[str],
source: valid_sources = 'mem',
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix) or full paths. Messages that can't be found are ignored
(i. e. the caller should check the result if they require all messages).
Searches one of the following sources:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
source_messages = self.msg_gather(source, require_file_path=True)
return [m for m in source_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix) or full paths. Also deletes the files of all given
messages with a valid file_path.
"""
# delete the message files first
rm_messages = self.msg_find(msg_names, source='all')
for m in rm_messages:
if (m.file_path):
m.file_path.unlink()
# then remove them from the internal list
super().msg_remove(msg_names)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
source: valid_sources = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem').
Searches one of the following sources:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
# only consider messages with a valid file_path so they can be sorted # only consider messages with a valid file_path so they can be sorted
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] source_messages = self.msg_gather(source, require_file_path=True)
source_messages.sort(key=lambda m: m.msg_id(), reverse=True) source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages: for m in source_messages:
if mfilter is None or m.match(mfilter): if mfilter is None or m.match(mfilter):