chat: added new functions to ChatDB: msg_gather(), msg_find(), msg_remove()
This commit is contained in:
parent
bbcff17558
commit
454ce84d71
@ -16,6 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||
|
||||
db_next_file = '.next'
|
||||
ignored_files = [db_next_file, default_config_file]
|
||||
valid_sources = Literal['mem', 'disk', 'cache', 'db', 'all']
|
||||
|
||||
|
||||
class ChatError(Exception):
|
||||
@ -118,6 +119,16 @@ class Chat:
|
||||
|
||||
messages: list[Message]
|
||||
|
||||
def msg_name_matches(self, file_path: Path, name: str) -> bool:
|
||||
"""
|
||||
Return True if the given name matches the given file_path.
|
||||
Matching is True if:
|
||||
* 'name' matches the full 'file_path'
|
||||
* 'name' matches 'file_path.name' (i. e. including the suffix)
|
||||
* 'name' matches 'file_path.stem' (i. e. without a suffix)
|
||||
"""
|
||||
return Path(name) == file_path or name == file_path.name or name == file_path.stem
|
||||
|
||||
def msg_filter(self, mfilter: MessageFilter) -> None:
|
||||
"""
|
||||
Use 'Message.match(mfilter) to remove all messages that
|
||||
@ -164,19 +175,19 @@ class Chat:
|
||||
def msg_find(self, msg_names: list[str]) -> list[Message]:
|
||||
"""
|
||||
Search and return the messages with the given names. Names can either be filenames
|
||||
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
|
||||
caller should check the result if he requires all messages).
|
||||
(with or without suffix) or full paths. Messages that can't be found are ignored
|
||||
(i. e. the caller should check the result if they require all messages).
|
||||
"""
|
||||
return [m for m in self.messages
|
||||
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
|
||||
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||
|
||||
def msg_remove(self, msg_names: list[str]) -> None:
|
||||
"""
|
||||
Remove the messages with the given names. Names can either be filenames
|
||||
(incl. the suffix) or full paths.
|
||||
(with or without suffix) or full paths.
|
||||
"""
|
||||
self.messages = [m for m in self.messages
|
||||
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
|
||||
if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||
self.msg_sort()
|
||||
|
||||
def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
@ -308,8 +319,8 @@ class ChatDB(Chat):
|
||||
def msg_update(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
Update EXISTING messages. A message is determined as 'existing' if a message with
|
||||
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
|
||||
existing messages.
|
||||
the same base filename (i. e. 'file_path.name') is already in the list.
|
||||
Only accepts existing messages.
|
||||
"""
|
||||
if any(not message_in(m, self.messages) for m in messages):
|
||||
raise ChatError("Can't update messages that are not in the internal list")
|
||||
@ -321,29 +332,80 @@ class ChatDB(Chat):
|
||||
if write:
|
||||
self.msg_write(messages)
|
||||
|
||||
def msg_latest(self,
|
||||
mfilter: Optional[MessageFilter] = None,
|
||||
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]:
|
||||
def msg_gather(self,
|
||||
source: valid_sources,
|
||||
require_file_path: bool = False,
|
||||
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||
"""
|
||||
Return the last added message (according to the file ID) that matches the given filter.
|
||||
Only consider messages with a valid file_path (except if source is 'mem').
|
||||
Searches one of the following sources:
|
||||
* 'mem' : only search messages currently in memory
|
||||
* 'disk' : search messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': only search messages in the cache directory
|
||||
* 'db' : only search messages in the DB directory
|
||||
* 'all' : search all messages ('mem' + 'disk')
|
||||
Gather and return messages from the given source:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
* 'db' : messages in the DB directory
|
||||
* 'all' : all messages ('mem' + 'disk')
|
||||
|
||||
If 'require_file_path' is True, return only files with a valid file_path.
|
||||
"""
|
||||
source_messages: list[Message] = []
|
||||
if source == 'mem':
|
||||
return super().msg_latest(mfilter)
|
||||
if source in ['mem', 'all']:
|
||||
if require_file_path:
|
||||
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
|
||||
else:
|
||||
source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
|
||||
if source in ['cache', 'disk', 'all']:
|
||||
source_messages += read_dir(self.cache_path, mfilter=mfilter)
|
||||
if source in ['db', 'disk', 'all']:
|
||||
source_messages += read_dir(self.db_path, mfilter=mfilter)
|
||||
if source in ['all']:
|
||||
# only consider messages with a valid file_path so they can be sorted
|
||||
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
|
||||
return source_messages
|
||||
|
||||
def msg_find(self,
|
||||
msg_names: list[str],
|
||||
source: valid_sources = 'mem',
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Search and return the messages with the given names. Names can either be filenames
|
||||
(with or without suffix) or full paths. Messages that can't be found are ignored
|
||||
(i. e. the caller should check the result if they require all messages).
|
||||
Searches one of the following sources:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
* 'db' : messages in the DB directory
|
||||
* 'all' : all messages ('mem' + 'disk')
|
||||
"""
|
||||
source_messages = self.msg_gather(source, require_file_path=True)
|
||||
return [m for m in source_messages
|
||||
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||
|
||||
def msg_remove(self, msg_names: list[str]) -> None:
|
||||
"""
|
||||
Remove the messages with the given names. Names can either be filenames
|
||||
(with or without suffix) or full paths. Also deletes the files of all given
|
||||
messages with a valid file_path.
|
||||
"""
|
||||
# delete the message files first
|
||||
rm_messages = self.msg_find(msg_names, source='all')
|
||||
for m in rm_messages:
|
||||
if (m.file_path):
|
||||
m.file_path.unlink()
|
||||
# then remove them from the internal list
|
||||
super().msg_remove(msg_names)
|
||||
|
||||
def msg_latest(self,
|
||||
mfilter: Optional[MessageFilter] = None,
|
||||
source: valid_sources = 'mem') -> Optional[Message]:
|
||||
"""
|
||||
Return the last added message (according to the file ID) that matches the given filter.
|
||||
Only consider messages with a valid file_path (except if source is 'mem').
|
||||
Searches one of the following sources:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
* 'db' : messages in the DB directory
|
||||
* 'all' : all messages ('mem' + 'disk')
|
||||
"""
|
||||
# only consider messages with a valid file_path so they can be sorted
|
||||
source_messages = self.msg_gather(source, require_file_path=True)
|
||||
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
|
||||
for m in source_messages:
|
||||
if mfilter is None or m.match(mfilter):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user