""" Module implementing various chat classes and functions for managing a chat history. """ import shutil from pprint import PrettyPrinter import pathlib from dataclasses import dataclass, field from typing import TypeVar, Type, Optional, ClassVar, Any from .message import Message, MessageFilter, MessageError ChatInst = TypeVar('ChatInst', bound='Chat') ChatDirInst = TypeVar('ChatDirInst', bound='ChatDir') class ChatError(Exception): pass def terminal_width() -> int: return shutil.get_terminal_size().columns def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) @dataclass class Chat: """ A class containing a complete chat history. """ messages: list[Message] def filter(self, mfilter: MessageFilter) -> None: """ Use 'Message.match(mfilter) to remove all messages that don't fulfill the filter requirements. """ self.messages = [m for m in self.messages if m.match(mfilter)] def print(self, dump: bool = False) -> None: if dump: pp(self) return # for message in self.messages: # text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 # if source_code: # display_source_code(message['content']) # continue # if message['role'] == 'user': # print('-' * terminal_width()) # if text_too_long: # print(f"{message['role'].upper()}:") # print(message['content']) # else: # print(f"{message['role'].upper()}: {message['content']}") @dataclass class ChatDir(Chat): """ A Chat class that is bound to a given directory. Supports reading and writing messages from / to that directory. """ default_file_suffix: ClassVar[str] = '.txt' directory: pathlib.Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix # set containing all file names of the current messages message_files: set[str] = field(default_factory=set) @classmethod def from_dir(cls: Type[ChatDirInst], path: pathlib.Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDirInst: """ Create a ChatDir instance from the given directory. If 'glob' is specified, files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. Messages are created using 'Message.from_file()' and the optional MessageFilter. """ messages: list[Message] = [] message_files: set[str] = set() file_iter = path.glob(glob) if glob else path.iterdir() for file_path in sorted(file_iter): if file_path.is_file(): try: message = Message.from_file(file_path, mfilter) if message: messages.append(message) message_files.add(file_path.name) except MessageError as e: print(f"Error processing message in '{file_path}': {str(e)}") return cls(messages, path, mfilter, cls.default_file_suffix, message_files) @classmethod def from_messages(cls: Type[ChatDirInst], path: pathlib.Path, messages: list[Message], mfilter: Optional[MessageFilter]) -> ChatDirInst: """ Create a ChatDir instance from the given message list. Note that the next call to 'dump()' will write all files in order to synchronize the messages. 'update()' is not supported until after the first 'dump()'. """ return cls(messages, path, mfilter) def get_next_fid(self) -> int: next_fname = self.directory / '.next' try: with open(next_fname, 'r') as f: return int(f.read()) + 1 except Exception: return 1 def set_next_fid(self, fid: int) -> None: next_fname = self.directory / '.next' with open(next_fname, 'w') as f: f.write(f'{fid}') def dump(self, force_all: bool = False) -> None: """ Writes all messages to the bound directory. If a message has no file_path, it will create a new one. By default, only messages that have not been written (or read) before will be dumped. Use 'force_all' to force writing all message files. """ # FIXME: write to 'db' subfolder or given folder for message in self.messages: # skip messages that we have already written (or read) if message.file_path and message.file_path in self.message_files and not force_all: continue file_path = message.file_path if not file_path: fid = self.get_next_fid() file_path = self.directory / f"{fid:04d}{self.file_suffix}" self.set_next_fid(fid) message.to_file(file_path)