408 lines
15 KiB
Python

"""
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
class ChatError(Exception):
pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_paged(text: str) -> None:
pager(text)
def read_dir(dir_path: Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Reads the messages from the given folder.
Parameters:
* 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return messages
def make_file_path(dir_path: Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
"""
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path,
messages: list[Message],
file_suffix: str,
next_fid: Callable[[], int]) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the given directory.
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
file_path = make_file_path(dir_path, file_suffix, next_fid)
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path)
def clear_dir(dir_path: Path,
glob: Optional[str] = None) -> None:
"""
Deletes all Message files in the given directory.
"""
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
file_path.unlink(missing_ok=True)
@dataclass
class Chat:
"""
A class containing a complete chat history.
"""
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
try:
# the message may not have an ID if it doesn't have a file_path
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
except MessageError:
pass
def clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def add_messages(self, messages: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += messages
self.sort()
def latest_message(self) -> Optional[Message]:
"""
Returns the last added message (according to the file ID).
"""
if len(self.messages) > 0:
self.sort()
return self.messages[-1]
else:
return None
def find_messages(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
caller should check the result if he requires all messages).
"""
return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths.
"""
self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
"""
tags: set[Tag] = set()
for m in self.messages:
tags |= m.filter_tags(prefix, contain)
return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
"""
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
tags: list[Tag] = []
for m in self.messages:
tags += [tag for tag in m.filter_tags(prefix, contain)]
return {tag: tags.count(tag) for tag in sorted(tags)}
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by all messages in this chat.
If unknown, 0 is returned.
"""
return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False,
paged: bool = True) -> None:
output: list[str] = []
for message in self.messages:
if source_code_only:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_tags, with_files))
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
print(*output, sep='\n')
@dataclass
class ChatDB(Chat):
"""
A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path
db_path: Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
def __post_init__(self) -> None:
# contains the latest message ID
self.next_fname = self.db_path / '.next'
# make all paths absolute
self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute()
@classmethod
def from_dir(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
@classmethod
def from_messages(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
messages: list[Message],
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a ChatDB instance from the given message list.
"""
return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int:
try:
with open(self.next_fname, 'r') as f:
next_fid = int(f.read()) + 1
self.set_next_fid(next_fid)
return next_fid
except Exception:
self.set_next_fid(1)
return 1
def set_next_fid(self, fid: int) -> None:
with open(self.next_fname, 'w') as f:
f.write(f'{fid}')
def read_db(self) -> None:
"""
Reads new messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def read_cache(self) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)