""" Module implementing various chat classes and functions for managing a chat history. """ import shutil import pathlib from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, ClassVar, Any, Callable from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') class ChatError(Exception): pass def terminal_width() -> int: return shutil.get_terminal_size().columns def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) def print_paged(text: str) -> None: pager(text) def read_dir(dir_path: pathlib.Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Reads the messages from the given folder. Parameters: * 'dir_path': source directory * 'glob': if specified, files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. """ messages: list[Message] = [] file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in sorted(file_iter): if file_path.is_file(): try: message = Message.from_file(file_path, mfilter) if message: messages.append(message) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") return messages def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: """ Write all messages to the given directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the given directory. Parameters: * 'dir_path': destination directory * 'messages': list of messages to write * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] * 'next_fid': callable that returns the next file ID """ for message in messages: file_path = message.file_path # message has no file_path: create one if not file_path: fid = next_fid() fname = f"{fid:04d}{file_suffix}" file_path = dir_path / fname # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name message.to_file(file_path) @dataclass class Chat: """ A class containing a complete chat history. """ messages: list[Message] def filter(self, mfilter: MessageFilter) -> None: """ Use 'Message.match(mfilter) to remove all messages that don't fulfill the filter requirements. """ self.messages = [m for m in self.messages if m.match(mfilter)] def sort(self, reverse: bool = False) -> None: """ Sort the messages according to 'Message.msg_id()'. """ try: # the message may not have an ID if it doesn't have a file_path self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) except MessageError: pass def clear(self) -> None: """ Delete all messages. """ self.messages = [] def add_msgs(self, msgs: list[Message]) -> None: """ Add new messages and sort them if possible. """ self.messages += msgs self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. """ tags: set[Tag] = set() for m in self.messages: tags |= m.filter_tags(prefix, contain) return tags def print(self, dump: bool = False, source_code_only: bool = False, with_tags: bool = False, with_file: bool = False, paged: bool = True) -> None: if dump: pp(self) return output: list[str] = [] for message in self.messages: if source_code_only: output.extend(source_code(message.question, include_delims=True)) continue output.append('-' * terminal_width()) output.append(Question.txt_header) output.append(message.question) if message.answer: output.append(Answer.txt_header) output.append(message.answer) if with_tags: output.append(message.tags_str()) if with_file: output.append('FILE: ' + str(message.file_path)) if paged: print_paged('\n'.join(output)) else: print(*output, sep='\n') @dataclass class ChatDB(Chat): """ A 'Chat' class that is bound to a given directory structure. Supports reading and writing messages from / to that structure. Such a structure consists of two directories: a 'cache directory', where all messages are temporarily stored, and a 'DB' directory, where selected messages can be stored persistently. """ default_file_suffix: ClassVar[str] = '.txt' cache_path: pathlib.Path db_path: pathlib.Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix # the glob pattern for all messages glob: Optional[str] = None def __post_init__(self) -> None: # contains the latest message ID self.next_fname = self.db_path / '.next' @classmethod def from_dir(cls: Type[ChatDBInst], cache_path: pathlib.Path, db_path: pathlib.Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ Create a 'ChatDB' instance from the given directory structure. Reads all messages from 'db_path' into the local message list. Parameters: * 'cache_path': path to the directory for temporary messages * 'db_path': path to the directory for persistent messages * 'glob': if specified, files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. """ messages = read_dir(db_path, glob, mfilter) return cls(messages, cache_path, db_path, mfilter, cls.default_file_suffix, glob) @classmethod def from_messages(cls: Type[ChatDBInst], cache_path: pathlib.Path, db_path: pathlib.Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ Create a ChatDB instance from the given message list. """ return cls(messages, cache_path, db_path, mfilter) def get_next_fid(self) -> int: try: with open(self.next_fname, 'r') as f: next_fid = int(f.read()) + 1 self.set_next_fid(next_fid) return next_fid except Exception: self.set_next_fid(1) return 1 def set_next_fid(self, fid: int) -> None: with open(self.next_fname, 'w') as f: f.write(f'{fid}') def read_db(self) -> None: """ Reads new messages from the DB directory. New ones are added to the internal list, existing ones are kept or replaced. A message is determined as 'existing' if a message with the same base filename (i. e. 'file_path.name') is already in the list. """ new_messages = read_dir(self.db_path, self.glob, self.mfilter) # remove all messages from self.messages that are in the new list self.messages = [m for m in self.messages if not message_in(m, new_messages)] # copy the messages from the temporary list to self.messages and sort them self.messages += new_messages self.sort() def read_cache(self) -> None: """ Reads new messages from the cache directory. New ones are added to the internal list, existing ones are kept or replaced. A message is determined as 'existing' if a message with the same base filename (i. e. 'file_path.name') is already in the list. """ new_messages = read_dir(self.db_path, self.glob, self.mfilter) # remove all messages from self.messages that are in the new list self.messages = [m for m in self.messages if not message_in(m, new_messages)] # copy the messages from the temporary list to self.messages and sort them self.messages += new_messages self.sort() def write_db(self) -> None: """ Write all messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, self.messages, self.file_suffix, self.get_next_fid) def write_cache(self) -> None: """ Write all messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, self.messages, self.file_suffix, self.get_next_fid)