added new module 'chat.py'

This commit is contained in:
juk0de 2023-08-24 16:49:54 +02:00
parent 07b8f955da
commit f5b185505e

154
chatmastermind/chat.py Normal file
View File

@ -0,0 +1,154 @@
"""
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
from pprint import PrettyPrinter
import pathlib
from dataclasses import dataclass, field
from typing import TypeVar, Type, Optional, ClassVar, Any
from .message import Message, MessageFilter, MessageError
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDirInst = TypeVar('ChatDirInst', bound='ChatDir')
class ChatError(Exception):
pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
@dataclass
class Chat:
"""
A class containing a complete chat history.
"""
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def print(self, dump: bool = False) -> None:
if dump:
pp(self)
return
# for message in self.messages:
# text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
# if source_code:
# display_source_code(message['content'])
# continue
# if message['role'] == 'user':
# print('-' * terminal_width())
# if text_too_long:
# print(f"{message['role'].upper()}:")
# print(message['content'])
# else:
# print(f"{message['role'].upper()}: {message['content']}")
@dataclass
class ChatDir(Chat):
"""
A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# set containing all file names of the current messages
message_files: set[str] = field(default_factory=set)
@classmethod
def from_dir(cls: Type[ChatDirInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDirInst:
"""
Create a 'ChatDir' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob' fs specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
message_files: set[str] = set()
file_iter = db_path.glob(glob) if glob else db_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
message_files.add(file_path.name)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return cls(messages, cache_path, db_path, mfilter, cls.default_file_suffix, message_files)
@classmethod
def from_messages(cls: Type[ChatDirInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
messages: list[Message],
mfilter: Optional[MessageFilter]) -> ChatDirInst:
"""
Create a ChatDir instance from the given message list.
Note that the next call to 'dump()' will write all files
in order to synchronize the messages. 'update()' is not
supported until after the first 'dump()'.
"""
return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int:
next_fname = self.db_path / '.next'
try:
with open(next_fname, 'r') as f:
return int(f.read()) + 1
except Exception:
return 1
def set_next_fid(self, fid: int) -> None:
next_fname = self.db_path / '.next'
with open(next_fname, 'w') as f:
f.write(f'{fid}')
def dump(self, to_db: bool = False, force_all: bool = False) -> None:
"""
Writes all messages to the 'cache_path' or 'db_path'. If a message has no file_path,
it will create a new one. By default, only messages that have not been written
(or read) before will be dumped. Use 'force_all' to force writing all message files.
"""
for message in self.messages:
# skip messages that we have already written (or read)
if message.file_path and message.file_path in self.message_files and not force_all:
continue
file_path = message.file_path
if not file_path:
fid = self.get_next_fid()
fname = f"{fid:04d}{self.file_suffix}"
file_path = self.db_path / fname if to_db else self.cache_path / fname
self.set_next_fid(fid)
message.to_file(file_path)