Compare commits

...

4 Commits

4 changed files with 306 additions and 5 deletions

View File

@ -64,3 +64,6 @@ class Config():
def to_file(self, path: str) -> None:
with open(path, 'w') as f:
yaml.dump(asdict(self), f)
def as_dict(self) -> dict[str, Any]:
return asdict(self)

210
chatmastermind/message.py Normal file
View File

@ -0,0 +1,210 @@
"""
Module implementing message related functions and classes.
"""
import pathlib
import yaml
from typing import Type, TypeVar, ClassVar, Optional, Any
from dataclasses import dataclass, asdict
from .tags import Tag, TagLine
QuestionInst = TypeVar('QuestionInst', bound='Question')
AnswerInst = TypeVar('AnswerInst', bound='Answer')
MessageInst = TypeVar('MessageInst', bound='Message')
class MessageError(Exception):
pass
def source_code(text: str, include_delims: bool = False) -> list[str]:
"""
Extract all source code sections from the given text, i. e. all lines
surrounded by lines tarting with '```'. If 'include_delims' is True,
the surrounding lines are included, otherwise they are omitted. The
result list contains every source code section as a single string.
The order in the list represents the order of the sections in the text.
"""
code_sections: list[str] = []
code_lines: list[str] = []
in_code_block = False
for line in text.split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
if in_code_block:
code_sections.append('\n'.join(code_lines) + '\n')
code_lines.clear()
in_code_block = not in_code_block
elif in_code_block:
code_lines.append(line)
return code_sections
class Question(str):
"""
A single question with a defined header.
"""
header: ClassVar[str] = '=== QUESTION ==='
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
"""
Make sure the question string does not contain the header.
"""
if cls.header in string:
raise MessageError(f"Question '{string}' contains the header '{cls.header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
class Answer(str):
"""
A single answer with a defined header.
"""
header: ClassVar[str] = '=== ANSWER ==='
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Make sure the answer string does not contain the header.
"""
if cls.header in string:
raise MessageError(f"Answer '{string}' contains the header '{cls.header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
@dataclass
class Message():
"""
Single message. Consists of a question and optionally an answer, a set of tags
and a file path.
"""
question: Question
answer: Optional[Answer]
tags: Optional[set[Tag]]
file_path: Optional[pathlib.Path]
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
@classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
"""
Create a Message from the given dict.
"""
return cls(question=data['question'],
answer=data.get('answer', None),
tags=set(data.get('tags', [])),
file_path=data.get('file_path', None))
@classmethod
def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]:
"""
Return only the tags from the given Message file.
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags()
else: # '.yaml'
tags = set() # FIXME
return tags
@classmethod
def from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
"""
Create a Message from the given file. Expects the following file structures:
For '.txt':
* TagLine
* Question.Header
* Question
* Answer.Header
For '.yaml':
TODO
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
tags: set[Tag]
question: Question
answer: Answer
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags()
text = fd.read().strip().split('\n')
question_idx = text.index(Question.header) + 1
answer_idx = text.index(Answer.header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
return cls(question, answer, tags, file_path)
else: # '.yaml'
with open(file_path, "r") as fd:
# FIXME: use the actual YAML format
data = yaml.load(fd, Loader=yaml.FullLoader)
data['file_path'] = file_path
return cls.from_dict(data)
def to_file(self, file_path: Optional[pathlib.Path]) -> None:
"""
Write Message to the given file. Creates the following file structures:
For '.txt':
* TagLine
* Question.Header
* Question
* Answer.Header
* Answer
For '.yaml':
TODO
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise MessageError("Got no valid path to write message")
if self.file_path.suffix not in self.file_suffixes:
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
if self.file_path.suffix == '.txt':
with open(self.file_path, "w") as fd:
msg_tags = self.tags or set()
fd.write(f'{TagLine.from_set(msg_tags)}\n')
fd.write(f'{Question.header}\n{self.question}\n')
fd.write(f'{Answer.header}\n{self.answer}\n')
# FIXME: write YAML format
def as_dict(self) -> dict[str, Any]:
return asdict(self)

View File

@ -1,7 +1,7 @@
"""
Module implementing tag related functions and classes.
"""
from typing import Type, TypeVar, Optional
from typing import Type, TypeVar, Optional, Final
TagInst = TypeVar('TagInst', bound='Tag')
TagLineInst = TypeVar('TagLineInst', bound='TagLine')
@ -16,9 +16,9 @@ class Tag(str):
A single tag. A string that can contain anything but the default separator (' ').
"""
# default separator
default_separator = ' '
default_separator: Final[str] = ' '
# alternative separators (e. g. for backwards compatibility)
alternative_separators = [',']
alternative_separators: Final[list[str]] = [',']
def __new__(cls: Type[TagInst], string: str) -> TagInst:
"""
@ -98,14 +98,16 @@ class TagLine(str):
the tags.
"""
# the prefix
prefix = 'TAGS:'
prefix: Final[str] = 'TAGS:'
def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst:
"""
Make sure the tagline string starts with the prefix.
Make sure the tagline string starts with the prefix. Also replace newlines
and multiple spaces with ' ', in order to support multiline TagLines.
"""
if not string.startswith(cls.prefix):
raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'")
string = ' '.join(string.split())
instance = super().__new__(cls, string)
return instance

View File

@ -8,6 +8,7 @@ from chatmastermind.api_client import ai
from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from chatmastermind.tags import Tag, TagLine, TagError
from chatmastermind.message import source_code, MessageError, Question, Answer
from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY
@ -255,6 +256,10 @@ class TestTagLine(CmmTestCase):
tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_valid_tagline_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_invalid_tagline(self) -> None:
with self.assertRaises(TagError):
TagLine('tag1 tag2')
@ -272,6 +277,11 @@ class TestTagLine(CmmTestCase):
tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_tags_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2')
tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_merge(self) -> None:
tagline1 = TagLine('TAGS: tag1 tag2')
tagline2 = TagLine('TAGS: tag2 tag3')
@ -345,3 +355,79 @@ class TestTagLine(CmmTestCase):
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not))
class SourceCodeTestCase(CmmTestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" ```python\n print(\"Hello, World!\")\n ```\n",
" ```python\n x = 10\n y = 20\n print(x + y)\n ```\n"
]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_without_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" print(\"Hello, World!\")\n",
" x = 10\n y = 20\n print(x + y)\n"
]
result = source_code(text, include_delims=False)
self.assertEqual(result, expected_result)
def test_source_code_with_single_code_block(self) -> None:
text = "```python\nprint(\"Hello, World!\")\n```"
expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_with_no_code_blocks(self) -> None:
text = "Some text without any code blocks"
expected_result: list[str] = []
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
class QuestionTestCase(CmmTestCase):
def test_question_with_prefix(self) -> None:
with self.assertRaises(MessageError):
Question("=== QUESTION === What is your name?")
def test_question_without_prefix(self) -> None:
question = Question("What is your favorite color?")
self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase):
def test_answer_with_prefix(self) -> None:
with self.assertRaises(MessageError):
Answer("=== ANSWER === Yes")
def test_answer_without_prefix(self) -> None:
answer = Answer("No")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No")