From bbc51c2f51b4ef85130d03197c0ce64e41db794b Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 15 Sep 2023 10:17:20 +0200 Subject: [PATCH] chat: added new functions: msg_unique_id(), msg_unique_content() and tests --- chatmastermind/chat.py | 29 ++++++++++++++++- tests/test_chat.py | 73 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 87 insertions(+), 15 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 0aee2fe..083b91e 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -146,6 +146,25 @@ class Chat: except MessageError: pass + def msg_unique_id(self) -> None: + """ + Remove duplicates from the internal messages, based on the msg_id (i. e. file_path). + Messages without a file_path are kept. + """ + old_msgs = self.messages.copy() + self.messages = [] + for m in old_msgs: + if not message_in(m, self.messages): + self.messages.append(m) + self.msg_sort() + + def msg_unique_content(self) -> None: + """ + Remove duplicates from the internal messages, based on the content (i. e. question + answer). + """ + self.messages = list(set(self.messages)) + self.msg_sort() + def msg_clear(self) -> None: """ Delete all messages. @@ -356,7 +375,13 @@ class ChatDB(Chat): source_messages += read_dir(self.cache_path, mfilter=mfilter) if source in ['db', 'disk', 'all']: source_messages += read_dir(self.db_path, mfilter=mfilter) - return source_messages + # remove_duplicates and sort the list + unique_messages: list[Message] = [] + for m in source_messages: + if not message_in(m, unique_messages): + unique_messages.append(m) + unique_messages.sort(key=lambda m: m.msg_id()) + return unique_messages def msg_find(self, msg_names: list[str], @@ -430,6 +455,7 @@ class ChatDB(Chat): Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. + Does NOT add the messages to the internal list (use 'cache_add()' for that)! """ write_dir(self.cache_path, messages if messages else self.messages, @@ -480,6 +506,7 @@ class ChatDB(Chat): Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. + Does NOT add the messages to the internal list (use 'db_add()' for that)! """ write_dir(self.db_path, messages if messages else self.messages, diff --git a/tests/test_chat.py b/tests/test_chat.py index dbd9915..ab37a6b 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -20,6 +20,29 @@ class TestChat(unittest.TestCase): Answer('Answer 2'), {Tag('btag2')}, file_path=pathlib.Path('0002.txt')) + self.maxDiff = None + + def test_unique_id(self) -> None: + # test with two identical messages + self.chat.msg_add([self.message1, self.message1]) + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) + self.chat.msg_unique_id() + self.assertSequenceEqual(self.chat.messages, [self.message1]) + # test with two different messages + self.chat.msg_add([self.message2]) + self.chat.msg_unique_id() + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) + + def test_unique_content(self) -> None: + # test with two identical messages + self.chat.msg_add([self.message1, self.message1]) + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) + self.chat.msg_unique_content() + self.assertSequenceEqual(self.chat.messages, [self.message1]) + # test with two different messages + self.chat.msg_add([self.message2]) + self.chat.msg_unique_content() + self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) def test_filter(self) -> None: self.chat.msg_add([self.message1, self.message2]) @@ -161,6 +184,7 @@ class TestChatDB(unittest.TestCase): for file in self.trash_files: with open(pathlib.Path(self.db_path.name) / file, 'w') as f: f.write('test trash') + self.maxDiff = None def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: """ @@ -174,7 +198,7 @@ class TestChatDB(unittest.TestCase): self.cache_path.cleanup() pass - def test_chat_db_from_dir(self) -> None: + def test_from_dir(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) @@ -190,7 +214,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) - def test_chat_db_from_dir_glob(self) -> None: + def test_from_dir_glob(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), glob='*.txt') @@ -202,7 +226,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_from_dir_filter_tags(self) -> None: + def test_from_dir_filter_tags(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(tags_or={Tag('tag1')})) @@ -212,7 +236,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) - def test_chat_db_from_dir_filter_tags_empty(self) -> None: + def test_from_dir_filter_tags_empty(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(tags_or=set(), @@ -220,7 +244,7 @@ class TestChatDB(unittest.TestCase): tags_not=set())) self.assertEqual(len(chat_db.messages), 0) - def test_chat_db_from_dir_filter_answer(self) -> None: + def test_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -231,7 +255,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messages(self) -> None: + def test_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, @@ -240,7 +264,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) - def test_chat_db_fids(self) -> None: + def test_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.get_next_fid(), 5) @@ -249,7 +273,7 @@ class TestChatDB(unittest.TestCase): with open(chat_db.next_path, 'r') as f: self.assertEqual(f.read(), '7') - def test_chat_db_write(self) -> None: + def test_db_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -297,7 +321,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) - def test_chat_db_read(self) -> None: + def test_db_read(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -360,7 +384,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) - def test_chat_db_clear(self) -> None: + def test_cache_clear(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -405,7 +429,7 @@ class TestChatDB(unittest.TestCase): # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) - def test_chat_db_add(self) -> None: + def test_add(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -436,7 +460,7 @@ class TestChatDB(unittest.TestCase): with self.assertRaises(ChatError): chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))]) - def test_chat_db_write_messages(self) -> None: + def test_msg_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -459,7 +483,7 @@ class TestChatDB(unittest.TestCase): self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) - def test_chat_db_update_messages(self) -> None: + def test_msg_update(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) @@ -487,7 +511,28 @@ class TestChatDB(unittest.TestCase): with self.assertRaises(ChatError): chat_db.msg_update([message1]) - def test_chat_db_latest_message(self) -> None: + def test_msg_find(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # search for a DB file in memory + self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1]) + self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1]) + # and on disk + self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2]) + self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2]) + # now search the cache -> expect empty result + self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), []) + self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), []) + self.assertEqual(chat_db.msg_find(['0003'], source='cache'), []) + # search for multiple messages + search_names = ['0001', '0002.yaml', str(self.message3.file_path)] + expected_result = [self.message1, self.message2, self.message3] + result = chat_db.msg_find(search_names, source='all') + self.assertSequenceEqual(result, expected_result) + + def test_msg_latest(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)