From cc76da2ab36ae3cef44bd203018656d3a39501d0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:39:00 +0200 Subject: [PATCH] chat: added 'update_messages()' function and test --- chatmastermind/chat.py | 16 ++++++++++++++++ tests/test_chat.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4e8fb20..ddabb56 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -386,3 +386,19 @@ class ChatDB(Chat): msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): m.to_file() + + def update_messages(self, messages: list[Message], write: bool = True) -> None: + """ + Update existing messages. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. Only accepts + existing messages. + """ + if any(not message_in(m, self.messages) for m in messages): + raise ChatError("Can't update messages that are not in the internal list") + # remove old versions and add new ones + self.messages = [m for m in self.messages if not message_in(m, messages)] + self.messages += messages + self.sort() + # write the UPDATED messages if requested + if write: + self.write_messages(messages) diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e4aa8c..ed630a4 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -440,3 +440,31 @@ class TestChatDB(unittest.TestCase): cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + + def test_chat_db_update_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + message = chat_db.messages[0] + message.answer = Answer("New answer") + # update message without writing + chat_db.update_messages([message], write=False) + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # re-read the message and check for old content + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) + # now check with writing (message should be overwritten) + chat_db.update_messages([message], write=True) + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # test without file_path -> expect error + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.update_messages([message1])