Compare commits

..

2 Commits

2 changed files with 9 additions and 8 deletions

View File

@ -16,7 +16,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next' db_next_file = '.next'
ignored_files = [db_next_file, default_config_file] ignored_files = [db_next_file, default_config_file]
valid_sources = Literal['mem', 'disk', 'cache', 'db', 'all'] msg_place = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception): class ChatError(Exception):
@ -133,7 +133,7 @@ class Chat:
error = False error = False
for fp in file_paths: for fp in file_paths:
if file_stems.count(fp.stem) > 1: if file_stems.count(fp.stem) > 1:
print(f"ERROR: File '{fp.stem}' appears twice in the message list: {msg_paths(fp.stem)}") print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
error = True error = True
if error: if error:
raise ChatError("Validation failed") raise ChatError("Validation failed")
@ -373,7 +373,7 @@ class ChatDB(Chat):
self.msg_write(messages) self.msg_write(messages)
def msg_gather(self, def msg_gather(self,
source: valid_sources, source: msg_place,
require_file_path: bool = False, require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]: mfilter: Optional[MessageFilter] = None) -> list[Message]:
""" """
@ -406,14 +406,14 @@ class ChatDB(Chat):
def msg_find(self, def msg_find(self,
msg_names: list[str], msg_names: list[str],
source: valid_sources = 'mem', source: msg_place = 'mem',
) -> list[Message]: ) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be (with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all found are ignored (i. e. the caller should check the result if they require all
messages). messages).
Searches one of the following sources: Searches one of the following places:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory
@ -440,11 +440,11 @@ class ChatDB(Chat):
def msg_latest(self, def msg_latest(self,
mfilter: Optional[MessageFilter] = None, mfilter: Optional[MessageFilter] = None,
source: valid_sources = 'mem') -> Optional[Message]: source: msg_place = 'mem') -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem'). Only consider messages with a valid file_path (except if source is 'mem').
Searches one of the following sources: Searches one of the following places:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory * 'cache': messages in the cache directory

View File

@ -204,9 +204,10 @@ class TestChatDB(unittest.TestCase):
{Tag('tag4')}, {Tag('tag4')},
file_path=pathlib.Path('0004.txt')) file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt')) duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError): with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name), ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_from_dir(self) -> None: def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),