Compare commits
7 Commits
4c674b80a6
...
a4e0d14ca9
| Author | SHA1 | Date | |
|---|---|---|---|
| a4e0d14ca9 | |||
| 52c0d6a4a9 | |||
| fdc2f4aca9 | |||
| 7905390dee | |||
| 15e5423984 | |||
| f433f4c4b2 | |||
| a1e55104b0 |
@ -59,14 +59,15 @@ class Question(str):
|
||||
"""
|
||||
A single question with a defined header.
|
||||
"""
|
||||
header: ClassVar[str] = '=== QUESTION ==='
|
||||
txt_header: ClassVar[str] = '=== QUESTION ==='
|
||||
yaml_key: ClassVar[str] = 'question'
|
||||
|
||||
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
||||
"""
|
||||
Make sure the question string does not contain the header.
|
||||
"""
|
||||
if cls.header in string:
|
||||
raise MessageError(f"Question '{string}' contains the header '{cls.header}'")
|
||||
if cls.txt_header in string:
|
||||
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
@ -75,8 +76,8 @@ class Question(str):
|
||||
"""
|
||||
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||
"""
|
||||
if any(cls.header in string for string in strings):
|
||||
raise MessageError(f"Question contains the header '{cls.header}'")
|
||||
if any(cls.txt_header in string for string in strings):
|
||||
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||
return instance
|
||||
|
||||
@ -91,14 +92,15 @@ class Answer(str):
|
||||
"""
|
||||
A single answer with a defined header.
|
||||
"""
|
||||
header: ClassVar[str] = '=== ANSWER ==='
|
||||
txt_header: ClassVar[str] = '=== ANSWER ==='
|
||||
yaml_key: ClassVar[str] = 'answer'
|
||||
|
||||
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
|
||||
"""
|
||||
Make sure the answer string does not contain the header.
|
||||
"""
|
||||
if cls.header in string:
|
||||
raise MessageError(f"Answer '{string}' contains the header '{cls.header}'")
|
||||
if cls.txt_header in string:
|
||||
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
@ -107,8 +109,8 @@ class Answer(str):
|
||||
"""
|
||||
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||
"""
|
||||
if any(cls.header in string for string in strings):
|
||||
raise MessageError(f"Question contains the header '{cls.header}'")
|
||||
if any(cls.txt_header in string for string in strings):
|
||||
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||
return instance
|
||||
|
||||
@ -130,16 +132,18 @@ class Message():
|
||||
tags: Optional[set[Tag]]
|
||||
file_path: Optional[pathlib.Path]
|
||||
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
|
||||
tags_yaml_key: ClassVar[str] = 'tags'
|
||||
file_yaml_key: ClassVar[str] = 'file_path'
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
|
||||
"""
|
||||
Create a Message from the given dict.
|
||||
"""
|
||||
return cls(question=data['question'],
|
||||
answer=data.get('answer', None),
|
||||
tags=set(data.get('tags', [])),
|
||||
file_path=data.get('file_path', None))
|
||||
return cls(question=data[Question.yaml_key],
|
||||
answer=data.get(Answer.yaml_key, None),
|
||||
tags=set(data.get(cls.tags_yaml_key, [])),
|
||||
file_path=data.get(cls.file_yaml_key, None))
|
||||
|
||||
@classmethod
|
||||
def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]:
|
||||
@ -154,7 +158,9 @@ class Message():
|
||||
with open(file_path, "r") as fd:
|
||||
tags = TagLine(fd.readline()).tags()
|
||||
else: # '.yaml'
|
||||
tags = set() # FIXME
|
||||
with open(file_path, "r") as fd:
|
||||
data = yaml.load(fd, Loader=yaml.FullLoader)
|
||||
tags = set(sorted(data[cls.tags_yaml_key]))
|
||||
return tags
|
||||
|
||||
@classmethod
|
||||
@ -163,9 +169,9 @@ class Message():
|
||||
Create a Message from the given file. Expects the following file structures:
|
||||
For '.txt':
|
||||
* TagLine
|
||||
* Question.Header
|
||||
* Question.txt_header
|
||||
* Question
|
||||
* Answer.Header
|
||||
* Answer.txt_header
|
||||
For '.yaml':
|
||||
* question: single or multiline string
|
||||
* answer: single or multiline string
|
||||
@ -183,15 +189,15 @@ class Message():
|
||||
with open(file_path, "r") as fd:
|
||||
tags = TagLine(fd.readline()).tags()
|
||||
text = fd.read().strip().split('\n')
|
||||
question_idx = text.index(Question.header) + 1
|
||||
answer_idx = text.index(Answer.header)
|
||||
question_idx = text.index(Question.txt_header) + 1
|
||||
answer_idx = text.index(Answer.txt_header)
|
||||
question = Question.from_list(text[question_idx:answer_idx])
|
||||
answer = Answer.from_list(text[answer_idx + 1:])
|
||||
return cls(question, answer, tags, file_path)
|
||||
else: # '.yaml'
|
||||
with open(file_path, "r") as fd:
|
||||
data = yaml.load(fd, Loader=yaml.FullLoader)
|
||||
data['file_path'] = file_path
|
||||
data[cls.file_yaml_key] = file_path
|
||||
return cls.from_dict(data)
|
||||
|
||||
def to_file(self, file_path: Optional[pathlib.Path]) -> None:
|
||||
@ -199,13 +205,13 @@ class Message():
|
||||
Write Message to the given file. Creates the following file structures:
|
||||
For '.txt':
|
||||
* TagLine
|
||||
* Question.Header
|
||||
* Question.txt_header
|
||||
* Question
|
||||
* Answer.Header
|
||||
* Answer.txt_header
|
||||
* Answer
|
||||
For '.yaml':
|
||||
* question: single or multiline string
|
||||
* answer: single or multiline string
|
||||
* Question.yaml_key: single or multiline string
|
||||
* Answer.yaml_key: single or multiline string
|
||||
* tags: list of strings
|
||||
"""
|
||||
if file_path:
|
||||
@ -218,15 +224,15 @@ class Message():
|
||||
with open(self.file_path, "w") as fd:
|
||||
msg_tags = self.tags or set()
|
||||
fd.write(f'{TagLine.from_set(msg_tags)}\n')
|
||||
fd.write(f'{Question.header}\n{self.question}\n')
|
||||
fd.write(f'{Answer.header}\n{self.answer}\n')
|
||||
fd.write(f'{Question.txt_header}\n{self.question}\n')
|
||||
fd.write(f'{Answer.txt_header}\n{self.answer}\n')
|
||||
elif self.file_path.suffix == '.yaml':
|
||||
with open(self.file_path, "w") as fd:
|
||||
data: YamlDict = {'question': str(self.question)}
|
||||
data: YamlDict = {Question.yaml_key: str(self.question)}
|
||||
if self.answer:
|
||||
data['answer'] = str(self.answer)
|
||||
data[Answer.yaml_key] = str(self.answer)
|
||||
if self.tags:
|
||||
data['tags'] = sorted([str(tag) for tag in self.tags])
|
||||
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
|
||||
yaml.dump(data, fd)
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
|
||||
@ -3,7 +3,7 @@ import tempfile
|
||||
from typing import cast
|
||||
from .test_main import CmmTestCase
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer
|
||||
from chatmastermind.tags import Tag
|
||||
from chatmastermind.tags import Tag, TagLine
|
||||
|
||||
|
||||
class SourceCodeTestCase(CmmTestCase):
|
||||
@ -100,10 +100,10 @@ class MessageToFileTxtTestCase(CmmTestCase):
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = """TAGS: tag1 tag2
|
||||
=== QUESTION ===
|
||||
expected_content = f"""{TagLine.prefix} tag1 tag2
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
=== ANSWER ===
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
"""
|
||||
self.assertEqual(content, expected_content)
|
||||
@ -157,13 +157,13 @@ class MessageToFileYamlTestCase(CmmTestCase):
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = """answer: |-
|
||||
expected_content = f"""{Answer.yaml_key}: |-
|
||||
This is a
|
||||
multiline answer.
|
||||
question: |-
|
||||
{Question.yaml_key}: |-
|
||||
This is a
|
||||
multiline question.
|
||||
tags:
|
||||
{Message.tags_yaml_key}:
|
||||
- tag1
|
||||
- tag2
|
||||
"""
|
||||
@ -175,7 +175,12 @@ class MessageFromFileTxtTestCase(CmmTestCase):
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
with open(self.file_path, "w") as fd:
|
||||
fd.write("TAGS: tag1 tag2\n=== QUESTION ===\nThis is a question.\n=== ANSWER ===\nThis is an answer.\n")
|
||||
fd.write(f"""{TagLine.prefix} tag1 tag2
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
@ -201,7 +206,15 @@ class MessageFromFileYamlTestCase(CmmTestCase):
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
with open(self.file_path, "w") as fd:
|
||||
fd.write("question: |-\n This is a question.\nanswer: |-\n This is an answer.\ntags:\n- tag1\n- tag2")
|
||||
fd.write(f"""
|
||||
{Question.yaml_key}: |-
|
||||
This is a question.
|
||||
{Answer.yaml_key}: |-
|
||||
This is an answer.
|
||||
{Message.tags_yaml_key}:
|
||||
- tag1
|
||||
- tag2
|
||||
""")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user