diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..bc13b25 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -219,9 +219,12 @@ class Message(): file_path=data.get(cls.file_yaml_key, None)) @classmethod - def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None) -> set[Tag]: """ - Return only the tags from the given Message file. + Return only the tags from the given Message file, + optionally filtered based on prefix. """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") @@ -229,11 +232,34 @@ class Message(): raise MessageError(f"File type '{file_path.suffix}' is not supported") if file_path.suffix == '.txt': with open(file_path, "r") as fd: - tags = TagLine(fd.readline()).tags() + tags = TagLine(fd.readline()).tags(prefix) else: # '.yaml' with open(file_path, "r") as fd: data = yaml.load(fd, Loader=yaml.FullLoader) - tags = set(sorted(data[cls.tags_yaml_key])) + if prefix and len(prefix) > 0: + tags = set(sorted([t.strip() for t in data[cls.tags_yaml_key] if t.startswith(prefix)])) + else: + tags = set(sorted(data[cls.tags_yaml_key])) + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 544270c..b03fa7f 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -118,9 +118,10 @@ class TagLine(str): """ return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) - def tags(self) -> set[Tag]: + def tags(self, prefix: Optional[str] = None) -> set[Tag]: """ - Returns all tags contained in this line as a set. + Returns all tags contained in this line as a set, optionally + filtered based on prefix. """ tagstr = self[len(self.prefix):].strip() separator = Tag.default_separator @@ -130,7 +131,10 @@ class TagLine(str): if s in tagstr: separator = s break - return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator) if t.startswith(prefix)])) + else: + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) def merge(self, taglines: set['TagLine']) -> 'TagLine': """ diff --git a/tests/test_message.py b/tests/test_message.py index 0e326b4..23e27f3 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -543,7 +543,7 @@ class TagsFromFileTestCase(CmmTestCase): self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: - fd.write(f"""{TagLine.prefix} tag1 tag2 + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 {Question.txt_header} This is a question. {Answer.txt_header} @@ -560,6 +560,7 @@ This is an answer. {Message.tags_yaml_key}: - tag1 - tag2 + - ptag3 """) def tearDown(self) -> None: @@ -570,11 +571,19 @@ This is an answer. def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) class MessageIDTestCase(CmmTestCase): diff --git a/tests/test_tags.py b/tests/test_tags.py index 9ac9746..44e3a83 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -49,6 +49,13 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertEqual(tags, {Tag('stag2'), Tag('stag3')}) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3')