From 3fc23feae3fd52f5153bddb29ba355bdfdb846f1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:24:20 +0200 Subject: [PATCH] message: fixed matching with empty tag sets --- chatmastermind/message.py | 4 ++-- tests/test_chat.py | 22 ++++++++++++++++++++-- tests/test_message.py | 6 ++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 7107c13..df59ed6 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -312,7 +312,7 @@ class Message(): mfilter.tags_not if mfilter else None) else: message = cls.__from_file_yaml(file_path) - if message and (not mfilter or (mfilter and message.match(mfilter))): + if message and (mfilter is None or message.match(mfilter)): return message else: return None @@ -508,7 +508,7 @@ class Message(): Return True if all attributes match, else False. """ mytags = self.tags or set() - if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 diff --git a/tests/test_chat.py b/tests/test_chat.py index ed630a4..1916a2b 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_filter(self) -> None: + def test_chat_db_from_dir_filter_tags(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or={Tag('tag1')})) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + + def test_chat_db_from_dir_filter_tags_empty(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or=set(), + tags_and=set(), + tags_not=set())) + self.assertEqual(len(chat_db.messages), 0) + + def test_chat_db_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messges(self) -> None: + def test_chat_db_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, diff --git a/tests/test_message.py b/tests/test_message.py index 57d5982..1f440df 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -300,6 +300,12 @@ This is a question. MessageFilter(tags_or={Tag('tag1')})) self.assertIsNone(message) + def test_from_file_txt_empty_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or=set(), + tags_and=set())) + self.assertIsNone(message) + def test_from_file_txt_no_tags_match_tags_not(self) -> None: message = Message.from_file(self.file_path_min, MessageFilter(tags_not={Tag('tag1')}))