Compare commits

..

8 Commits

3 changed files with 62 additions and 21 deletions

View File

@ -60,13 +60,17 @@ class Chat:
@dataclass @dataclass
class ChatDir(Chat): class ChatDir(Chat):
""" """
A Chat class that is bound to a given directory. Supports reading A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that directory. and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
""" """
default_file_suffix: ClassVar[str] = '.txt' default_file_suffix: ClassVar[str] = '.txt'
directory: pathlib.Path cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given) # a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix file_suffix: str = default_file_suffix
@ -75,17 +79,24 @@ class ChatDir(Chat):
@classmethod @classmethod
def from_dir(cls: Type[ChatDirInst], def from_dir(cls: Type[ChatDirInst],
path: pathlib.Path, cache_path: pathlib.Path,
db_path: pathlib.Path,
glob: Optional[str] = None, glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDirInst: mfilter: Optional[MessageFilter] = None) -> ChatDirInst:
""" """
Create a ChatDir instance from the given directory. If 'glob' is specified, Create a 'ChatDir' instance from the given directory structure.
files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'. Reads all messages from 'db_path' into the local message list.
Messages are created using 'Message.from_file()' and the optional MessageFilter. Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob' fs specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
""" """
messages: list[Message] = [] messages: list[Message] = []
message_files: set[str] = set() message_files: set[str] = set()
file_iter = path.glob(glob) if glob else path.iterdir() file_iter = db_path.glob(glob) if glob else db_path.iterdir()
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if file_path.is_file(): if file_path.is_file():
try: try:
@ -95,11 +106,12 @@ class ChatDir(Chat):
message_files.add(file_path.name) message_files.add(file_path.name)
except MessageError as e: except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}") print(f"Error processing message in '{file_path}': {str(e)}")
return cls(messages, path, mfilter, cls.default_file_suffix, message_files) return cls(messages, cache_path, db_path, mfilter, cls.default_file_suffix, message_files)
@classmethod @classmethod
def from_messages(cls: Type[ChatDirInst], def from_messages(cls: Type[ChatDirInst],
path: pathlib.Path, cache_path: pathlib.Path,
db_path: pathlib.Path,
messages: list[Message], messages: list[Message],
mfilter: Optional[MessageFilter]) -> ChatDirInst: mfilter: Optional[MessageFilter]) -> ChatDirInst:
""" """
@ -108,10 +120,10 @@ class ChatDir(Chat):
in order to synchronize the messages. 'update()' is not in order to synchronize the messages. 'update()' is not
supported until after the first 'dump()'. supported until after the first 'dump()'.
""" """
return cls(messages, path, mfilter) return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int: def get_next_fid(self) -> int:
next_fname = self.directory / '.next' next_fname = self.db_path / '.next'
try: try:
with open(next_fname, 'r') as f: with open(next_fname, 'r') as f:
return int(f.read()) + 1 return int(f.read()) + 1
@ -119,18 +131,16 @@ class ChatDir(Chat):
return 1 return 1
def set_next_fid(self, fid: int) -> None: def set_next_fid(self, fid: int) -> None:
next_fname = self.directory / '.next' next_fname = self.db_path / '.next'
with open(next_fname, 'w') as f: with open(next_fname, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def dump(self, force_all: bool = False) -> None: def dump(self, to_db: bool = False, force_all: bool = False) -> None:
""" """
Writes all messages to the bound directory. If a message has no file_path, Writes all messages to the 'cache_path' or 'db_path'. If a message has no file_path,
it will create a new one. By default, only messages that have not been it will create a new one. By default, only messages that have not been written
written (or read) before will be dumped. Use 'force_all' to force writing (or read) before will be dumped. Use 'force_all' to force writing all message files.
all message files.
""" """
# FIXME: write to 'db' subfolder or given folder
for message in self.messages: for message in self.messages:
# skip messages that we have already written (or read) # skip messages that we have already written (or read)
if message.file_path and message.file_path in self.message_files and not force_all: if message.file_path and message.file_path in self.message_files and not force_all:
@ -138,6 +148,7 @@ class ChatDir(Chat):
file_path = message.file_path file_path = message.file_path
if not file_path: if not file_path:
fid = self.get_next_fid() fid = self.get_next_fid()
file_path = self.directory / f"{fid:04d}{self.file_suffix}" fname = f"{fid:04d}{self.file_suffix}"
file_path = self.db_path / fname if to_db else self.cache_path / fname
self.set_next_fid(fid) self.set_next_fid(fid)
message.to_file(file_path) message.to_file(file_path)

View File

@ -187,7 +187,7 @@ class Message():
and a file path. and a file path.
""" """
question: Question question: Question
answer: Optional[Answer] = None # FIXME: support multiple answers answer: Optional[Answer] = None
tags: Optional[set[Tag]] = None tags: Optional[set[Tag]] = None
ai: Optional[str] = None ai: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None
@ -409,5 +409,15 @@ class Message():
return False return False
return True return True
def file_id(self) -> str:
"""
Returns an ID that is unique within the directory of this message.
Currently this is simply the file name.
"""
if self.file_path:
return self.file_path.name
else:
raise MessageError("Can't create file ID without a file path")
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -575,3 +575,23 @@ This is an answer.
def test_tags_from_file_yaml(self) -> None: def test_tags_from_file_yaml(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml) tags = Message.tags_from_file(self.file_path_yaml)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')})
class MessageFileIDTxtTestCase(CmmTestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message_no_file_path = Message(Question('This is a question.'))
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_file_id_txt(self) -> None:
self.assertEqual(self.message.file_id(), self.file_path.name)
def test_file_id_txt_exception(self) -> None:
with self.assertRaises(MessageError):
self.message_no_file_path.file_id()