Compare commits

..

6 Commits

2 changed files with 14 additions and 6 deletions

View File

@ -320,8 +320,11 @@ class ChatDB(Chat):
def add_to_db(self, messages: list[Message], write: bool = True) -> None: def add_to_db(self, messages: list[Message], write: bool = True) -> None:
""" """
Adds the given messages and sets the file_path to the DB directory. Adds the given new messages and sets the file_path to the DB directory.
Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write: if write:
write_dir(self.db_path, write_dir(self.db_path,
messages, messages,
@ -335,8 +338,11 @@ class ChatDB(Chat):
def add_to_cache(self, messages: list[Message], write: bool = True) -> None: def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
""" """
Adds the given messages and sets the file_path to the cache directory. Adds the given new messages and sets the file_path to the cache directory.
Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write: if write:
write_dir(self.cache_path, write_dir(self.cache_path,
messages, messages,

View File

@ -369,16 +369,18 @@ class TestChatDB(CmmTestCase):
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
chat_db.add_to_cache([message1]) chat_db.add_to_cache([message1])
# check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIsNotNone(chat_db.messages[4].file_path)
self.assertEqual(chat_db.messages[4].file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
# add new messages to the DB dir # add new messages to the DB dir
message2 = Message(question=Question("Question 2"), message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2")) answer=Answer("Answer 2"))
chat_db.add_to_db([message2]) chat_db.add_to_db([message2])
# check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
self.assertIsNotNone(chat_db.messages[5].file_path)
self.assertEqual(chat_db.messages[5].file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]