Compare commits
2 Commits
b454010d6c
...
eaf04e36df
| Author | SHA1 | Date | |
|---|---|---|---|
| eaf04e36df | |||
| c76bd273e6 |
55
chatmastermind/message.py
Normal file
55
chatmastermind/message.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
Module implementing message related functions and classes.
|
||||||
|
"""
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
QuestionInst = TypeVar('QuestionInst', bound='Question')
|
||||||
|
|
||||||
|
|
||||||
|
class MessageError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def source_code(text: str, include_delims: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extracts all source code sections from the given text, i. e. all lines
|
||||||
|
surrounded by lines tarting with '```'. If 'include_delims' is True,
|
||||||
|
the surrounding lines are included, otherwise they are omitted. The
|
||||||
|
result list contains every source code section as a single string.
|
||||||
|
The order in the list represents the order of the sections in the text.
|
||||||
|
"""
|
||||||
|
code_sections: list[str] = []
|
||||||
|
code_lines: list[str] = []
|
||||||
|
in_code_block = False
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
|
||||||
|
for line in text.split('\n'):
|
||||||
|
if line.strip().startswith("```"):
|
||||||
|
if include_delims:
|
||||||
|
code_lines.append(line)
|
||||||
|
if in_code_block:
|
||||||
|
code_sections.append("\n".join(code_lines))
|
||||||
|
code_lines.clear()
|
||||||
|
print(f" --> New code section:\n{code_sections}")
|
||||||
|
in_code_block = not in_code_block
|
||||||
|
elif in_code_block:
|
||||||
|
code_lines.append(line)
|
||||||
|
|
||||||
|
return code_sections
|
||||||
|
|
||||||
|
|
||||||
|
class Question(str):
|
||||||
|
"""
|
||||||
|
A single question with a defined prefix.
|
||||||
|
"""
|
||||||
|
prefix = '=== QUESTION ==='
|
||||||
|
|
||||||
|
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
||||||
|
"""
|
||||||
|
Make sure the question string does not contain the prefix.
|
||||||
|
"""
|
||||||
|
if cls.prefix in string:
|
||||||
|
raise MessageError(f"Question '{string}' contains the prefix '{cls.prefix}'")
|
||||||
|
instance = super().__new__(cls, string)
|
||||||
|
return instance
|
||||||
@ -8,6 +8,7 @@ from chatmastermind.api_client import ai
|
|||||||
from chatmastermind.configuration import Config
|
from chatmastermind.configuration import Config
|
||||||
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
||||||
from chatmastermind.tags import Tag, TagLine, TagError
|
from chatmastermind.tags import Tag, TagLine, TagError
|
||||||
|
from chatmastermind.message import source_code
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch, MagicMock, Mock, ANY
|
from unittest.mock import patch, MagicMock, Mock, ANY
|
||||||
|
|
||||||
@ -345,3 +346,57 @@ class TestTagLine(CmmTestCase):
|
|||||||
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
|
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
|
||||||
tags_not = {Tag('tag2')}
|
tags_not = {Tag('tag2')}
|
||||||
self.assertFalse(tagline.match_tags(None, None, tags_not))
|
self.assertFalse(tagline.match_tags(None, None, tags_not))
|
||||||
|
|
||||||
|
|
||||||
|
class SourceCodeTestCase(CmmTestCase):
|
||||||
|
def test_source_code_with_include_delims(self) -> None:
|
||||||
|
text = """
|
||||||
|
Some text before the code block
|
||||||
|
```python
|
||||||
|
print("Hello, World!")
|
||||||
|
```
|
||||||
|
Some text after the code block
|
||||||
|
```python
|
||||||
|
x = 10
|
||||||
|
y = 20
|
||||||
|
print(x + y)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected_result = [
|
||||||
|
" ```python\n print(\"Hello, World!\")\n ```",
|
||||||
|
" ```python\n x = 10\n y = 20\n print(x + y)\n```"
|
||||||
|
]
|
||||||
|
result = source_code(text, include_delims=True)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_source_code_without_include_delims(self) -> None:
|
||||||
|
text = """
|
||||||
|
Some text before the code block
|
||||||
|
```python
|
||||||
|
print("Hello, World!")
|
||||||
|
```
|
||||||
|
Some text after the code block
|
||||||
|
```python
|
||||||
|
x = 10
|
||||||
|
y = 20
|
||||||
|
print(x + y)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected_result = [
|
||||||
|
" print(\"Hello, World!\")\n",
|
||||||
|
" x = 10\n y = 20\n print(x + y\n)"
|
||||||
|
]
|
||||||
|
result = source_code(text, include_delims=False)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_source_code_with_single_code_block(self) -> None:
|
||||||
|
text = "```python\nprint(\"Hello, World!\")\n```"
|
||||||
|
expected_result = ["```python\nprint(\"Hello, World!\")\n```"]
|
||||||
|
result = source_code(text, include_delims=True)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_source_code_with_no_code_blocks(self) -> None:
|
||||||
|
text = "Some text without any code blocks"
|
||||||
|
expected_result: list[str] = []
|
||||||
|
result = source_code(text, include_delims=True)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user