""" Module implementing various chat classes and functions for managing a chat history. """ import shutil from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union from .configuration import default_config_file from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] msg_location = Literal['mem', 'disk', 'cache', 'db', 'all'] class ChatError(Exception): pass def terminal_width() -> int: return shutil.get_terminal_size().columns def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) def print_paged(text: str) -> None: pager(text) def read_dir(dir_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Reads the messages from the given folder. Parameters: * 'dir_path': source directory * 'glob': if specified, files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. """ messages: list[Message] = [] file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in sorted(file_iter): if (file_path.is_file() and file_path.name not in ignored_files # noqa: W503 and file_path.suffix in Message.file_suffixes): # noqa: W503 try: message = Message.from_file(file_path, mfilter) if message: messages.append(message) except MessageError as e: print(f"WARNING: Skipping message in '{file_path}': {str(e)}") return messages def make_file_path(dir_path: Path, file_suffix: str, next_fid: Callable[[], int]) -> Path: """ Create a file_path for the given directory using the given file_suffix and ID generator function. """ file_path = dir_path / f"{next_fid():04d}{file_suffix}" while file_path.exists(): file_path = dir_path / f"{next_fid():04d}{file_suffix}" return file_path def write_dir(dir_path: Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: """ Write all messages to the given directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the given directory. Parameters: * 'dir_path': destination directory * 'messages': list of messages to write * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] * 'next_fid': callable that returns the next file ID """ for message in messages: file_path = message.file_path # message has no file_path: create one if not file_path: file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name message.to_file(file_path) def clear_dir(dir_path: Path, glob: Optional[str] = None) -> None: """ Deletes all Message files in the given directory. """ file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in file_iter: if file_path.is_file() and file_path.suffix in Message.file_suffixes: file_path.unlink(missing_ok=True) @dataclass class Chat: """ A class containing a complete chat history. """ messages: list[Message] def __post_init__(self) -> None: self.validate() def validate(self) -> None: """ Validate this Chat instance. """ def msg_paths(stem: str) -> list[str]: return [str(fp) for fp in file_paths if fp.stem == stem] file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None} file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None] error = False for fp in file_paths: if file_stems.count(fp.stem) > 1: print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}") error = True if error: raise ChatError("Validation failed") def msg_name_matches(self, file_path: Path, name: str) -> bool: """ Return True if the given name matches the given file_path. Matching is True if: * 'name' matches the full 'file_path' * 'name' matches 'file_path.name' (i. e. including the suffix) * 'name' matches 'file_path.stem' (i. e. without a suffix) """ return Path(name) == file_path or name == file_path.name or name == file_path.stem def msg_filter(self, mfilter: MessageFilter) -> None: """ Use 'Message.match(mfilter) to remove all messages that don't fulfill the filter requirements. """ self.messages = [m for m in self.messages if m.match(mfilter)] def msg_sort(self, reverse: bool = False) -> None: """ Sort the messages according to 'Message.msg_id()'. """ try: # the message may not have an ID if it doesn't have a file_path self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) except MessageError: pass def msg_unique_id(self) -> None: """ Remove duplicates from the internal messages, based on the msg_id (i. e. file_path). Messages without a file_path are kept. """ old_msgs = self.messages.copy() self.messages = [] for m in old_msgs: if not message_in(m, self.messages): self.messages.append(m) self.msg_sort() def msg_unique_content(self) -> None: """ Remove duplicates from the internal messages, based on the content (i. e. question + answer). """ self.messages = list(set(self.messages)) self.msg_sort() def msg_clear(self) -> None: """ Delete all messages. """ self.messages = [] def msg_add(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ self.messages += messages self.msg_sort() def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]: """ Return the last added message (according to the file ID) that matches the given filter. When containing messages without a valid file_path, it returns the latest message in the internal list. """ if len(self.messages) > 0: self.msg_sort() for m in reversed(self.messages): if mfilter is None or m.match(mfilter): return m return None def msg_find(self, msg_names: list[str]) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). Messages that can't be found are ignored (i. e. the caller should check the result if they require all messages). """ return [m for m in self.messages if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] def msg_remove(self, msg_names: list[str]) -> None: """ Remove the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). """ self.messages = [m for m in self.messages if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] self.msg_sort() def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. """ tags: set[Tag] = set() for m in self.messages: tags |= m.filter_tags(prefix, contain) return set(sorted(tags)) def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: """ Get the frequency of all tags of all messages, optionally filtered by prefix or substring. """ tags: list[Tag] = [] for m in self.messages: tags += [tag for tag in m.filter_tags(prefix, contain)] return {tag: tags.count(tag) for tag in sorted(tags)} def tokens(self) -> int: """ Returns the nr. of AI language tokens used by all messages in this chat. If unknown, 0 is returned. """ return sum(m.tokens() for m in self.messages) def print(self, source_code_only: bool = False, with_tags: bool = False, with_files: bool = False, paged: bool = True) -> None: output: list[str] = [] for message in self.messages: if source_code_only: output.append(message.to_str(source_code_only=True)) continue output.append(message.to_str(with_tags, with_files)) if paged: print_paged('\n'.join(output)) else: print(*output, sep='\n') @dataclass class ChatDB(Chat): """ A 'Chat' class that is bound to a given directory structure. Supports reading and writing messages from / to that structure. Such a structure consists of two directories: a 'cache directory', where all messages are temporarily stored, and a 'DB' directory, where selected messages can be stored persistently. """ default_file_suffix: ClassVar[str] = '.txt' cache_path: Path db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix # the glob pattern for all messages glob: Optional[str] = None def __post_init__(self) -> None: # contains the latest message ID self.next_path = self.db_path / db_next_file # make all paths absolute self.cache_path = self.cache_path.absolute() self.db_path = self.db_path.absolute() self.validate() @classmethod def from_dir(cls: Type[ChatDBInst], cache_path: Path, db_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ Create a 'ChatDB' instance from the given directory structure. Reads all messages from 'db_path' into the local message list. Parameters: * 'cache_path': path to the directory for temporary messages * 'db_path': path to the directory for persistent messages * 'glob': if specified, files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. * 'mfilter': use with 'Message.from_file()' to filter messages when reading them. """ messages = read_dir(db_path, glob, mfilter) return cls(messages, cache_path, db_path, mfilter, cls.default_file_suffix, glob) @classmethod def from_messages(cls: Type[ChatDBInst], cache_path: Path, db_path: Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ Create a ChatDB instance from the given message list. """ return cls(messages, cache_path, db_path, mfilter) def get_next_fid(self) -> int: try: with open(self.next_path, 'r') as f: next_fid = int(f.read()) + 1 self.set_next_fid(next_fid) return next_fid except Exception: self.set_next_fid(1) return 1 def set_next_fid(self, fid: int) -> None: with open(self.next_path, 'w') as f: f.write(f'{fid}') def msg_write(self, messages: Optional[list[Message]] = None) -> None: """ Write either the given messages or the internal ones to their CURRENT file_path. If messages are given, they all must have a valid file_path. When writing the internal messages, the ones with a valid file_path are written, the others are ignored. """ if messages and any(m.file_path is None for m in messages): raise ChatError("Can't write files without a valid file_path") msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): m.to_file() def msg_update(self, messages: list[Message], write: bool = True) -> None: """ Update EXISTING messages. A message is determined as 'existing' if a message with the same base filename (i. e. 'file_path.name') is already in the list. Only accepts existing messages. """ if any(not message_in(m, self.messages) for m in messages): raise ChatError("Can't update messages that are not in the internal list") # remove old versions and add new ones self.messages = [m for m in self.messages if not message_in(m, messages)] self.messages += messages self.msg_sort() # write the UPDATED messages if requested if write: self.msg_write(messages) def msg_gather(self, loc: msg_location, require_file_path: bool = False, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Gather and return messages from the given locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory * 'db' : messages in the DB directory * 'all' : all messages ('mem' + 'disk') If 'require_file_path' is True, return only files with a valid file_path. """ loc_messages: list[Message] = [] if loc in ['mem', 'all']: if require_file_path: loc_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))] else: loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] if loc in ['cache', 'disk', 'all']: loc_messages += read_dir(self.cache_path, mfilter=mfilter) if loc in ['db', 'disk', 'all']: loc_messages += read_dir(self.db_path, mfilter=mfilter) # remove_duplicates and sort the list unique_messages: list[Message] = [] for m in loc_messages: if not message_in(m, unique_messages): unique_messages.append(m) try: unique_messages.sort(key=lambda m: m.msg_id()) # messages in 'mem' can have an empty file_path except MessageError: pass return unique_messages def msg_find(self, msg_names: list[str], loc: msg_location = 'mem', ) -> list[Message]: """ Search and return the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). Messages that can't be found are ignored (i. e. the caller should check the result if they require all messages). Searches one of the following locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory * 'db' : messages in the DB directory * 'all' : all messages ('mem' + 'disk') """ loc_messages = self.msg_gather(loc, require_file_path=True) return [m for m in loc_messages if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)] def msg_remove(self, msg_names: list[str], loc: msg_location = 'mem') -> None: """ Remove the messages with the given names. Names can either be filenames (with or without suffix), full paths or Message.msg_id(). Also deletes the files of all given messages with a valid file_path. Delete files from one of the following locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory * 'db' : messages in the DB directory * 'all' : all messages ('mem' + 'disk') """ if loc != 'mem': # delete the message files first rm_messages = self.msg_find(msg_names, loc=loc) for m in rm_messages: if (m.file_path): m.file_path.unlink() # then remove them from the internal list super().msg_remove(msg_names) def msg_latest(self, mfilter: Optional[MessageFilter] = None, loc: msg_location = 'mem') -> Optional[Message]: """ Return the last added message (according to the file ID) that matches the given filter. Only consider messages with a valid file_path (except if loc is 'mem'). Searches one of the following locations: * 'mem' : messages currently in memory * 'disk' : messages on disk (cache + DB directory), but not in memory * 'cache': messages in the cache directory * 'db' : messages in the DB directory * 'all' : all messages ('mem' + 'disk') """ # only consider messages with a valid file_path so they can be sorted loc_messages = self.msg_gather(loc, require_file_path=True) loc_messages.sort(key=lambda m: m.msg_id(), reverse=True) for m in loc_messages: if mfilter is None or m.match(mfilter): return m return None def msg_in_cache(self, message: Union[Message, str]) -> bool: """ Return true if the given Message (or filename or Message.msg_id()) is located in the cache directory. False otherwise. """ if isinstance(message, Message): return (message.file_path is not None and message.file_path.parent.samefile(self.cache_path) # noqa: W503 and message.file_path.exists()) # noqa: W503 else: return len(self.msg_find([message], loc='cache')) > 0 def msg_in_db(self, message: Union[Message, str]) -> bool: """ Return true if the given Message (or filename or Message.msg_id()) is located in the DB directory. False otherwise. """ if isinstance(message, Message): return (message.file_path is not None and message.file_path.parent.samefile(self.db_path) # noqa: W503 and message.file_path.exists()) # noqa: W503 else: return len(self.msg_find([message], loc='db')) > 0 def cache_read(self) -> None: """ Read messages from the cache directory. New ones are added to the internal list, existing ones are replaced. A message is determined as 'existing' if a message with the same base filename (i. e. 'file_path.name') is already in the list. """ new_messages = read_dir(self.cache_path, self.glob, self.mfilter) # remove all messages from self.messages that are in the new list self.messages = [m for m in self.messages if not message_in(m, new_messages)] # copy the messages from the temporary list to self.messages and sort them self.messages += new_messages self.msg_sort() def cache_write(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. Does NOT add the messages to the internal list (use 'cache_add()' for that)! """ write_dir(self.cache_path, messages if messages else self.messages, self.file_suffix, self.get_next_fid) def cache_add(self, messages: list[Message], write: bool = True) -> None: """ Add NEW messages and set the file_path to the cache directory. Only accepts messages without a file_path. """ if any(m.file_path is not None for m in messages): raise ChatError("Can't add new messages with existing file_path") if write: write_dir(self.cache_path, messages, self.file_suffix, self.get_next_fid) else: for m in messages: m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) self.messages += messages self.msg_sort() def cache_clear(self) -> None: """ Delete all message files from the cache dir and remove them from the internal list. """ clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] def cache_move(self, message: Message) -> None: """ Moves the given messages to the cache directory. """ # remember the old path (if any) old_path: Optional[Path] = None if message.file_path: old_path = message.file_path # write message to the new destination self.cache_write([message]) # remove the old one (if any) if old_path: self.msg_remove([str(old_path)], loc='db') # (re)add it to the internal list self.msg_add([message]) def db_read(self) -> None: """ Read messages from the DB directory. New ones are added to the internal list, existing ones are replaced. A message is determined as 'existing' if a message with the same base filename (i. e. 'file_path.name') is already in the list. """ new_messages = read_dir(self.db_path, self.glob, self.mfilter) # remove all messages from self.messages that are in the new list self.messages = [m for m in self.messages if not message_in(m, new_messages)] # copy the messages from the temporary list to self.messages and sort them self.messages += new_messages self.msg_sort() def db_write(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. Does NOT add the messages to the internal list (use 'db_add()' for that)! """ write_dir(self.db_path, messages if messages else self.messages, self.file_suffix, self.get_next_fid) def db_add(self, messages: list[Message], write: bool = True) -> None: """ Add NEW messages and set the file_path to the DB directory. Only accepts messages without a file_path. """ if any(m.file_path is not None for m in messages): raise ChatError("Can't add new messages with existing file_path") if write: write_dir(self.db_path, messages, self.file_suffix, self.get_next_fid) else: for m in messages: m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) self.messages += messages self.msg_sort() def db_move(self, message: Message) -> None: """ Moves the given messages to the db directory. """ # remember the old path (if any) old_path: Optional[Path] = None if message.file_path: old_path = message.file_path # write message to the new destination self.db_write([message]) # remove the old one (if any) if old_path: self.msg_remove([str(old_path)], loc='cache') # (re)add it to the internal list self.msg_add([message])