Compare commits

...

2 Commits

4 changed files with 48 additions and 11 deletions

View File

@ -119,6 +119,25 @@ class Chat:
messages: list[Message] messages: list[Message]
def __post_init__(self) -> None:
self.validate()
def validate(self) -> None:
"""
Validate this Chat instance.
"""
def msg_paths(stem: str) -> list[str]:
return [str(fp) for fp in file_paths if fp.stem == stem]
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
error = False
for fp in file_paths:
if file_stems.count(fp.stem) > 1:
print(f"ERROR: File '{fp.stem}' appears twice in the message list: {msg_paths(fp.stem)}")
error = True
if error:
raise ChatError("Validation failed")
def msg_name_matches(self, file_path: Path, name: str) -> bool: def msg_name_matches(self, file_path: Path, name: str) -> bool:
""" """
Return True if the given name matches the given file_path. Return True if the given name matches the given file_path.
@ -194,8 +213,9 @@ class Chat:
def msg_find(self, msg_names: list[str]) -> list[Message]: def msg_find(self, msg_names: list[str]) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(with or without suffix) or full paths. Messages that can't be found are ignored (with or without suffix), full paths or Message.msg_id(). Messages that can't be
(i. e. the caller should check the result if they require all messages). found are ignored (i. e. the caller should check the result if they require all
messages).
""" """
return [m for m in self.messages return [m for m in self.messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
@ -203,7 +223,7 @@ class Chat:
def msg_remove(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str]) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(with or without suffix) or full paths. (with or without suffix), full paths or Message.msg_id().
""" """
self.messages = [m for m in self.messages self.messages = [m for m in self.messages
if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
@ -275,6 +295,7 @@ class ChatDB(Chat):
# make all paths absolute # make all paths absolute
self.cache_path = self.cache_path.absolute() self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute() self.db_path = self.db_path.absolute()
self.validate()
@classmethod @classmethod
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
@ -389,8 +410,9 @@ class ChatDB(Chat):
) -> list[Message]: ) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(with or without suffix) or full paths. Messages that can't be found are ignored (with or without suffix), full paths or Message.msg_id(). Messages that can't be
(i. e. the caller should check the result if they require all messages). found are ignored (i. e. the caller should check the result if they require all
messages).
Searches one of the following sources: Searches one of the following sources:
* 'mem' : messages currently in memory * 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory * 'disk' : messages on disk (cache + DB directory), but not in memory
@ -405,8 +427,8 @@ class ChatDB(Chat):
def msg_remove(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str]) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(with or without suffix) or full paths. Also deletes the files of all given (with or without suffix), full paths or Message.msg_id(). Also deletes the
messages with a valid file_path. files of all given messages with a valid file_path.
""" """
# delete the message files first # delete the message files first
rm_messages = self.msg_find(msg_names, source='all') rm_messages = self.msg_find(msg_names, source='all')

View File

@ -546,10 +546,11 @@ class Message():
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name. The ID is also used for sorting messages. Currently this is the file name without suffix. The ID is also used for sorting
messages.
""" """
if self.file_path: if self.file_path:
return self.file_path.name return self.file_path.stem
else: else:
raise MessageError("Can't create file ID without a file path") raise MessageError("Can't create file ID without a file path")

View File

@ -198,6 +198,16 @@ class TestChatDB(unittest.TestCase):
self.cache_path.cleanup() self.cache_path.cleanup()
pass pass
def test_validate(self) -> None:
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError):
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
def test_from_dir(self) -> None: def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@ -516,18 +526,22 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
# search for a DB file in memory # search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1])
# and on disk # and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2])
# now search the cache -> expect empty result # now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) self.assertEqual(chat_db.msg_find(['0003'], source='cache'), [])
# search for multiple messages # search for multiple messages
search_names = ['0001', '0002.yaml', str(self.message3.file_path)] # -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, source='all') result = chat_db.msg_find(search_names, source='all')
self.assertSequenceEqual(result, expected_result) self.assertSequenceEqual(result, expected_result)

View File

@ -730,7 +730,7 @@ class MessageIDTestCase(unittest.TestCase):
self.file_path.unlink() self.file_path.unlink()
def test_msg_id_txt(self) -> None: def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.name) self.assertEqual(self.message.msg_id(), self.file_path.stem)
def test_msg_id_txt_exception(self) -> None: def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):