569 lines
22 KiB
Python
569 lines
22 KiB
Python
"""
|
|
Module implementing message related functions and classes.
|
|
"""
|
|
import pathlib
|
|
import yaml
|
|
import tempfile
|
|
import shutil
|
|
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
|
|
from dataclasses import dataclass, asdict, field
|
|
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
|
|
|
|
QuestionInst = TypeVar('QuestionInst', bound='Question')
|
|
AnswerInst = TypeVar('AnswerInst', bound='Answer')
|
|
MessageInst = TypeVar('MessageInst', bound='Message')
|
|
AILineInst = TypeVar('AILineInst', bound='AILine')
|
|
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
|
|
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
|
|
|
|
|
|
class MessageError(Exception):
|
|
pass
|
|
|
|
|
|
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
|
|
"""
|
|
Changes the YAML dump style to multiline syntax for multiline strings.
|
|
"""
|
|
if len(data.splitlines()) > 1:
|
|
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
|
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
|
|
|
|
|
yaml.add_representer(str, str_presenter)
|
|
|
|
|
|
def source_code(text: str, include_delims: bool = False) -> list[str]:
|
|
"""
|
|
Extract all source code sections from the given text, i. e. all lines
|
|
surrounded by lines tarting with '```'. If 'include_delims' is True,
|
|
the surrounding lines are included, otherwise they are omitted. The
|
|
result list contains every source code section as a single string.
|
|
The order in the list represents the order of the sections in the text.
|
|
"""
|
|
code_sections: list[str] = []
|
|
code_lines: list[str] = []
|
|
in_code_block = False
|
|
|
|
for line in text.split('\n'):
|
|
if line.strip().startswith('```'):
|
|
if include_delims:
|
|
code_lines.append(line)
|
|
if in_code_block:
|
|
code_sections.append('\n'.join(code_lines) + '\n')
|
|
code_lines.clear()
|
|
in_code_block = not in_code_block
|
|
elif in_code_block:
|
|
code_lines.append(line)
|
|
|
|
return code_sections
|
|
|
|
|
|
def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool:
|
|
"""
|
|
Searches the given message list for a message with the same file
|
|
name as the given one (i. e. it compares Message.file_path.name).
|
|
If the given message has no file_path, False is returned.
|
|
"""
|
|
if not message.file_path:
|
|
return False
|
|
for m in messages:
|
|
if m.file_path and m.file_path.name == message.file_path.name:
|
|
return True
|
|
return False
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class MessageFilter:
|
|
"""
|
|
Various filters for a Message.
|
|
"""
|
|
tags_or: Optional[set[Tag]] = None
|
|
tags_and: Optional[set[Tag]] = None
|
|
tags_not: Optional[set[Tag]] = None
|
|
ai: Optional[str] = None
|
|
model: Optional[str] = None
|
|
question_contains: Optional[str] = None
|
|
answer_contains: Optional[str] = None
|
|
answer_state: Optional[Literal['available', 'missing']] = None
|
|
ai_state: Optional[Literal['available', 'missing']] = None
|
|
model_state: Optional[Literal['available', 'missing']] = None
|
|
|
|
|
|
class AILine(str):
|
|
"""
|
|
A line that represents the AI name in a '.txt' file..
|
|
"""
|
|
prefix: Final[str] = 'AI:'
|
|
|
|
def __new__(cls: Type[AILineInst], string: str) -> AILineInst:
|
|
if not string.startswith(cls.prefix):
|
|
raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'")
|
|
instance = super().__new__(cls, string)
|
|
return instance
|
|
|
|
def ai(self) -> str:
|
|
return self[len(self.prefix):].strip()
|
|
|
|
@classmethod
|
|
def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst:
|
|
return cls(' '.join([cls.prefix, ai]))
|
|
|
|
|
|
class ModelLine(str):
|
|
"""
|
|
A line that represents the model name in a '.txt' file..
|
|
"""
|
|
prefix: Final[str] = 'MODEL:'
|
|
|
|
def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst:
|
|
if not string.startswith(cls.prefix):
|
|
raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'")
|
|
instance = super().__new__(cls, string)
|
|
return instance
|
|
|
|
def model(self) -> str:
|
|
return self[len(self.prefix):].strip()
|
|
|
|
@classmethod
|
|
def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst:
|
|
return cls(' '.join([cls.prefix, model]))
|
|
|
|
|
|
class Answer(str):
|
|
"""
|
|
A single answer with a defined header.
|
|
"""
|
|
tokens: int = 0 # tokens used by this answer
|
|
txt_header: ClassVar[str] = '==== ANSWER ===='
|
|
yaml_key: ClassVar[str] = 'answer'
|
|
|
|
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
|
|
"""
|
|
Make sure the answer string does not contain the header as a whole line.
|
|
"""
|
|
if cls.txt_header in string.split('\n'):
|
|
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
|
|
instance = super().__new__(cls, string)
|
|
return instance
|
|
|
|
@classmethod
|
|
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
|
|
"""
|
|
Build Question from a list of strings. Make sure strings do not contain the header.
|
|
"""
|
|
if cls.txt_header in strings:
|
|
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
|
instance = super().__new__(cls, '\n'.join(strings).strip())
|
|
return instance
|
|
|
|
def source_code(self, include_delims: bool = False) -> list[str]:
|
|
"""
|
|
Extract and return all source code sections.
|
|
"""
|
|
return source_code(self, include_delims)
|
|
|
|
|
|
class Question(str):
|
|
"""
|
|
A single question with a defined header.
|
|
"""
|
|
tokens: int = 0 # tokens used by this question
|
|
txt_header: ClassVar[str] = '=== QUESTION ==='
|
|
yaml_key: ClassVar[str] = 'question'
|
|
|
|
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
|
"""
|
|
Make sure the question string does not contain the header as a whole line
|
|
(also not that from 'Answer', so it's always clear where the answer starts).
|
|
"""
|
|
string_lines = string.split('\n')
|
|
if cls.txt_header in string_lines:
|
|
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
|
|
if Answer.txt_header in string_lines:
|
|
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
|
|
instance = super().__new__(cls, string)
|
|
return instance
|
|
|
|
@classmethod
|
|
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
|
|
"""
|
|
Build Question from a list of strings. Make sure strings do not contain the header.
|
|
"""
|
|
if cls.txt_header in strings:
|
|
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
|
instance = super().__new__(cls, '\n'.join(strings).strip())
|
|
return instance
|
|
|
|
def source_code(self, include_delims: bool = False) -> list[str]:
|
|
"""
|
|
Extract and return all source code sections.
|
|
"""
|
|
return source_code(self, include_delims)
|
|
|
|
|
|
@dataclass
|
|
class Message():
|
|
"""
|
|
Single message. Consists of a question and optionally an answer, a set of tags
|
|
and a file path.
|
|
"""
|
|
question: Question
|
|
answer: Optional[Answer] = None
|
|
# metadata, ignored when comparing messages
|
|
tags: Optional[set[Tag]] = field(default=None, compare=False)
|
|
ai: Optional[str] = field(default=None, compare=False)
|
|
model: Optional[str] = field(default=None, compare=False)
|
|
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
|
|
# class variables
|
|
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
|
|
tags_yaml_key: ClassVar[str] = 'tags'
|
|
file_yaml_key: ClassVar[str] = 'file_path'
|
|
ai_yaml_key: ClassVar[str] = 'ai'
|
|
model_yaml_key: ClassVar[str] = 'model'
|
|
|
|
def __hash__(self) -> int:
|
|
"""
|
|
The hash value is computed based on immutable members.
|
|
"""
|
|
return hash((self.question, self.answer))
|
|
|
|
@classmethod
|
|
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
|
|
"""
|
|
Create a Message from the given dict.
|
|
"""
|
|
return cls(question=data[Question.yaml_key],
|
|
answer=data.get(Answer.yaml_key, None),
|
|
tags=set(data.get(cls.tags_yaml_key, [])),
|
|
ai=data.get(cls.ai_yaml_key, None),
|
|
model=data.get(cls.model_yaml_key, None),
|
|
file_path=data.get(cls.file_yaml_key, None))
|
|
|
|
@classmethod
|
|
def tags_from_file(cls: Type[MessageInst],
|
|
file_path: pathlib.Path,
|
|
prefix: Optional[str] = None,
|
|
contain: Optional[str] = None) -> set[Tag]:
|
|
"""
|
|
Return only the tags from the given Message file,
|
|
optionally filtered based on prefix or contained string.
|
|
"""
|
|
tags: set[Tag] = set()
|
|
if not file_path.exists():
|
|
raise MessageError(f"Message file '{file_path}' does not exist")
|
|
if file_path.suffix not in cls.file_suffixes:
|
|
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
|
# for TXT, it's enough to read the TagLine
|
|
if file_path.suffix == '.txt':
|
|
with open(file_path, "r") as fd:
|
|
try:
|
|
tags = TagLine(fd.readline()).tags(prefix, contain)
|
|
except TagError:
|
|
pass # message without tags
|
|
else: # '.yaml'
|
|
try:
|
|
message = cls.from_file(file_path)
|
|
if message:
|
|
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
|
|
except MessageError as e:
|
|
print(f"Error processing message in '{file_path}': {str(e)}")
|
|
if msg_tags:
|
|
tags = msg_tags
|
|
return tags
|
|
|
|
@classmethod
|
|
def tags_from_dir(cls: Type[MessageInst],
|
|
path: pathlib.Path,
|
|
glob: Optional[str] = None,
|
|
prefix: Optional[str] = None,
|
|
contain: Optional[str] = None) -> set[Tag]:
|
|
|
|
"""
|
|
Return only the tags from message files in the given directory.
|
|
The files can be filtered using 'glob', the tags by using 'prefix'
|
|
and 'contain'.
|
|
"""
|
|
tags: set[Tag] = set()
|
|
file_iter = path.glob(glob) if glob else path.iterdir()
|
|
for file_path in sorted(file_iter):
|
|
if file_path.is_file():
|
|
try:
|
|
tags |= cls.tags_from_file(file_path, prefix, contain)
|
|
except MessageError as e:
|
|
print(f"Error processing message in '{file_path}': {str(e)}")
|
|
return tags
|
|
|
|
@classmethod
|
|
def from_file(cls: Type[MessageInst], file_path: pathlib.Path,
|
|
mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]:
|
|
"""
|
|
Create a Message from the given file. Returns 'None' if the message does
|
|
not fulfill the filter requirements. For TXT files, the tags are matched
|
|
before building the whole message. The other filters are applied afterwards.
|
|
"""
|
|
if not file_path.exists():
|
|
raise MessageError(f"Message file '{file_path}' does not exist")
|
|
if file_path.suffix not in cls.file_suffixes:
|
|
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
|
|
|
if file_path.suffix == '.txt':
|
|
message = cls.__from_file_txt(file_path,
|
|
mfilter.tags_or if mfilter else None,
|
|
mfilter.tags_and if mfilter else None,
|
|
mfilter.tags_not if mfilter else None)
|
|
else:
|
|
message = cls.__from_file_yaml(file_path)
|
|
if message and (mfilter is None or message.match(mfilter)):
|
|
return message
|
|
else:
|
|
return None
|
|
|
|
@classmethod
|
|
def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11
|
|
tags_or: Optional[set[Tag]] = None,
|
|
tags_and: Optional[set[Tag]] = None,
|
|
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
|
|
"""
|
|
Create a Message from the given TXT file. Expects the following file structures:
|
|
For '.txt':
|
|
* TagLine [Optional]
|
|
* AI [Optional]
|
|
* Model [Optional]
|
|
* Question.txt_header
|
|
* Question
|
|
* Answer.txt_header [Optional]
|
|
* Answer [Optional]
|
|
|
|
Returns 'None' if the message does not fulfill the tag requirements.
|
|
"""
|
|
tags: set[Tag] = set()
|
|
question: Question
|
|
answer: Optional[Answer] = None
|
|
ai: Optional[str] = None
|
|
model: Optional[str] = None
|
|
with open(file_path, "r") as fd:
|
|
# TagLine (Optional)
|
|
try:
|
|
pos = fd.tell()
|
|
tags = TagLine(fd.readline()).tags()
|
|
except TagError:
|
|
fd.seek(pos)
|
|
if tags_or or tags_and or tags_not:
|
|
# match with an empty set if the file has no tags
|
|
if not match_tags(tags, tags_or, tags_and, tags_not):
|
|
return None
|
|
# AILine (Optional)
|
|
try:
|
|
pos = fd.tell()
|
|
ai = AILine(fd.readline()).ai()
|
|
except MessageError:
|
|
fd.seek(pos)
|
|
# ModelLine (Optional)
|
|
try:
|
|
pos = fd.tell()
|
|
model = ModelLine(fd.readline()).model()
|
|
except MessageError:
|
|
fd.seek(pos)
|
|
# Question and Answer
|
|
text = fd.read().strip().split('\n')
|
|
try:
|
|
question_idx = text.index(Question.txt_header) + 1
|
|
except ValueError:
|
|
raise MessageError(f"'{file_path}' does not contain a valid message")
|
|
try:
|
|
answer_idx = text.index(Answer.txt_header)
|
|
question = Question.from_list(text[question_idx:answer_idx])
|
|
answer = Answer.from_list(text[answer_idx + 1:])
|
|
except ValueError:
|
|
question = Question.from_list(text[question_idx:])
|
|
return cls(question, answer, tags, ai, model, file_path)
|
|
|
|
@classmethod
|
|
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
|
|
"""
|
|
Create a Message from the given YAML file. Expects the following file structures:
|
|
* Question.yaml_key: single or multiline string
|
|
* Answer.yaml_key: single or multiline string [Optional]
|
|
* Message.tags_yaml_key: list of strings [Optional]
|
|
* Message.ai_yaml_key: str [Optional]
|
|
* Message.model_yaml_key: str [Optional]
|
|
"""
|
|
with open(file_path, "r") as fd:
|
|
try:
|
|
data = yaml.load(fd, Loader=yaml.FullLoader)
|
|
data[cls.file_yaml_key] = file_path
|
|
except Exception:
|
|
raise MessageError(f"'{file_path}' does not contain a valid message")
|
|
return cls.from_dict(data)
|
|
|
|
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
|
|
"""
|
|
Return the current Message as a string.
|
|
"""
|
|
output: list[str] = []
|
|
if source_code_only:
|
|
# use the source code from answer only
|
|
if self.answer:
|
|
output.extend(self.answer.source_code(include_delims=True))
|
|
return '\n'.join(output) if len(output) > 0 else ''
|
|
if with_tags:
|
|
output.append(self.tags_str())
|
|
if with_file:
|
|
output.append('FILE: ' + str(self.file_path))
|
|
output.append(Question.txt_header)
|
|
output.append(self.question)
|
|
if self.answer:
|
|
output.append(Answer.txt_header)
|
|
output.append(self.answer)
|
|
return '\n'.join(output)
|
|
|
|
def __str__(self) -> str:
|
|
return self.to_str(True, True, False)
|
|
|
|
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
|
|
"""
|
|
Write a Message to the given file. Type is determined based on the suffix.
|
|
Currently supported suffixes: ['.txt', '.yaml']
|
|
"""
|
|
if file_path:
|
|
self.file_path = file_path
|
|
if not self.file_path:
|
|
raise MessageError("Got no valid path to write message")
|
|
if self.file_path.suffix not in self.file_suffixes:
|
|
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
|
|
# TXT
|
|
if self.file_path.suffix == '.txt':
|
|
return self.__to_file_txt(self.file_path)
|
|
elif self.file_path.suffix == '.yaml':
|
|
return self.__to_file_yaml(self.file_path)
|
|
|
|
def __to_file_txt(self, file_path: pathlib.Path) -> None:
|
|
"""
|
|
Write a Message to the given file in TXT format.
|
|
Creates the following file structures:
|
|
* TagLine
|
|
* AI [Optional]
|
|
* Model [Optional]
|
|
* Question.txt_header
|
|
* Question
|
|
* Answer.txt_header
|
|
* Answer
|
|
"""
|
|
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
|
temp_file_path = pathlib.Path(temp_fd.name)
|
|
if self.tags:
|
|
temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
|
|
if self.ai:
|
|
temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
|
|
if self.model:
|
|
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
|
|
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
|
|
if self.answer:
|
|
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
|
|
shutil.move(temp_file_path, file_path)
|
|
|
|
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
|
|
"""
|
|
Write a Message to the given file in YAML format.
|
|
Creates the following file structures:
|
|
* Question.yaml_key: single or multiline string
|
|
* Answer.yaml_key: single or multiline string
|
|
* Message.tags_yaml_key: list of strings
|
|
* Message.ai_yaml_key: str [Optional]
|
|
* Message.model_yaml_key: str [Optional]
|
|
"""
|
|
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
|
temp_file_path = pathlib.Path(temp_fd.name)
|
|
data: YamlDict = {Question.yaml_key: str(self.question)}
|
|
if self.answer:
|
|
data[Answer.yaml_key] = str(self.answer)
|
|
if self.ai:
|
|
data[self.ai_yaml_key] = self.ai
|
|
if self.model:
|
|
data[self.model_yaml_key] = self.model
|
|
if self.tags:
|
|
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
|
|
yaml.dump(data, temp_fd, sort_keys=False)
|
|
shutil.move(temp_file_path, file_path)
|
|
|
|
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
|
"""
|
|
Filter tags based on their prefix (i. e. the tag starts with a given string)
|
|
or some contained string.
|
|
"""
|
|
if not self.tags:
|
|
return set()
|
|
res_tags = self.tags.copy()
|
|
if prefix and len(prefix) > 0:
|
|
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
|
|
if contain and len(contain) > 0:
|
|
res_tags -= {tag for tag in res_tags if contain not in tag}
|
|
return res_tags
|
|
|
|
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
|
|
"""
|
|
Returns all tags as a string with the TagLine prefix. Optionally filtered
|
|
using 'Message.filter_tags()'.
|
|
"""
|
|
if self.tags:
|
|
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
|
|
else:
|
|
return str(TagLine.from_set(set()))
|
|
|
|
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
|
|
"""
|
|
Matches the current Message to the given filter atttributes.
|
|
Return True if all attributes match, else False.
|
|
"""
|
|
mytags = self.tags or set()
|
|
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
|
|
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
|
|
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
|
|
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
|
|
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
|
|
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
|
|
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
|
|
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
|
|
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
|
|
or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503
|
|
or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503
|
|
or (mfilter.model_state == 'missing' and self.model)): # noqa: W503
|
|
return False
|
|
return True
|
|
|
|
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None:
|
|
"""
|
|
Renames the given tags. The first tuple element is the old name,
|
|
the second one is the new name.
|
|
"""
|
|
if self.tags:
|
|
self.tags = rename_tags(self.tags, tags_rename)
|
|
|
|
def clear_answer(self) -> None:
|
|
self.answer = None
|
|
|
|
def msg_id(self) -> str:
|
|
"""
|
|
Returns an ID that is unique throughout all messages in the same (DB) directory.
|
|
Currently this is the file name without suffix. The ID is also used for sorting
|
|
messages.
|
|
"""
|
|
if self.file_path:
|
|
return self.file_path.stem
|
|
else:
|
|
raise MessageError("Can't create file ID without a file path")
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
def tokens(self) -> int:
|
|
"""
|
|
Returns the nr. of AI language tokens used by this message.
|
|
If unknown, 0 is returned.
|
|
"""
|
|
if self.answer:
|
|
return self.question.tokens + self.answer.tokens
|
|
else:
|
|
return self.question.tokens
|