""" Module implementing message related functions and classes. """ import pathlib import yaml from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final from dataclasses import dataclass, asdict from .tags import Tag, TagLine, TagError, match_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') MessageInst = TypeVar('MessageInst', bound='Message') AILineInst = TypeVar('AILineInst', bound='AILine') ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] class MessageError(Exception): pass def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: """ Changes the YAML dump style to multiline syntax for multiline strings. """ if len(data.splitlines()) > 1: return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') return dumper.represent_scalar('tag:yaml.org,2002:str', data) yaml.add_representer(str, str_presenter) def source_code(text: str, include_delims: bool = False) -> list[str]: """ Extract all source code sections from the given text, i. e. all lines surrounded by lines tarting with '```'. If 'include_delims' is True, the surrounding lines are included, otherwise they are omitted. The result list contains every source code section as a single string. The order in the list represents the order of the sections in the text. """ code_sections: list[str] = [] code_lines: list[str] = [] in_code_block = False for line in text.split('\n'): if line.strip().startswith('```'): if include_delims: code_lines.append(line) if in_code_block: code_sections.append('\n'.join(code_lines) + '\n') code_lines.clear() in_code_block = not in_code_block elif in_code_block: code_lines.append(line) return code_sections class AILine(str): """ A line that represents the AI name in a '.txt' file.. """ prefix: Final[str] = 'AI:' def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance def ai(self) -> str: return self[len(self.prefix):].strip() @classmethod def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: return cls(' '.join([cls.prefix, ai])) class ModelLine(str): """ A line that represents the model name in a '.txt' file.. """ prefix: Final[str] = 'MODEL:' def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance def model(self) -> str: return self[len(self.prefix):].strip() @classmethod def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: return cls(' '.join([cls.prefix, model])) class Question(str): """ A single question with a defined header. """ txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ Make sure the question string does not contain the header. """ if cls.txt_header in string: raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ if any(cls.txt_header in string for string in strings): raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ return source_code(self, include_delims) class Answer(str): """ A single answer with a defined header. """ txt_header: ClassVar[str] = '=== ANSWER ===' yaml_key: ClassVar[str] = 'answer' def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ Make sure the answer string does not contain the header. """ if cls.txt_header in string: raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ if any(cls.txt_header in string for string in strings): raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ return source_code(self, include_delims) @dataclass class Message(): """ Single message. Consists of a question and optionally an answer, a set of tags and a file path. """ question: Question answer: Optional[Answer] = None tags: Optional[set[Tag]] = None ai: Optional[str] = None model: Optional[str] = None file_path: Optional[pathlib.Path] = None # class variables file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] tags_yaml_key: ClassVar[str] = 'tags' file_yaml_key: ClassVar[str] = 'file_path' ai_yaml_key: ClassVar[str] = 'ai' model_yaml_key: ClassVar[str] = 'model' @classmethod def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: """ Create a Message from the given dict. """ return cls(question=data[Question.yaml_key], answer=data.get(Answer.yaml_key, None), tags=set(data.get(cls.tags_yaml_key, [])), ai=data.get(cls.ai_yaml_key, None), model=data.get(cls.model_yaml_key, None), file_path=data.get(cls.file_yaml_key, None)) @classmethod def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: """ Return only the tags from the given Message file. """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") if file_path.suffix == '.txt': with open(file_path, "r") as fd: tags = TagLine(fd.readline()).tags() else: # '.yaml' with open(file_path, "r") as fd: data = yaml.load(fd, Loader=yaml.FullLoader) tags = set(sorted(data[cls.tags_yaml_key])) return tags @classmethod def from_file(cls: Type[MessageInst], file_path: pathlib.Path, tags_or: Optional[set[Tag]] = None, tags_and: Optional[set[Tag]] = None, tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: """ Create a Message from the given file. Returns 'None' if the message does not fulfill the tag requirements. """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") if file_path.suffix == '.txt': return cls.__from_file_txt(file_path, tags_or, tags_and, tags_not) else: return cls.__from_file_yaml(file_path, tags_or, tags_and, tags_not) @classmethod def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 tags_or: Optional[set[Tag]] = None, tags_and: Optional[set[Tag]] = None, tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: """ Create a Message from the given TXT file. Expects the following file structures: For '.txt': * TagLine [Optional] * AI [Optional] * Model [Optional] * Question.txt_header * Question * Answer.txt_header [Optional] * Answer [Optional] Returns 'None' if the message does not fulfill the tag requirements. """ tags: set[Tag] = set() question: Question answer: Optional[Answer] = None ai: Optional[str] = None model: Optional[str] = None with open(file_path, "r") as fd: # TagLine (Optional) try: pos = fd.tell() tags = TagLine(fd.readline()).tags() except TagError: fd.seek(pos) if tags_or or tags_and or tags_not: # match with an empty set if the file has no tags if not match_tags(tags, tags_or, tags_and, tags_not): return None # AILine (Optional) try: pos = fd.tell() ai = AILine(fd.readline()).ai() except TagError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() except TagError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') question_idx = text.index(Question.txt_header) + 1 try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) answer = Answer.from_list(text[answer_idx + 1:]) except ValueError: question = Question.from_list(text[question_idx:]) return cls(question, answer, tags, ai, model, file_path) @classmethod def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path, tags_or: Optional[set[Tag]] = None, tags_and: Optional[set[Tag]] = None, tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: """ Create a Message from the given YAML file. Expects the following file structures: * Question.yaml_key: single or multiline string * Answer.yaml_key: single or multiline string [Optional] * Message.tags_yaml_key: list of strings [Optional] * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] Returns 'None' if the message does not fulfill the tag requirements. """ tags: set[Tag] = set() with open(file_path, "r") as fd: data = yaml.load(fd, Loader=yaml.FullLoader) if tags_or or tags_and or tags_not: if Message.tags_yaml_key in data: tags = set([Tag(tag) for tag in data[Message.tags_yaml_key]]) # match with an empty set if the file has no tags if not match_tags(tags, tags_or, tags_and, tags_not): return None data[cls.file_yaml_key] = file_path return cls.from_dict(data) def to_file(self, file_path: Optional[pathlib.Path]) -> None: # noqa: 11 """ Write a Message to the given file. Type is determined based on the suffix. Currently supported suffixes: ['.txt', '.yaml'] """ if file_path: self.file_path = file_path if not self.file_path: raise MessageError("Got no valid path to write message") if self.file_path.suffix not in self.file_suffixes: raise MessageError(f"File type '{self.file_path.suffix}' is not supported") # TXT if self.file_path.suffix == '.txt': return self.__to_file_txt(self.file_path) elif self.file_path.suffix == '.yaml': return self.__to_file_yaml(self.file_path) def __to_file_txt(self, file_path: pathlib.Path) -> None: """ Write a Message to the given file in TXT format. Creates the following file structures: * TagLine * AI [Optional] * Model [Optional] * Question.txt_header * Question * Answer.txt_header * Answer """ with open(file_path, "w") as fd: if self.tags: fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: fd.write(f'{ModelLine.from_model(self.model)}\n') fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: fd.write(f'{Answer.txt_header}\n{self.answer}\n') def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ Write a Message to the given file in YAML format. Creates the following file structures: * Question.yaml_key: single or multiline string * Answer.yaml_key: single or multiline string * Message.tags_yaml_key: list of strings * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ with open(file_path, "w") as fd: data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) if self.ai: data[self.ai_yaml_key] = self.ai if self.model: data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, fd, sort_keys=False) def as_dict(self) -> dict[str, Any]: return asdict(self)