From 3245690d4d815ee59032f5ac8775d4db1695c00a Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 28 Sep 2023 07:51:56 +0200 Subject: [PATCH] chat: 'msg_gather()' now supports globbing --- chatmastermind/chat.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ad1cece..2640c8b 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -8,7 +8,7 @@ from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union from .configuration import default_config_file -from .message import Message, MessageFilter, MessageError, MessageFormat, message_in +from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -345,6 +345,8 @@ class ChatDB(Chat): """ Set message format for writing messages. """ + if mformat not in message_valid_formats: + raise ChatError(f"Message format '{mformat}' is not supported") self.mformat = mformat def msg_write(self, @@ -381,6 +383,7 @@ class ChatDB(Chat): def msg_gather(self, loc: msg_location, require_file_path: bool = False, + glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ Gather and return messages from the given locations: @@ -399,9 +402,9 @@ class ChatDB(Chat): else: loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))] if loc in ['cache', 'disk', 'all']: - loc_messages += read_dir(self.cache_path, mfilter=mfilter) + loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter) if loc in ['db', 'disk', 'all']: - loc_messages += read_dir(self.db_path, mfilter=mfilter) + loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter) # remove_duplicates and sort the list unique_messages: list[Message] = [] for m in loc_messages: