Compare commits

..

8 Commits

2 changed files with 15 additions and 67 deletions

View File

@ -1,11 +1,9 @@
""" """
Module implementing various chat classes and functions for managing a chat history. Module implementing various chat classes and functions for managing a chat history.
""" """
import shutil
from pprint import PrettyPrinter
import pathlib import pathlib
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any from typing import TypeVar, Type, Optional
from .message import Message, MessageFilter, MessageError from .message import Message, MessageFilter, MessageError
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
@ -16,14 +14,6 @@ class ChatError(Exception):
pass pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
@dataclass @dataclass
class Chat: class Chat:
""" """
@ -39,23 +29,6 @@ class Chat:
""" """
self.messages = [m for m in self.messages if m.match(mfilter)] self.messages = [m for m in self.messages if m.match(mfilter)]
def print(self, dump: bool = False) -> None:
if dump:
pp(self)
return
# for message in self.messages:
# text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
# if source_code:
# display_source_code(message['content'])
# continue
# if message['role'] == 'user':
# print('-' * terminal_width())
# if text_too_long:
# print(f"{message['role'].upper()}:")
# print(message['content'])
# else:
# print(f"{message['role'].upper()}: {message['content']}")
@dataclass @dataclass
class ChatDir(Chat): class ChatDir(Chat):
@ -64,14 +37,11 @@ class ChatDir(Chat):
and writing messages from / to that directory. and writing messages from / to that directory.
""" """
default_file_suffix: ClassVar[str] = '.txt'
directory: pathlib.Path directory: pathlib.Path
# a MessageFilter that all messages must match (if given) # a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# set containing all file names of the current messages # set containing all file names of the current messages
message_files: set[str] = field(default_factory=set) message_files: set[str] = set()
@classmethod @classmethod
def from_dir(cls: Type[ChatDirInst], def from_dir(cls: Type[ChatDirInst],
@ -86,7 +56,7 @@ class ChatDir(Chat):
messages: list[Message] = [] messages: list[Message] = []
message_files: set[str] = set() message_files: set[str] = set()
file_iter = path.glob(glob) if glob else path.iterdir() file_iter = path.glob(glob) if glob else path.iterdir()
for file_path in sorted(file_iter): for file_path in file_iter:
if file_path.is_file(): if file_path.is_file():
try: try:
message = Message.from_file(file_path, mfilter) message = Message.from_file(file_path, mfilter)
@ -95,7 +65,7 @@ class ChatDir(Chat):
message_files.add(file_path.name) message_files.add(file_path.name)
except MessageError as e: except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}") print(f"Error processing message in '{file_path}': {str(e)}")
return cls(messages, path, mfilter, cls.default_file_suffix, message_files) return cls(messages, path, mfilter, message_files)
@classmethod @classmethod
def from_messages(cls: Type[ChatDirInst], def from_messages(cls: Type[ChatDirInst],
@ -110,34 +80,12 @@ class ChatDir(Chat):
""" """
return cls(messages, path, mfilter) return cls(messages, path, mfilter)
def get_next_fid(self) -> int: # def dump(self) -> None:
next_fname = self.directory / '.next' # """
try: # Writes all messages to the bound directory. If a message has no file_path,
with open(next_fname, 'r') as f: # it will create a new one.
return int(f.read()) + 1 # """
except Exception: # for message in self.messages:
return 1 # # TODO: determine file name if message does not have one
# if message.file_path.name() not in self.message_files:
def set_next_fid(self, fid: int) -> None: # message.to_file()
next_fname = self.directory / '.next'
with open(next_fname, 'w') as f:
f.write(f'{fid}')
def dump(self, force_all: bool = False) -> None:
"""
Writes all messages to the bound directory. If a message has no file_path,
it will create a new one. By default, only messages that have not been
written (or read) before will be dumped. Use 'force_all' to force writing
all message files.
"""
# FIXME: write to 'db' subfolder or given folder
for message in self.messages:
# skip messages that we have already written (or read)
if message.file_path and message.file_path in self.message_files and not force_all:
continue
file_path = message.file_path
if not file_path:
fid = self.get_next_fid()
file_path = self.directory / f"{fid:04d}{self.file_suffix}"
self.set_next_fid(fid)
message.to_file(file_path)

View File

@ -187,7 +187,7 @@ class Message():
and a file path. and a file path.
""" """
question: Question question: Question
answer: Optional[Answer] = None # FIXME: support multiple answers answer: Optional[Answer] = None
tags: Optional[set[Tag]] = None tags: Optional[set[Tag]] = None
ai: Optional[str] = None ai: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None