diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 902aaa2..820d104 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -436,13 +436,14 @@ class Message(): Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ - res_tags = self.tags - if res_tags: - if prefix and len(prefix) > 0: - res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} - if contain and len(contain) > 0: - res_tags -= {tag for tag in res_tags if contain not in tag} - return res_tags or set() + if not self.tags: + return set() + res_tags = self.tags.copy() + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ diff --git a/tests/test_message.py b/tests/test_message.py index 83a73ea..2a9d0ff 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -746,3 +746,18 @@ class MessageTagsStrTestCase(CmmTestCase): def test_tags_str(self) -> None: self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') + + +class MessageFilterTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_filter_tags(self) -> None: + tags_all = self.message.filter_tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.message.filter_tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.message.filter_tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')})