chat: msg_remove() now supports multiple locations
This commit is contained in:
parent
19c2b16301
commit
a571307d77
@ -16,7 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||
|
||||
db_next_file = '.next'
|
||||
ignored_files = [db_next_file, default_config_file]
|
||||
msg_place = Literal['mem', 'disk', 'cache', 'db', 'all']
|
||||
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
|
||||
|
||||
|
||||
class ChatError(Exception):
|
||||
@ -373,11 +373,11 @@ class ChatDB(Chat):
|
||||
self.msg_write(messages)
|
||||
|
||||
def msg_gather(self,
|
||||
source: msg_place,
|
||||
loc: msg_location,
|
||||
require_file_path: bool = False,
|
||||
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||
"""
|
||||
Gather and return messages from the given source:
|
||||
Gather and return messages from the given locations:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
@ -386,19 +386,19 @@ class ChatDB(Chat):
|
||||
|
||||
If 'require_file_path' is True, return only files with a valid file_path.
|
||||
"""
|
||||
source_messages: list[Message] = []
|
||||
if source in ['mem', 'all']:
|
||||
loc_messages: list[Message] = []
|
||||
if loc in ['mem', 'all']:
|
||||
if require_file_path:
|
||||
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
|
||||
loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
|
||||
else:
|
||||
source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
|
||||
if source in ['cache', 'disk', 'all']:
|
||||
source_messages += read_dir(self.cache_path, mfilter=mfilter)
|
||||
if source in ['db', 'disk', 'all']:
|
||||
source_messages += read_dir(self.db_path, mfilter=mfilter)
|
||||
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
|
||||
if loc in ['cache', 'disk', 'all']:
|
||||
loc_messages += read_dir(self.cache_path, mfilter=mfilter)
|
||||
if loc in ['db', 'disk', 'all']:
|
||||
loc_messages += read_dir(self.db_path, mfilter=mfilter)
|
||||
# remove_duplicates and sort the list
|
||||
unique_messages: list[Message] = []
|
||||
for m in source_messages:
|
||||
for m in loc_messages:
|
||||
if not message_in(m, unique_messages):
|
||||
unique_messages.append(m)
|
||||
unique_messages.sort(key=lambda m: m.msg_id())
|
||||
@ -406,32 +406,39 @@ class ChatDB(Chat):
|
||||
|
||||
def msg_find(self,
|
||||
msg_names: list[str],
|
||||
source: msg_place = 'mem',
|
||||
loc: msg_location = 'mem',
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Search and return the messages with the given names. Names can either be filenames
|
||||
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
|
||||
found are ignored (i. e. the caller should check the result if they require all
|
||||
messages).
|
||||
Searches one of the following places:
|
||||
Searches one of the following locations:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
* 'db' : messages in the DB directory
|
||||
* 'all' : all messages ('mem' + 'disk')
|
||||
"""
|
||||
source_messages = self.msg_gather(source, require_file_path=True)
|
||||
return [m for m in source_messages
|
||||
loc_messages = self.msg_gather(loc, require_file_path=True)
|
||||
return [m for m in loc_messages
|
||||
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
|
||||
|
||||
def msg_remove(self, msg_names: list[str]) -> None:
|
||||
def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None:
|
||||
"""
|
||||
Remove the messages with the given names. Names can either be filenames
|
||||
(with or without suffix), full paths or Message.msg_id(). Also deletes the
|
||||
files of all given messages with a valid file_path.
|
||||
Delete files from one of the following locations:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
* 'db' : messages in the DB directory
|
||||
* 'all' : all messages ('mem' + 'disk')
|
||||
"""
|
||||
if loc != 'mem':
|
||||
# delete the message files first
|
||||
rm_messages = self.msg_find(msg_names, source='all')
|
||||
rm_messages = self.msg_find(msg_names, loc=loc)
|
||||
for m in rm_messages:
|
||||
if (m.file_path):
|
||||
m.file_path.unlink()
|
||||
@ -440,11 +447,11 @@ class ChatDB(Chat):
|
||||
|
||||
def msg_latest(self,
|
||||
mfilter: Optional[MessageFilter] = None,
|
||||
source: msg_place = 'mem') -> Optional[Message]:
|
||||
loc: msg_location = 'mem') -> Optional[Message]:
|
||||
"""
|
||||
Return the last added message (according to the file ID) that matches the given filter.
|
||||
Only consider messages with a valid file_path (except if source is 'mem').
|
||||
Searches one of the following places:
|
||||
Only consider messages with a valid file_path (except if loc is 'mem').
|
||||
Searches one of the following locations:
|
||||
* 'mem' : messages currently in memory
|
||||
* 'disk' : messages on disk (cache + DB directory), but not in memory
|
||||
* 'cache': messages in the cache directory
|
||||
@ -452,9 +459,9 @@ class ChatDB(Chat):
|
||||
* 'all' : all messages ('mem' + 'disk')
|
||||
"""
|
||||
# only consider messages with a valid file_path so they can be sorted
|
||||
source_messages = self.msg_gather(source, require_file_path=True)
|
||||
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
|
||||
for m in source_messages:
|
||||
loc_messages = self.msg_gather(loc, require_file_path=True)
|
||||
loc_messages.sort(key=lambda m: m.msg_id(), reverse=True)
|
||||
for m in loc_messages:
|
||||
if mfilter is None or m.match(mfilter):
|
||||
return m
|
||||
return None
|
||||
|
||||
@ -526,43 +526,43 @@ class TestChatDB(unittest.TestCase):
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# search for a DB file in memory
|
||||
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
|
||||
# and on disk
|
||||
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
|
||||
# now search the cache -> expect empty result
|
||||
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
|
||||
# search for multiple messages
|
||||
# -> search one twice, expect result to be unique
|
||||
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
|
||||
expected_result = [self.message1, self.message2, self.message3]
|
||||
result = chat_db.msg_find(search_names, source='all')
|
||||
result = chat_db.msg_find(search_names, loc='all')
|
||||
self.assertSequenceEqual(result, expected_result)
|
||||
|
||||
def test_msg_latest(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(source='db'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(source='disk'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(source='all'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(loc='mem'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(loc='disk'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(loc='all'), self.message4)
|
||||
# the cache is currently empty:
|
||||
self.assertIsNone(chat_db.msg_latest(source='cache'))
|
||||
self.assertIsNone(chat_db.msg_latest(loc='cache'))
|
||||
# add new messages to the cache dir
|
||||
new_message = Message(question=Question("New Question"),
|
||||
answer=Answer("New Answer"))
|
||||
chat_db.cache_add([new_message])
|
||||
self.assertEqual(chat_db.msg_latest(source='cache'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(source='mem'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(source='disk'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(source='all'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(loc='cache'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(loc='mem'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(loc='disk'), new_message)
|
||||
self.assertEqual(chat_db.msg_latest(loc='all'), new_message)
|
||||
# the DB does not contain the new message
|
||||
self.assertEqual(chat_db.msg_latest(source='db'), self.message4)
|
||||
self.assertEqual(chat_db.msg_latest(loc='db'), self.message4)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user