From dd661f822b4301c976f99da71cfb662e4580e194 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 19 Sep 2023 09:48:12 +0200 Subject: [PATCH] chat: improved message equality checks --- tests/test_chat.py | 83 ++++++++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index 4558e4d..a69f92c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -10,7 +10,18 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, ChatError -class TestChat(unittest.TestCase): +class TestChatBase(unittest.TestCase): + def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: + """ + Compare messages using more than just Question and Answer. + """ + self.assertEqual(len(msg1), len(msg2)) + for m1, m2 in zip(msg1, msg2): + # exclude the file_path, compare only Q, A and metadata + self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) + + +class TestChat(TestChatBase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), @@ -26,24 +37,24 @@ class TestChat(unittest.TestCase): def test_unique_id(self) -> None: # test with two identical messages self.chat.msg_add([self.message1, self.message1]) - self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) + self.assert_messages_equal(self.chat.messages, [self.message1, self.message1]) self.chat.msg_unique_id() - self.assertSequenceEqual(self.chat.messages, [self.message1]) + self.assert_messages_equal(self.chat.messages, [self.message1]) # test with two different messages self.chat.msg_add([self.message2]) self.chat.msg_unique_id() - self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) + self.assert_messages_equal(self.chat.messages, [self.message1, self.message2]) def test_unique_content(self) -> None: # test with two identical messages self.chat.msg_add([self.message1, self.message1]) - self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) + self.assert_messages_equal(self.chat.messages, [self.message1, self.message1]) self.chat.msg_unique_content() - self.assertSequenceEqual(self.chat.messages, [self.message1]) + self.assert_messages_equal(self.chat.messages, [self.message1]) # test with two different messages self.chat.msg_add([self.message2]) self.chat.msg_unique_content() - self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) + self.assert_messages_equal(self.chat.messages, [self.message1, self.message2]) def test_filter(self) -> None: self.chat.msg_add([self.message1, self.message2]) @@ -150,7 +161,7 @@ Answer 2 self.assertEqual(mock_stdout.getvalue(), expected_output) -class TestChatDB(unittest.TestCase): +class TestChatDB(TestChatBase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() @@ -569,7 +580,7 @@ class TestChatDB(unittest.TestCase): search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] expected_result = [self.message1, self.message2, self.message3] result = chat_db.msg_find(search_names, loc='all') - self.assertSequenceEqual(result, expected_result) + self.assert_messages_equal(result, expected_result) def test_msg_latest(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), @@ -595,47 +606,47 @@ class TestChatDB(unittest.TestCase): chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) all_messages = [self.message1, self.message2, self.message3, self.message4] - self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) # add a new message, but only to the internal list new_message = Message(Question("What?")) all_messages_mem = all_messages + [new_message] chat_db.msg_add([new_message]) - self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem) - self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem) + self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem) + self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem) # the nr. of messages on disk did not change -> expect old result - self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) # test with MessageFilter - self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), - [self.message1]) - self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), - [self.message2]) - self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), - []) - self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), - [new_message]) + self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), + [self.message1]) + self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), + [self.message2]) + self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), + []) + self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), + [new_message]) def test_msg_move_and_gather(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) all_messages = [self.message1, self.message2, self.message3, self.message4] - self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) # move first message to the cache chat_db.cache_move(self.message1) - self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1]) + self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1]) self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] - self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) - self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) - self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) + self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages) # now move first message back to the DB chat_db.db_move(self.message1) - self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) + self.assert_messages_equal(chat_db.msg_gather(loc='cache'), []) self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] - self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) + self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)