""" Module implementing various chat classes and functions for managing a chat history. """ import shutil from pprint import PrettyPrinter import pathlib from dataclasses import dataclass, field from typing import TypeVar, Type, Optional, ClassVar, Any from .message import Message, MessageFilter, MessageError ChatInst = TypeVar('ChatInst', bound='Chat') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') class ChatError(Exception): pass def terminal_width() -> int: return shutil.get_terminal_size().columns def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) @dataclass class Chat: """ A class containing a complete chat history. """ messages: list[Message] def filter(self, mfilter: MessageFilter) -> None: """ Use 'Message.match(mfilter) to remove all messages that don't fulfill the filter requirements. """ self.messages = [m for m in self.messages if m.match(mfilter)] def sort(self, reverse: bool = False) -> None: """ Sort the messages according to 'Message.msg_id()'. """ try: # the message may not have an ID if it doesn't have a file_path self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) except MessageError: pass def clear(self) -> None: """ Delete all messages. """ self.messages = [] def add_msgs(self, msgs: list[Message]) -> None: """ Add new messages and sort them if possible. """ self.messages += msgs self.sort() def print(self, dump: bool = False) -> None: if dump: pp(self) return # for message in self.messages: # text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 # if source_code: # display_source_code(message['content']) # continue # if message['role'] == 'user': # print('-' * terminal_width()) # if text_too_long: # print(f"{message['role'].upper()}:") # print(message['content']) # else: # print(f"{message['role'].upper()}: {message['content']}") @dataclass class ChatDB(Chat): """ A 'Chat' class that is bound to a given directory structure. Supports reading and writing messages from / to that structure. Such a structure consists of two directories: a 'cache directory', where all messages are temporarily stored, and a 'DB' directory, where selected messages can be stored persistently. """ default_file_suffix: ClassVar[str] = '.txt' cache_path: pathlib.Path db_path: pathlib.Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix # the glob pattern for all messages glob: Optional[str] = None # set containing all file names of the current messages message_files: set[str] = field(default_factory=set, repr=False) @classmethod def from_dir(cls: Type[ChatDBInst], cache_path: pathlib.Path, db_path: pathlib.Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ Create a 'ChatDB' instance from the given directory structure. Reads all messages from 'db_path' into the local message list. Parameters: * 'cache_path': path to the directory for temporary messages * 'db_path': path to the directory for persistent messages * 'glob' fs specified, files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. """ messages: list[Message] = [] message_files: set[str] = set() file_iter = db_path.glob(glob) if glob else db_path.iterdir() for file_path in sorted(file_iter): if file_path.is_file(): try: message = Message.from_file(file_path, mfilter) if message: messages.append(message) message_files.add(file_path.name) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") return cls(messages, cache_path, db_path, mfilter, cls.default_file_suffix, glob, message_files) @classmethod def from_messages(cls: Type[ChatDBInst], cache_path: pathlib.Path, db_path: pathlib.Path, messages: list[Message], mfilter: Optional[MessageFilter]) -> ChatDBInst: """ Create a ChatDB instance from the given message list. Note that the next call to 'dump()' will write all files in order to synchronize the messages. Similarly, 'update()' will read all messages, so you may end up with a lot of duplicates when using 'update()' first. """ return cls(messages, cache_path, db_path, mfilter) def get_next_fid(self) -> int: next_fname = self.db_path / '.next' try: with open(next_fname, 'r') as f: return int(f.read()) + 1 except Exception: return 1 def set_next_fid(self, fid: int) -> None: next_fname = self.db_path / '.next' with open(next_fname, 'w') as f: f.write(f'{fid}') def dump(self, to_db: bool = False, force_all: bool = False) -> None: """ Write all messages to 'cache_path' (or 'db_path' if 'to_db' is True). If a message has no file_path, a new one will be created. By default, only messages that have not been written (or read) before will be dumped. Use 'force_all' to force writing all message files. """ for message in self.messages: # skip messages that we have already written (or read) if message.file_path and message.file_path in self.message_files and not force_all: continue file_path = message.file_path if not file_path: fid = self.get_next_fid() fname = f"{fid:04d}{self.file_suffix}" file_path = self.db_path / fname if to_db else self.cache_path / fname self.set_next_fid(fid) message.to_file(file_path) def update(self, from_cache: bool = False, force_all: bool = False) -> None: """ Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true). By default, only messages that have not been read (or written) before will be read. Use 'force_all' to force reading all messages. """ if from_cache: file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir() else: file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir() for file_path in sorted(file_iter): if file_path.is_file(): if file_path.name in self.message_files and not force_all: continue try: message = Message.from_file(file_path, self.mfilter) if message: self.messages.append(message) self.message_files.add(file_path.name) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") self.sort()