From a5ca1c11071dbb2e9e913e024b93b3bb671eb456 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 26 Aug 2023 12:50:47 +0200 Subject: [PATCH] Added prefix filtering to TagLine.tags() and Message.tags_from_file() --- chatmastermind/message.py | 34 ++++++++++++++++++--- chatmastermind/tags.py | 10 +++++-- tests/test_message.py | 63 +++++++++++++++++++++++++++++++++++++-- tests/test_tags.py | 7 +++++ 4 files changed, 104 insertions(+), 10 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..bc13b25 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -219,9 +219,12 @@ class Message(): file_path=data.get(cls.file_yaml_key, None)) @classmethod - def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None) -> set[Tag]: """ - Return only the tags from the given Message file. + Return only the tags from the given Message file, + optionally filtered based on prefix. """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") @@ -229,11 +232,34 @@ class Message(): raise MessageError(f"File type '{file_path.suffix}' is not supported") if file_path.suffix == '.txt': with open(file_path, "r") as fd: - tags = TagLine(fd.readline()).tags() + tags = TagLine(fd.readline()).tags(prefix) else: # '.yaml' with open(file_path, "r") as fd: data = yaml.load(fd, Loader=yaml.FullLoader) - tags = set(sorted(data[cls.tags_yaml_key])) + if prefix and len(prefix) > 0: + tags = set(sorted([t.strip() for t in data[cls.tags_yaml_key] if t.startswith(prefix)])) + else: + tags = set(sorted(data[cls.tags_yaml_key])) + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 544270c..b03fa7f 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -118,9 +118,10 @@ class TagLine(str): """ return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) - def tags(self) -> set[Tag]: + def tags(self, prefix: Optional[str] = None) -> set[Tag]: """ - Returns all tags contained in this line as a set. + Returns all tags contained in this line as a set, optionally + filtered based on prefix. """ tagstr = self[len(self.prefix):].strip() separator = Tag.default_separator @@ -130,7 +131,10 @@ class TagLine(str): if s in tagstr: separator = s break - return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator) if t.startswith(prefix)])) + else: + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) def merge(self, taglines: set['TagLine']) -> 'TagLine': """ diff --git a/tests/test_message.py b/tests/test_message.py index 0e326b4..f13a33d 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -543,7 +543,7 @@ class TagsFromFileTestCase(CmmTestCase): self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: - fd.write(f"""{TagLine.prefix} tag1 tag2 + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 {Question.txt_header} This is a question. {Answer.txt_header} @@ -560,6 +560,7 @@ This is an answer. {Message.tags_yaml_key}: - tag1 - tag2 + - ptag3 """) def tearDown(self) -> None: @@ -570,11 +571,67 @@ This is an answer. def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + + +class TestTagsFromFile(CmmTestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_dir_no_tags = tempfile.TemporaryDirectory() + self.tag_sets = [ + {Tag('atag1'), Tag('atag2')}, + {Tag('btag3'), Tag('btag4')}, + {Tag('ctag5'), Tag('ctag6')} + ] + self.files = [ + pathlib.Path(self.temp_dir.name, 'file1.txt'), + pathlib.Path(self.temp_dir.name, 'file2.yaml'), + pathlib.Path(self.temp_dir.name, 'file3.txt') + ] + self.files_no_tags = [ + pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), + pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), + pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + ] + for file, tags in zip(self.files, self.tag_sets): + message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags) + message.to_file(file) + for file in self.files_no_tags: + message = Message(Question('This is a question.'), + Answer('This is an answer.')) + message.to_file(file) + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def test_tags_from_dir(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) + expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2] + self.assertEqual(all_tags, expected_tags) + + def test_tags_from_dir_prefix(self) -> None: + atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a') + expected_tags = self.tag_sets[0] + self.assertEqual(atags, expected_tags) + + # FIXME + # def test_tags_from_dir_no_tags(self) -> None: + # all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name)) + # self.assertSetEqual(all_tags, set()) class MessageIDTestCase(CmmTestCase): diff --git a/tests/test_tags.py b/tests/test_tags.py index 9ac9746..44e3a83 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -49,6 +49,13 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertEqual(tags, {Tag('stag2'), Tag('stag3')}) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3')