424 lines
16 KiB
Python

"""
Module implementing message related functions and classes.
"""
import pathlib
import yaml
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal
from dataclasses import dataclass, asdict
from .tags import Tag, TagLine, TagError, match_tags
QuestionInst = TypeVar('QuestionInst', bound='Question')
AnswerInst = TypeVar('AnswerInst', bound='Answer')
MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
class MessageError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
def source_code(text: str, include_delims: bool = False) -> list[str]:
"""
Extract all source code sections from the given text, i. e. all lines
surrounded by lines tarting with '```'. If 'include_delims' is True,
the surrounding lines are included, otherwise they are omitted. The
result list contains every source code section as a single string.
The order in the list represents the order of the sections in the text.
"""
code_sections: list[str] = []
code_lines: list[str] = []
in_code_block = False
for line in text.split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
if in_code_block:
code_sections.append('\n'.join(code_lines) + '\n')
code_lines.clear()
in_code_block = not in_code_block
elif in_code_block:
code_lines.append(line)
return code_sections
@dataclass(kw_only=True)
class MessageFilter:
"""
Various filters for a Message.
"""
tags_or: Optional[set[Tag]] = None
tags_and: Optional[set[Tag]] = None
tags_not: Optional[set[Tag]] = None
ai: Optional[str] = None
model: Optional[str] = None
question_contains: Optional[str] = None
answer_contains: Optional[str] = None
answer_state: Optional[Literal['available', 'missing']] = None
ai_state: Optional[Literal['available', 'missing']] = None
model_state: Optional[Literal['available', 'missing']] = None
class AILine(str):
"""
A line that represents the AI name in a '.txt' file..
"""
prefix: Final[str] = 'AI:'
def __new__(cls: Type[AILineInst], string: str) -> AILineInst:
if not string.startswith(cls.prefix):
raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance
def ai(self) -> str:
return self[len(self.prefix):].strip()
@classmethod
def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst:
return cls(' '.join([cls.prefix, ai]))
class ModelLine(str):
"""
A line that represents the model name in a '.txt' file..
"""
prefix: Final[str] = 'MODEL:'
def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst:
if not string.startswith(cls.prefix):
raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance
def model(self) -> str:
return self[len(self.prefix):].strip()
@classmethod
def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst:
return cls(' '.join([cls.prefix, model]))
class Question(str):
"""
A single question with a defined header.
"""
txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'question'
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
"""
Make sure the question string does not contain the header.
"""
if cls.txt_header in string:
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.txt_header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
class Answer(str):
"""
A single answer with a defined header.
"""
txt_header: ClassVar[str] = '=== ANSWER ==='
yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Make sure the answer string does not contain the header.
"""
if cls.txt_header in string:
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if any(cls.txt_header in string for string in strings):
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
@dataclass
class Message():
"""
Single message. Consists of a question and optionally an answer, a set of tags
and a file path.
"""
question: Question
answer: Optional[Answer] = None
tags: Optional[set[Tag]] = None
ai: Optional[str] = None
model: Optional[str] = None
file_path: Optional[pathlib.Path] = None
# class variables
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model'
@classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
"""
Create a Message from the given dict.
"""
return cls(question=data[Question.yaml_key],
answer=data.get(Answer.yaml_key, None),
tags=set(data.get(cls.tags_yaml_key, [])),
ai=data.get(cls.ai_yaml_key, None),
model=data.get(cls.model_yaml_key, None),
file_path=data.get(cls.file_yaml_key, None))
@classmethod
def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]:
"""
Return only the tags from the given Message file.
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags()
else: # '.yaml'
with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
tags = set(sorted(data[cls.tags_yaml_key]))
return tags
@classmethod
def from_file(cls: Type[MessageInst], file_path: pathlib.Path,
mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]:
"""
Create a Message from the given file. Returns 'None' if the message does
not fulfill the filter requirements. For TXT files, the tags are matched
before building the whole message. The other filters are applied afterwards.
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt':
message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None)
else:
message = cls.__from_file_yaml(file_path)
if message and (not mfilter or (mfilter and message.match(mfilter))):
return message
else:
return None
@classmethod
def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11
tags_or: Optional[set[Tag]] = None,
tags_and: Optional[set[Tag]] = None,
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
"""
Create a Message from the given TXT file. Expects the following file structures:
For '.txt':
* TagLine [Optional]
* AI [Optional]
* Model [Optional]
* Question.txt_header
* Question
* Answer.txt_header [Optional]
* Answer [Optional]
Returns 'None' if the message does not fulfill the tag requirements.
"""
tags: set[Tag] = set()
question: Question
answer: Optional[Answer] = None
ai: Optional[str] = None
model: Optional[str] = None
with open(file_path, "r") as fd:
# TagLine (Optional)
try:
pos = fd.tell()
tags = TagLine(fd.readline()).tags()
except TagError:
fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional)
try:
pos = fd.tell()
ai = AILine(fd.readline()).ai()
except TagError:
fd.seek(pos)
# ModelLine (Optional)
try:
pos = fd.tell()
model = ModelLine(fd.readline()).model()
except TagError:
fd.seek(pos)
# Question and Answer
text = fd.read().strip().split('\n')
question_idx = text.index(Question.txt_header) + 1
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
return cls(question, answer, tags, ai, model, file_path)
@classmethod
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
"""
Create a Message from the given YAML file. Expects the following file structures:
* Question.yaml_key: single or multiline string
* Answer.yaml_key: single or multiline string [Optional]
* Message.tags_yaml_key: list of strings [Optional]
* Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional]
"""
with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
"""
Write a Message to the given file. Type is determined based on the suffix.
Currently supported suffixes: ['.txt', '.yaml']
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise MessageError("Got no valid path to write message")
if self.file_path.suffix not in self.file_suffixes:
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
# TXT
if self.file_path.suffix == '.txt':
return self.__to_file_txt(self.file_path)
elif self.file_path.suffix == '.yaml':
return self.__to_file_yaml(self.file_path)
def __to_file_txt(self, file_path: pathlib.Path) -> None:
"""
Write a Message to the given file in TXT format.
Creates the following file structures:
* TagLine
* AI [Optional]
* Model [Optional]
* Question.txt_header
* Question
* Answer.txt_header
* Answer
"""
with open(file_path, "w") as fd:
if self.tags:
fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai:
fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model:
fd.write(f'{ModelLine.from_model(self.model)}\n')
fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer:
fd.write(f'{Answer.txt_header}\n{self.answer}\n')
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
"""
Write a Message to the given file in YAML format.
Creates the following file structures:
* Question.yaml_key: single or multiline string
* Answer.yaml_key: single or multiline string
* Message.tags_yaml_key: list of strings
* Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional]
"""
with open(file_path, "w") as fd:
data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer:
data[Answer.yaml_key] = str(self.answer)
if self.ai:
data[self.ai_yaml_key] = self.ai
if self.model:
data[self.model_yaml_key] = self.model
if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd, sort_keys=False)
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
"""
Matches the current Message to the given filter atttributes.
Return True if all attributes match, else False.
"""
mytags = self.tags or set()
if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503
or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503
or (mfilter.model_state == 'missing' and self.model)): # noqa: W503
return False
return True
def msg_id(self) -> str:
"""
Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name. The ID is also used for sorting messages.
"""
if self.file_path:
return self.file_path.name
else:
raise MessageError("Can't create file ID without a file path")
def as_dict(self) -> dict[str, Any]:
return asdict(self)