chat: msg_remove() now supports multiple locations

This commit is contained in:
juk0de 2023-09-15 16:00:05 +02:00
parent 19c2b16301
commit a571307d77
2 changed files with 58 additions and 51 deletions

View File

@ -16,7 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
msg_place = Literal['mem', 'disk', 'cache', 'db', 'all'] msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception): class ChatError(Exception):
@ -373,11 +373,11 @@ class ChatDB(Chat):
self.msg_write(messages) self.msg_write(messages)
def msg_gather(self, def msg_gather(self,
source: msg_place, loc: msg_location,
require_file_path: bool = False, require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
Gather and return messages from the given source: Gather and return messages from the given locations:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
@ -386,19 +386,19 @@ class ChatDB(Chat):
If 'require_file_path' is True, return only files with a valid file_path. If 'require_file_path' is True, return only files with a valid file_path.
""" """
source_messages: list[Message] = [] loc_messages: list[Message] = []
if source in ['mem', 'all']: if loc in ['mem', 'all']:
if require_file_path: if require_file_path:
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else: else:
source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if source in ['cache', 'disk', 'all']: if loc in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter) loc_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']: if loc in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter) loc_messages += read_dir(self.db_path, mfilter=mfilter)
# remove_duplicates and sort the list # remove_duplicates and sort the list
unique_messages: list[Message] = [] unique_messages: list[Message] = []
for m in source_messages: for m in loc_messages:
if not message_in(m, unique_messages): if not message_in(m, unique_messages):
unique_messages.append(m) unique_messages.append(m)
unique_messages.sort(key=lambda m: m.msg_id()) unique_messages.sort(key=lambda m: m.msg_id())
@ -406,32 +406,39 @@ class ChatDB(Chat):
def msg_find(self, def msg_find(self,
msg_names: list[str], msg_names: list[str],
source: msg_place = 'mem', loc: msg_location = 'mem',
) -> list[Message]: ) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be (with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all found are ignored (i. e. the caller should check the result if they require all
messages). messages).
Searches one of the following places: Searches one of the following locations:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
* 'db' : messages in the DB directory * 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk') * 'all' : all messages ('mem' + 'disk')
""" """
source_messages = self.msg_gather(source, require_file_path=True) loc_messages = self.msg_gather(loc, require_file_path=True)
return [m for m in source_messages return [m for m in loc_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the (with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path. files of all given messages with a valid file_path.
Delete files from one of the following locations:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
""" """
if loc != 'mem':
# delete the message files first # delete the message files first
rm_messages = self.msg_find(msg_names, source='all') rm_messages = self.msg_find(msg_names, loc=loc)
for m in rm_messages: for m in rm_messages:
if (m.file_path): if (m.file_path):
m.file_path.unlink() m.file_path.unlink()
@ -440,11 +447,11 @@ class ChatDB(Chat):
def msg_latest(self, def msg_latest(self,
mfilter: Optional[MessageFilter] = None, mfilter: Optional[MessageFilter] = None,
source: msg_place = 'mem') -> Optional[Message]: loc: msg_location = 'mem') -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem'). Only consider messages with a valid file_path (except if loc is 'mem').
Searches one of the following places: Searches one of the following locations:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
@ -452,9 +459,9 @@ class ChatDB(Chat):
* 'all' : all messages ('mem' + 'disk') * 'all' : all messages ('mem' + 'disk')
""" """
# only consider messages with a valid file_path so they can be sorted # only consider messages with a valid file_path so they can be sorted
source_messages = self.msg_gather(source, require_file_path=True) loc_messages = self.msg_gather(loc, require_file_path=True)
source_messages.sort(key=lambda m: m.msg_id(), reverse=True) loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages: for m in loc_messages:
if mfilter is None or m.match(mfilter): if mfilter is None or m.match(mfilter):
return m return m
return None return None

View File

@ -526,43 +526,43 @@ class TestChatDB(unittest.TestCase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# search for a DB file in memory # search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk # and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result # now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), []) self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages # search for multiple messages
# -> search one twice, expect result to be unique # -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, source='all') result = chat_db.msg_find(search_names, loc='all')
self.assertSequenceEqual(result, expected_result) self.assertSequenceEqual(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(source='mem'), self.message4) self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
self.assertEqual(chat_db.msg_latest(source='disk'), self.message4) self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(source='all'), self.message4) self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
# the cache is currently empty: # the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(source='cache')) self.assertIsNone(chat_db.msg_latest(loc='cache'))
# add new messages to the cache dir # add new messages to the cache dir
new_message = Message(question=Question("New Question"), new_message = Message(question=Question("New Question"),
answer=Answer("New Answer")) answer=Answer("New Answer"))
chat_db.cache_add([new_message]) chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(source='cache'), new_message) self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
self.assertEqual(chat_db.msg_latest(source='mem'), new_message) self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
self.assertEqual(chat_db.msg_latest(source='disk'), new_message) self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
self.assertEqual(chat_db.msg_latest(source='all'), new_message) self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
# the DB does not contain the new message # the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)