diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ad1cece..2640c8b 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -8,7 +8,7 @@ from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from .configuration import default_config_file -from .message import Message, MessageFilter, MessageError, MessageFormat, message_in +from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -345,6 +345,8 @@ class ChatDB(Chat): """ Set message format for writing messages. """ + if mformat not in message_valid_formats: + raise ChatError(f"Message format '{mformat}' is not supported") self.mformat = mformat def msg_write(self, @@ -381,6 +383,7 @@ class ChatDB(Chat): def msg_gather(self, loc: msg_location, require_file_path: bool = False, + glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Gather and return messages from the given locations: @@ -399,9 +402,9 @@ class ChatDB(Chat): else: loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] if loc in ['cache', 'disk', 'all']: - loc_messages += read_dir(self.cache_path, mfilter=mfilter) + loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter) if loc in ['db', 'disk', 'all']: - loc_messages += read_dir(self.db_path, mfilter=mfilter) + loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter) # remove_duplicates and sort the list unique_messages: list[Message] = [] for m in loc_messages: