""" Module implementing message related functions and classes. """ import pathlib import yaml from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') MessageInst = TypeVar('MessageInst', bound='Message') AILineInst = TypeVar('AILineInst', bound='AILine') ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] class MessageError(Exception): pass def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: """ Changes the YAML dump style to multiline syntax for multiline strings. """ if len(data.splitlines()) > 1: return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') return dumper.represent_scalar('tag:yaml.org,2002:str', data) yaml.add_representer(str, str_presenter) def source_code(text: str, include_delims: bool = False) -> list[str]: """ Extract all source code sections from the given text, i. e. all lines surrounded by lines tarting with '```'. If 'include_delims' is True, the surrounding lines are included, otherwise they are omitted. The result list contains every source code section as a single string. The order in the list represents the order of the sections in the text. """ code_sections: list[str] = [] code_lines: list[str] = [] in_code_block = False for line in text.split('\n'): if line.strip().startswith('```'): if include_delims: code_lines.append(line) if in_code_block: code_sections.append('\n'.join(code_lines) + '\n') code_lines.clear() in_code_block = not in_code_block elif in_code_block: code_lines.append(line) return code_sections def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: """ Searches the given message list for a message with the same file name as the given one (i. e. it compares Message.file_path.name). If the given message has no file_path, False is returned. """ if not message.file_path: return False for m in messages: if m.file_path and m.file_path.name == message.file_path.name: return True return False @dataclass(kw_only=True) class MessageFilter: """ Various filters for a Message. """ tags_or: Optional[set[Tag]] = None tags_and: Optional[set[Tag]] = None tags_not: Optional[set[Tag]] = None ai: Optional[str] = None model: Optional[str] = None question_contains: Optional[str] = None answer_contains: Optional[str] = None answer_state: Optional[Literal['available', 'missing']] = None ai_state: Optional[Literal['available', 'missing']] = None model_state: Optional[Literal['available', 'missing']] = None class AILine(str): """ A line that represents the AI name in a '.txt' file.. """ prefix: Final[str] = 'AI:' def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance def ai(self) -> str: return self[len(self.prefix):].strip() @classmethod def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: return cls(' '.join([cls.prefix, ai])) class ModelLine(str): """ A line that represents the model name in a '.txt' file.. """ prefix: Final[str] = 'MODEL:' def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance def model(self) -> str: return self[len(self.prefix):].strip() @classmethod def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: return cls(' '.join([cls.prefix, model])) class Answer(str): """ A single answer with a defined header. """ tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ Make sure the answer string does not contain the header as a whole line. """ if cls.txt_header in string.split('\n'): raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ return source_code(self, include_delims) class Question(str): """ A single question with a defined header. """ tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ Make sure the question string does not contain the header as a whole line (also not that from 'Answer', so it's always clear where the answer starts). """ string_lines = string.split('\n') if cls.txt_header in string_lines: raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") if Answer.txt_header in string_lines: raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ return source_code(self, include_delims) @dataclass class Message(): """ Single message. Consists of a question and optionally an answer, a set of tags and a file path. """ question: Question answer: Optional[Answer] = None # metadata, ignored when comparing messages tags: Optional[set[Tag]] = field(default=None, compare=False) ai: Optional[str] = field(default=None, compare=False) model: Optional[str] = field(default=None, compare=False) file_path: Optional[pathlib.Path] = field(default=None, compare=False) # class variables file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] tags_yaml_key: ClassVar[str] = 'tags' file_yaml_key: ClassVar[str] = 'file_path' ai_yaml_key: ClassVar[str] = 'ai' model_yaml_key: ClassVar[str] = 'model' def __hash__(self) -> int: """ The hash value is computed based on immutable members. """ return hash((self.question, self.answer)) @classmethod def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: """ Create a Message from the given dict. """ return cls(question=data[Question.yaml_key], answer=data.get(Answer.yaml_key, None), tags=set(data.get(cls.tags_yaml_key, [])), ai=data.get(cls.ai_yaml_key, None), model=data.get(cls.model_yaml_key, None), file_path=data.get(cls.file_yaml_key, None)) @classmethod def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Return only the tags from the given Message file, optionally filtered based on prefix or contained string. """ tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") # for TXT, it's enough to read the TagLine if file_path.suffix == '.txt': with open(file_path, "r") as fd: try: tags = TagLine(fd.readline()).tags(prefix, contain) except TagError: pass # message without tags else: # '.yaml' try: message = cls.from_file(file_path) if message: msg_tags = message.filter_tags(prefix=prefix, contain=contain) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") if msg_tags: tags = msg_tags return tags @classmethod def tags_from_dir(cls: Type[MessageInst], path: pathlib.Path, glob: Optional[str] = None, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Return only the tags from message files in the given directory. The files can be filtered using 'glob', the tags by using 'prefix' and 'contain'. """ tags: set[Tag] = set() file_iter = path.glob(glob) if glob else path.iterdir() for file_path in sorted(file_iter): if file_path.is_file(): try: tags |= cls.tags_from_file(file_path, prefix, contain) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod def from_file(cls: Type[MessageInst], file_path: pathlib.Path, mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: """ Create a Message from the given file. Returns 'None' if the message does not fulfill the filter requirements. For TXT files, the tags are matched before building the whole message. The other filters are applied afterwards. """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") if file_path.suffix == '.txt': message = cls.__from_file_txt(file_path, mfilter.tags_or if mfilter else None, mfilter.tags_and if mfilter else None, mfilter.tags_not if mfilter else None) else: message = cls.__from_file_yaml(file_path) if message and (not mfilter or (mfilter and message.match(mfilter))): return message else: return None @classmethod def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 tags_or: Optional[set[Tag]] = None, tags_and: Optional[set[Tag]] = None, tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: """ Create a Message from the given TXT file. Expects the following file structures: For '.txt': * TagLine [Optional] * AI [Optional] * Model [Optional] * Question.txt_header * Question * Answer.txt_header [Optional] * Answer [Optional] Returns 'None' if the message does not fulfill the tag requirements. """ tags: set[Tag] = set() question: Question answer: Optional[Answer] = None ai: Optional[str] = None model: Optional[str] = None with open(file_path, "r") as fd: # TagLine (Optional) try: pos = fd.tell() tags = TagLine(fd.readline()).tags() except TagError: fd.seek(pos) if tags_or or tags_and or tags_not: # match with an empty set if the file has no tags if not match_tags(tags, tags_or, tags_and, tags_not): return None # AILine (Optional) try: pos = fd.tell() ai = AILine(fd.readline()).ai() except MessageError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() except MessageError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') try: question_idx = text.index(Question.txt_header) + 1 except ValueError: raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) answer = Answer.from_list(text[answer_idx + 1:]) except ValueError: question = Question.from_list(text[question_idx:]) return cls(question, answer, tags, ai, model, file_path) @classmethod def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: """ Create a Message from the given YAML file. Expects the following file structures: * Question.yaml_key: single or multiline string * Answer.yaml_key: single or multiline string [Optional] * Message.tags_yaml_key: list of strings [Optional] * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ with open(file_path, "r") as fd: data = yaml.load(fd, Loader=yaml.FullLoader) data[cls.file_yaml_key] = file_path return cls.from_dict(data) def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ Write a Message to the given file. Type is determined based on the suffix. Currently supported suffixes: ['.txt', '.yaml'] """ if file_path: self.file_path = file_path if not self.file_path: raise MessageError("Got no valid path to write message") if self.file_path.suffix not in self.file_suffixes: raise MessageError(f"File type '{self.file_path.suffix}' is not supported") # TXT if self.file_path.suffix == '.txt': return self.__to_file_txt(self.file_path) elif self.file_path.suffix == '.yaml': return self.__to_file_yaml(self.file_path) def __to_file_txt(self, file_path: pathlib.Path) -> None: """ Write a Message to the given file in TXT format. Creates the following file structures: * TagLine * AI [Optional] * Model [Optional] * Question.txt_header * Question * Answer.txt_header * Answer """ with open(file_path, "w") as fd: if self.tags: fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: fd.write(f'{ModelLine.from_model(self.model)}\n') fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: fd.write(f'{Answer.txt_header}\n{self.answer}\n') def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ Write a Message to the given file in YAML format. Creates the following file structures: * Question.yaml_key: single or multiline string * Answer.yaml_key: single or multiline string * Message.tags_yaml_key: list of strings * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ with open(file_path, "w") as fd: data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) if self.ai: data[self.ai_yaml_key] = self.ai if self.model: data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, fd, sort_keys=False) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ if not self.tags: return set() res_tags = self.tags.copy() if prefix and len(prefix) > 0: res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} if contain and len(contain) > 0: res_tags -= {tag for tag in res_tags if contain not in tag} return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ Returns all tags as a string with the TagLine prefix. Optionally filtered using 'Message.filter_tags()'. """ if self.tags: return str(TagLine.from_set(self.filter_tags(prefix, contain))) else: return str(TagLine.from_set(set())) def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 """ Matches the current Message to the given filter atttributes. Return True if all attributes match, else False. """ mytags = self.tags or set() if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503 or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 return False return True def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. Currently this is the file name. The ID is also used for sorting messages. """ if self.file_path: return self.file_path.name else: raise MessageError("Can't create file ID without a file path") def as_dict(self) -> dict[str, Any]: return asdict(self) def tokens(self) -> int: """ Returns the nr. of AI language tokens used by this message. If unknown, 0 is returned. """ if self.answer: return self.question.tokens + self.answer.tokens else: return self.question.tokens