diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 17e5c38..63a5e7f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -6,9 +6,9 @@ from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass -from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union +from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from .configuration import default_config_file -from .message import Message, MessageFilter, MessageError, message_in +from .message import Message, MessageFilter, MessageError, MessageFormat, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -17,6 +17,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') db_next_file = '.next' ignored_files = [db_next_file, default_config_file] msg_location = Literal['mem', 'disk', 'cache', 'db', 'all'] +msg_suffix = Message.file_suffix_write class ChatError(Exception): @@ -52,7 +53,7 @@ def read_dir(dir_path: Path, for file_path in sorted(file_iter): if (file_path.is_file() and file_path.name not in ignored_files # noqa: W503 - and file_path.suffix in Message.file_suffixes): # noqa: W503 + and file_path.suffix in Message.file_suffixes_read): # noqa: W503 try: message = Message.from_file(file_path, mfilter) if message: @@ -63,22 +64,20 @@ def read_dir(dir_path: Path, def make_file_path(dir_path: Path, - file_suffix: str, next_fid: Callable[[], int]) -> Path: """ - Create a file_path for the given directory using the - given file_suffix and ID generator function. + Create a file_path for the given directory using the given ID generator function. """ - file_path = dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{msg_suffix}" while file_path.exists(): - file_path = dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{msg_suffix}" return file_path def write_dir(dir_path: Path, messages: list[Message], - file_suffix: str, - next_fid: Callable[[], int]) -> None: + next_fid: Callable[[], int], + mformat: MessageFormat = Message.default_format) -> None: """ Write all messages to the given directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified @@ -86,18 +85,17 @@ def write_dir(dir_path: Path, Parameters: * 'dir_path': destination directory * 'messages': list of messages to write - * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] * 'next_fid': callable that returns the next file ID """ for message in messages: file_path = message.file_path # message has no file_path: create one if not file_path: - file_path = make_file_path(dir_path, file_suffix, next_fid) + file_path = make_file_path(dir_path, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name - message.to_file(file_path) + message.to_file(file_path, mformat=mformat) def clear_dir(dir_path: Path, @@ -109,7 +107,7 @@ def clear_dir(dir_path: Path, for file_path in file_iter: if (file_path.is_file() and file_path.name not in ignored_files # noqa: W503 - and file_path.suffix in Message.file_suffixes): # noqa: W503 + and file_path.suffix in Message.file_suffixes_read): # noqa: W503 file_path.unlink(missing_ok=True) @@ -146,7 +144,7 @@ class Chat: Matching is True if: * 'name' matches the full 'file_path' * 'name' matches 'file_path.name' (i. e. including the suffix) - * 'name' matches 'file_path.stem' (i. e. without a suffix) + * 'name' matches 'file_path.stem' (i. e. without the suffix) """ return Path(name) == file_path or name == file_path.name or name == file_path.stem @@ -281,13 +279,10 @@ class ChatDB(Chat): persistently. """ - default_file_suffix: ClassVar[str] = '.txt' - cache_path: Path db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None - file_suffix: str = default_file_suffix # the glob pattern for all messages glob: Optional[str] = None @@ -317,8 +312,7 @@ class ChatDB(Chat): when reading them. """ messages = read_dir(db_path, glob, mfilter) - return cls(messages, cache_path, db_path, mfilter, - cls.default_file_suffix, glob) + return cls(messages, cache_path, db_path, mfilter, glob) @classmethod def from_messages(cls: Type[ChatDBInst], @@ -345,7 +339,9 @@ class ChatDB(Chat): with open(self.next_path, 'w') as f: f.write(f'{fid}') - def msg_write(self, messages: Optional[list[Message]] = None) -> None: + def msg_write(self, + messages: Optional[list[Message]] = None, + mformat: MessageFormat = Message.default_format) -> None: """ Write either the given messages or the internal ones to their CURRENT file_path. If messages are given, they all must have a valid file_path. When writing the @@ -356,7 +352,7 @@ class ChatDB(Chat): raise ChatError("Can't write files without a valid file_path") msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): - m.to_file() + m.to_file(mformat=mformat) def msg_update(self, messages: list[Message], write: bool = True) -> None: """ @@ -518,7 +514,6 @@ class ChatDB(Chat): """ write_dir(self.cache_path, messages if messages else self.messages, - self.file_suffix, self.get_next_fid) def cache_add(self, messages: list[Message], write: bool = True) -> None: @@ -531,11 +526,10 @@ class ChatDB(Chat): if write: write_dir(self.cache_path, messages, - self.file_suffix, self.get_next_fid) else: for m in messages: - m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + m.file_path = make_file_path(self.cache_path, self.get_next_fid) self.messages += messages self.msg_sort() @@ -585,7 +579,6 @@ class ChatDB(Chat): """ write_dir(self.db_path, messages if messages else self.messages, - self.file_suffix, self.get_next_fid) def db_add(self, messages: list[Message], write: bool = True) -> None: @@ -598,11 +591,10 @@ class ChatDB(Chat): if write: write_dir(self.db_path, messages, - self.file_suffix, self.get_next_fid) else: for m in messages: - m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + m.file_path = make_file_path(self.db_path, self.get_next_fid) self.messages += messages self.msg_sort()