Compare commits

...

2 Commits

Author SHA1 Message Date
eb3dd98adb added testcases for messages.py 2023-08-18 16:35:34 +02:00
5862071b3b added new module 'message.py' 2023-08-18 16:35:34 +02:00
2 changed files with 107 additions and 0 deletions

52
chatmastermind/message.py Normal file
View File

@ -0,0 +1,52 @@
"""
Module implementing message related functions and classes.
"""
from typing import Type, TypeVar
QuestionInst = TypeVar('QuestionInst', bound='Question')
class MessageError(Exception):
pass
def source_code(text: str, include_delims: bool = False) -> list[str]:
"""
Extracts all source code sections from the given text, i. e. all lines
surrounded by lines tarting with '```'. If 'include_delims' is True,
the surrounding lines are included, otherwise they are omitted. The
result list contains every source code section as a single string.
The order in the list represents the order of the sections in the text.
"""
code_sections: list[str] = []
code_lines: list[str] = []
in_code_block = False
for line in text.split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
if in_code_block:
code_sections.append('\n'.join(code_lines) + '\n')
code_lines.clear()
in_code_block = not in_code_block
elif in_code_block:
code_lines.append(line)
return code_sections
class Question(str):
"""
A single question with a defined prefix.
"""
prefix = '=== QUESTION ==='
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
"""
Make sure the question string does not contain the prefix.
"""
if cls.prefix in string:
raise MessageError(f"Question '{string}' contains the prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance

View File

@ -8,6 +8,7 @@ from chatmastermind.api_client import ai
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from chatmastermind.tags import Tag, TagLine, TagError from chatmastermind.tags import Tag, TagLine, TagError
from chatmastermind.message import source_code
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY from unittest.mock import patch, MagicMock, Mock, ANY
@ -345,3 +346,57 @@ class TestTagLine(CmmTestCase):
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')} tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not)) self.assertFalse(tagline.match_tags(None, None, tags_not))
class SourceCodeTestCase(CmmTestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" ```python\n print(\"Hello, World!\")\n ```\n",
" ```python\n x = 10\n y = 20\n print(x + y)\n ```\n"
]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_without_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" print(\"Hello, World!\")\n",
" x = 10\n y = 20\n print(x + y)\n"
]
result = source_code(text, include_delims=False)
self.assertEqual(result, expected_result)
def test_source_code_with_single_code_block(self) -> None:
text = "```python\nprint(\"Hello, World!\")\n```"
expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_with_no_code_blocks(self) -> None:
text = "Some text without any code blocks"
expected_result: list[str] = []
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)