diff --git a/chatmastermind/message.py b/chatmastermind/message.py new file mode 100644 index 0000000..db32b22 --- /dev/null +++ b/chatmastermind/message.py @@ -0,0 +1,55 @@ +""" +Module implementing message related functions and classes. +""" +from typing import Type, TypeVar + +QuestionInst = TypeVar('QuestionInst', bound='Question') + + +class MessageError(Exception): + pass + + +def source_code(text: str, include_delims: bool = False) -> list[str]: + """ + Extracts all source code sections from the given text, i. e. all lines + surrounded by lines tarting with '```'. If 'include_delims' is True, + the surrounding lines are included, otherwise they are omitted. The + result list contains every source code section as a single string. + The order in the list represents the order of the sections in the text. + """ + code_sections: list[str] = [] + code_lines: list[str] = [] + in_code_block = False + + print(text) + + for line in text.split('\n'): + if line.strip().startswith("```"): + if include_delims: + code_lines.append(line) + if in_code_block: + code_sections.append("\n".join(code_lines)) + code_lines.clear() + print(f" --> New code section:\n{code_sections}") + in_code_block = not in_code_block + elif in_code_block: + code_lines.append(line) + + return code_sections + + +class Question(str): + """ + A single question with a defined prefix. + """ + prefix = '=== QUESTION ===' + + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + """ + Make sure the question string does not contain the prefix. + """ + if cls.prefix in string: + raise MessageError(f"Question '{string}' contains the prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance diff --git a/tests/test_main.py b/tests/test_main.py index eb13dc5..f99b015 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,6 +8,7 @@ from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.tags import Tag, TagLine, TagError +from chatmastermind.message import source_code from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -345,3 +346,57 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```", + " ```python\n x = 10\n y = 20\n print(x + y)\n```" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y\n)" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result)