diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..e70135c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -236,6 +236,25 @@ class Message(): tags = set(sorted(data[cls.tags_yaml_key])) return tags + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None) -> set[Tag]: + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return tags + @classmethod def from_file(cls: Type[MessageInst], file_path: pathlib.Path, mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: diff --git a/tests/test_main.py b/tests/test_main.py index eb13dc5..8ce06cb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,6 +8,7 @@ from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.tags import Tag, TagLine, TagError +from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -345,3 +346,79 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No")