""" Module implementing message related functions and classes. """ import pathlib import yaml import tempfile import shutil import io from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple from typing import Generator, Iterator from typing import get_args as typing_get_args from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') MessageInst = TypeVar('MessageInst', bound='Message') AILineInst = TypeVar('AILineInst', bound='AILine') ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] MessageFormat = Literal['txt', 'yaml'] message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat) message_default_format: Final[MessageFormat] = 'txt' class MessageError(Exception): pass def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: """ Changes the YAML dump style to multiline syntax for multiline strings. """ if len(data.splitlines()) > 1: return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') return dumper.represent_scalar('tag:yaml.org,2002:str', data) yaml.add_representer(str, str_presenter) def source_code(text: str, include_delims: bool = False) -> list[str]: """ Extract all source code sections from the given text, i. e. all lines surrounded by lines tarting with '```'. If 'include_delims' is True, the surrounding lines are included, otherwise they are omitted. The result list contains every source code section as a single string. The order in the list represents the order of the sections in the text. """ code_sections: list[str] = [] code_lines: list[str] = [] in_code_block = False for line in text.split('\n'): if line.strip().startswith('```'): if include_delims: code_lines.append(line) if in_code_block: code_sections.append('\n'.join(code_lines) + '\n') code_lines.clear() in_code_block = not in_code_block elif in_code_block: code_lines.append(line) return code_sections def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: """ Searches the given message list for a message with the same file name as the given one (i. e. it compares Message.file_path.name). If the given message has no file_path, False is returned. """ if not message.file_path: return False for m in messages: if m.file_path and m.file_path.name == message.file_path.name: return True return False @dataclass(kw_only=True) class MessageFilter: """ Various filters for a Message. """ tags_or: Optional[set[Tag]] = None tags_and: Optional[set[Tag]] = None tags_not: Optional[set[Tag]] = None ai: Optional[str] = None model: Optional[str] = None question_contains: Optional[str] = None answer_contains: Optional[str] = None answer_state: Optional[Literal['available', 'missing']] = None ai_state: Optional[Literal['available', 'missing']] = None model_state: Optional[Literal['available', 'missing']] = None class AILine(str): """ A line that represents the AI name in the 'txt' format. """ prefix: Final[str] = 'AI:' def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance def ai(self) -> str: return self[len(self.prefix):].strip() @classmethod def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: return cls(' '.join([cls.prefix, ai])) class ModelLine(str): """ A line that represents the model name in the 'txt' format. """ prefix: Final[str] = 'MODEL:' def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance def model(self) -> str: return self[len(self.prefix):].strip() @classmethod def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: return cls(' '.join([cls.prefix, model])) class Answer(str): """ A single answer with a defined header. """ tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' def __init__(self, data: Union[str, Generator[str, None, None]]) -> None: # Indicator of whether all of data has been processed self.is_exhausted: bool = False # Initialize data self.iterator: Iterator[str] = self._init_data(data) # Set up the buffer to hold the 'Answer' content self.buffer: io.StringIO = io.StringIO() def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]: """ Process input data (either a string or a string generator) """ if isinstance(data, str): yield data else: yield from data def __str__(self) -> str: """ Output all content when converted into a string """ # Ensure all data has been processed for _ in self: pass # Return the 'Answer' content return self.buffer.getvalue() def __repr__(self) -> str: return repr(str(self)) def __iter__(self) -> Generator[str, None, None]: """ Allows the object to be iterable """ # Generate content if not all data has been processed if not self.is_exhausted: yield from self.generator_iter() else: yield self.buffer.getvalue() def generator_iter(self) -> Generator[str, None, None]: """ Main generator method to process data """ for piece in self.iterator: # Write to buffer and yield piece for the iterator self.buffer.write(piece) yield piece self.is_exhausted = True # Set the flag that all data has been processed # If the header occurs in the 'Answer' content, raise an error if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header): raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}") def __eq__(self, other: object) -> bool: """ Comparing the object to a string or another object """ if isinstance(other, str): return str(self) == other # Compare the string value of this object to the other string # Default behavior for comparing non-string objects return super().__eq__(other) def __hash__(self) -> int: """ Generate a hash for the object based on its string representation. """ return hash(str(self)) def __format__(self, format_spec: str) -> str: """ Return a formatted version of the string as per the format specification. """ return str(self).__format__(format_spec) @classmethod def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Answer from a list of strings. Make sure strings do not contain the header. """ def _gen() -> Generator[str, None, None]: if len(strings) > 0: yield strings[0] for s in strings[1:]: yield '\n' yield s return cls(_gen()) def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ return source_code(str(self), include_delims) class Question(str): """ A single question with a defined header. """ tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ Make sure the question string does not contain the header as a whole line (also not that from 'Answer', so it's always clear where the answer starts). """ string_lines = string.split('\n') if cls.txt_header in string_lines: raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") if Answer.txt_header in string_lines: raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ return source_code(self, include_delims) @dataclass class Message(): """ Single message. Consists of a question and optionally an answer, a set of tags and a file path. """ question: Question answer: Optional[Answer] = None # metadata, ignored when comparing messages tags: Optional[set[Tag]] = field(default=None, compare=False) ai: Optional[str] = field(default=None, compare=False) model: Optional[str] = field(default=None, compare=False) file_path: Optional[pathlib.Path] = field(default=None, compare=False) # class variables file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml'] file_suffix_write: ClassVar[str] = '.msg' default_format: ClassVar[MessageFormat] = message_default_format tags_yaml_key: ClassVar[str] = 'tags' file_yaml_key: ClassVar[str] = 'file_path' ai_yaml_key: ClassVar[str] = 'ai' model_yaml_key: ClassVar[str] = 'model' def __post_init__(self) -> None: # convert some types that are often set wrong if self.tags is not None and not isinstance(self.tags, set): self.tags = set(self.tags) if self.file_path is not None and not isinstance(self.file_path, pathlib.Path): self.file_path = pathlib.Path(self.file_path) def __hash__(self) -> int: """ The hash value is computed based on immutable members. """ return hash((self.question, self.answer)) def equals(self, other: MessageInst, tags: bool = True, ai: bool = True, model: bool = True, file_path: bool = True, verbose: bool = False) -> bool: """ Compare this message with another one, including the metadata. Return True if everything is identical, False otherwise. """ equal: bool = ((not tags or (self.tags == other.tags)) and (not ai or (self.ai == other.ai)) # noqa: W503 and (not model or (self.model == other.model)) # noqa: W503 and (not file_path or (self.file_path == other.file_path)) # noqa: W503 and (self == other)) # noqa: W503 if not equal and verbose: print("Messages not equal:") print(self) print(other) return equal @classmethod def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: """ Create a Message from the given dict. """ return cls(question=data[Question.yaml_key], answer=data.get(Answer.yaml_key, None), tags=set(data.get(cls.tags_yaml_key, [])), ai=data.get(cls.ai_yaml_key, None), model=data.get(cls.model_yaml_key, None), file_path=data.get(cls.file_yaml_key, None)) @classmethod def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Return only the tags from the given Message file, optionally filtered based on prefix or contained string. """ tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes_read: raise MessageError(f"File type '{file_path.suffix}' is not supported") try: message = cls.from_file(file_path) if message: msg_tags = message.filter_tags(prefix=prefix, contain=contain) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") if msg_tags: tags = msg_tags return tags @classmethod def tags_from_dir(cls: Type[MessageInst], path: pathlib.Path, glob: Optional[str] = None, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Return only the tags from message files in the given directory. The files can be filtered using 'glob', the tags by using 'prefix' and 'contain'. """ tags: set[Tag] = set() file_iter = path.glob(glob) if glob else path.iterdir() for file_path in sorted(file_iter): if file_path.is_file(): try: tags |= cls.tags_from_file(file_path, prefix, contain) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod def from_file(cls: Type[MessageInst], file_path: pathlib.Path, mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: """ Create a Message from the given file. Returns 'None' if the message does not fulfill the filter requirements. For TXT files, the tags are matched before building the whole message. The other filters are applied afterwards. """ if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes_read: raise MessageError(f"File type '{file_path.suffix}' is not supported") # try TXT first try: message = cls.__from_file_txt(file_path, mfilter.tags_or if mfilter else None, mfilter.tags_and if mfilter else None, mfilter.tags_not if mfilter else None) # then YAML except MessageError: message = cls.__from_file_yaml(file_path) if message and (mfilter is None or message.match(mfilter)): return message else: return None @classmethod def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 tags_or: Optional[set[Tag]] = None, tags_and: Optional[set[Tag]] = None, tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: """ Create a Message from the given TXT file. Expects the following file structures: For '.txt': * TagLine [Optional] * AI [Optional] * Model [Optional] * Question.txt_header * Question * Answer.txt_header [Optional] * Answer [Optional] Returns 'None' if the message does not fulfill the tag requirements. """ tags: set[Tag] = set() question: Question answer: Optional[Answer] = None ai: Optional[str] = None model: Optional[str] = None with open(file_path, "r") as fd: # TagLine (Optional) try: pos = fd.tell() tags = TagLine(fd.readline()).tags() except TagError: fd.seek(pos) # AILine (Optional) try: pos = fd.tell() ai = AILine(fd.readline()).ai() except MessageError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() except MessageError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') try: question_idx = text.index(Question.txt_header) + 1 except ValueError: raise MessageError(f"'{file_path}' does not contain a valid message") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) answer = Answer.from_list(text[answer_idx + 1:]) except ValueError: question = Question.from_list(text[question_idx:]) # match tags AFTER reading the whole file # -> make sure it's a valid 'txt' file format if tags_or or tags_and or tags_not: # match with an empty set if the file has no tags if not match_tags(tags, tags_or, tags_and, tags_not): return None return cls(question, answer, tags, ai, model, file_path) @classmethod def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: """ Create a Message from the given YAML file. Expects the following file structures: * Question.yaml_key: single or multiline string * Answer.yaml_key: single or multiline string [Optional] * Message.tags_yaml_key: list of strings [Optional] * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ with open(file_path, "r") as fd: try: data = yaml.load(fd, Loader=yaml.FullLoader) data[cls.file_yaml_key] = file_path return cls.from_dict(data) except Exception: raise MessageError(f"'{file_path}' does not contain a valid message") def to_str(self, with_metadata: bool = False, source_code_only: bool = False) -> str: """ Return the current Message as a string. """ output: list[str] = [] if source_code_only: # use the source code from answer only if self.answer: output.extend(self.answer.source_code(include_delims=True)) return '\n'.join(output) if len(output) > 0 else '' if with_metadata: output.append(self.tags_str()) output.append('FILE: ' + str(self.file_path)) output.append('AI: ' + str(self.ai)) output.append('MODEL: ' + str(self.model)) output.append(Question.txt_header) output.append(self.question) if self.answer: output.append(Answer.txt_header) output.append(str(self.answer)) return '\n'.join(output) def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11 """ Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'. Suffix is always '.msg'. """ if file_path: self.file_path = file_path if not self.file_path: raise MessageError("Got no valid path to write message") if mformat not in message_valid_formats: raise MessageError(f"File format '{mformat}' is not supported") # check for valid suffix # -> add one if it's empty # -> refuse old or otherwise unsupported suffixes if not self.file_path.suffix: self.file_path = self.file_path.with_suffix(self.file_suffix_write) elif self.file_path.suffix != self.file_suffix_write: raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported") # TXT if mformat == 'txt': return self.__to_file_txt(self.file_path) # YAML elif mformat == 'yaml': return self.__to_file_yaml(self.file_path) def __to_file_txt(self, file_path: pathlib.Path) -> None: """ Write a Message to the given file in TXT format. Creates the following file structures: * TagLine * AI [Optional] * Model [Optional] * Question.txt_header * Question * Answer.txt_header [Optional] * Answer [Optional] """ with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: temp_file_path = pathlib.Path(temp_fd.name) if self.tags: temp_fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: temp_fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: temp_fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n') shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ Write a Message to the given file in YAML format. Creates the following file structures: * Question.yaml_key: single or multiline string * Answer.yaml_key: single or multiline string * Message.tags_yaml_key: list of strings * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: temp_file_path = pathlib.Path(temp_fd.name) data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) if self.ai: data[self.ai_yaml_key] = self.ai if self.model: data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, temp_fd, sort_keys=False) shutil.move(temp_file_path, file_path) def rm_file(self) -> None: """ Delete the message file. Ignore empty file_path and not existing files. """ if self.file_path is not None: self.file_path.unlink(missing_ok=True) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ if not self.tags: return set() res_tags = self.tags.copy() if prefix and len(prefix) > 0: res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} if contain and len(contain) > 0: res_tags -= {tag for tag in res_tags if contain not in tag} return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ Returns all tags as a string with the TagLine prefix. Optionally filtered using 'Message.filter_tags()'. """ if self.tags: return str(TagLine.from_set(self.filter_tags(prefix, contain))) else: return str(TagLine.from_set(set())) def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 """ Matches the current Message to the given filter atttributes. Return True if all attributes match, else False. """ mytags = self.tags or set() if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503 or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 return False return True def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None: """ Renames the given tags. The first tuple element is the old name, the second one is the new name. """ if self.tags: self.tags = rename_tags(self.tags, tags_rename) def clear_answer(self) -> None: self.answer = None def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. Currently this is the file name without suffix. The ID is also used for sorting messages. """ if self.file_path: return self.file_path.stem else: raise MessageError("Can't create file ID without a file path") def as_dict(self) -> dict[str, Any]: return asdict(self) def tokens(self) -> int: """ Returns the nr. of AI language tokens used by this message. If unknown, 0 is returned. """ if self.answer: return self.question.tokens + self.answer.tokens else: return self.question.tokens