Compare commits

..

7 Commits

2 changed files with 57 additions and 38 deletions

View File

@ -59,14 +59,15 @@ class Question(str):
"""
A single question with a defined header.
"""
header: ClassVar[str] = '=== QUESTION ==='
txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'question'
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
"""
Make sure the question string does not contain the header.
"""
if cls.header in string:
raise MessageError(f"Question '{string}' contains the header '{cls.header}'")
if cls.txt_header in string:
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@ -75,8 +76,8 @@ class Question(str):
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.header}'")
if any(cls.txt_header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
@ -91,14 +92,15 @@ class Answer(str):
"""
A single answer with a defined header.
"""
header: ClassVar[str] = '=== ANSWER ==='
txt_header: ClassVar[str] = '=== ANSWER ==='
yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Make sure the answer string does not contain the header.
"""
if cls.header in string:
raise MessageError(f"Answer '{string}' contains the header '{cls.header}'")
if cls.txt_header in string:
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@ -107,8 +109,8 @@ class Answer(str):
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.header}'")
if any(cls.txt_header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
@ -130,16 +132,18 @@ class Message():
tags: Optional[set[Tag]]
file_path: Optional[pathlib.Path]
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path'
@classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
"""
Create a Message from the given dict.
"""
return cls(question=data['question'],
answer=data.get('answer', None),
tags=set(data.get('tags', [])),
file_path=data.get('file_path', None))
return cls(question=data[Question.yaml_key],
answer=data.get(Answer.yaml_key, None),
tags=set(data.get(cls.tags_yaml_key, [])),
file_path=data.get(cls.file_yaml_key, None))
@classmethod
def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]:
@ -154,7 +158,9 @@ class Message():
with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags()
else: # '.yaml'
tags = set() # FIXME
with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
tags = set(sorted(data[cls.tags_yaml_key]))
return tags
@classmethod
@ -163,9 +169,9 @@ class Message():
Create a Message from the given file. Expects the following file structures:
For '.txt':
* TagLine
* Question.Header
* Question.txt_header
* Question
* Answer.Header
* Answer.txt_header
For '.yaml':
* question: single or multiline string
* answer: single or multiline string
@ -183,15 +189,15 @@ class Message():
with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags()
text = fd.read().strip().split('\n')
question_idx = text.index(Question.header) + 1
answer_idx = text.index(Answer.header)
question_idx = text.index(Question.txt_header) + 1
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
return cls(question, answer, tags, file_path)
else: # '.yaml'
with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
data['file_path'] = file_path
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
def to_file(self, file_path: Optional[pathlib.Path]) -> None:
@ -199,13 +205,13 @@ class Message():
Write Message to the given file. Creates the following file structures:
For '.txt':
* TagLine
* Question.Header
* Question.txt_header
* Question
* Answer.Header
* Answer.txt_header
* Answer
For '.yaml':
* question: single or multiline string
* answer: single or multiline string
* Question.yaml_key: single or multiline string
* Answer.yaml_key: single or multiline string
* tags: list of strings
"""
if file_path:
@ -218,15 +224,15 @@ class Message():
with open(self.file_path, "w") as fd:
msg_tags = self.tags or set()
fd.write(f'{TagLine.from_set(msg_tags)}\n')
fd.write(f'{Question.header}\n{self.question}\n')
fd.write(f'{Answer.header}\n{self.answer}\n')
fd.write(f'{Question.txt_header}\n{self.question}\n')
fd.write(f'{Answer.txt_header}\n{self.answer}\n')
elif self.file_path.suffix == '.yaml':
with open(self.file_path, "w") as fd:
data: YamlDict = {'question': str(self.question)}
data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer:
data['answer'] = str(self.answer)
data[Answer.yaml_key] = str(self.answer)
if self.tags:
data['tags'] = sorted([str(tag) for tag in self.tags])
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd)
def as_dict(self) -> dict[str, Any]:

View File

@ -3,7 +3,7 @@ import tempfile
from typing import cast
from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer
from chatmastermind.tags import Tag
from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(CmmTestCase):
@ -100,10 +100,10 @@ class MessageToFileTxtTestCase(CmmTestCase):
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = """TAGS: tag1 tag2
=== QUESTION ===
expected_content = f"""{TagLine.prefix} tag1 tag2
{Question.txt_header}
This is a question.
=== ANSWER ===
{Answer.txt_header}
This is an answer.
"""
self.assertEqual(content, expected_content)
@ -157,13 +157,13 @@ class MessageToFileYamlTestCase(CmmTestCase):
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = """answer: |-
expected_content = f"""{Answer.yaml_key}: |-
This is a
multiline answer.
question: |-
{Question.yaml_key}: |-
This is a
multiline question.
tags:
{Message.tags_yaml_key}:
- tag1
- tag2
"""
@ -175,7 +175,12 @@ class MessageFromFileTxtTestCase(CmmTestCase):
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write("TAGS: tag1 tag2\n=== QUESTION ===\nThis is a question.\n=== ANSWER ===\nThis is an answer.\n")
fd.write(f"""{TagLine.prefix} tag1 tag2
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
def tearDown(self) -> None:
self.file.close()
@ -201,7 +206,15 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write("question: |-\n This is a question.\nanswer: |-\n This is an answer.\ntags:\n- tag1\n- tag2")
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
{Message.tags_yaml_key}:
- tag1
- tag2
""")
def tearDown(self) -> None:
self.file.close()