From e1414835c8c2cdc96d9c425b7d585afb3ffbb261 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 08:16:55 +0200 Subject: [PATCH] chat: added functions for finding and deleting messages --- chatmastermind/chat.py | 52 ++++++++++++++++++++++++++++++++---------- tests/test_chat.py | 22 ++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c631dab..4e8fb20 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -2,7 +2,7 @@ Module implementing various chat classes and functions for managing a chat history. """ import shutil -import pathlib +from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass @@ -30,7 +30,7 @@ def print_paged(text: str) -> None: pager(text) -def read_dir(dir_path: pathlib.Path, +def read_dir(dir_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ @@ -55,9 +55,9 @@ def read_dir(dir_path: pathlib.Path, return messages -def make_file_path(dir_path: pathlib.Path, +def make_file_path(dir_path: Path, file_suffix: str, - next_fid: Callable[[], int]) -> pathlib.Path: + next_fid: Callable[[], int]) -> Path: """ Create a file_path for the given directory using the given file_suffix and ID generator function. @@ -65,7 +65,7 @@ def make_file_path(dir_path: pathlib.Path, return dir_path / f"{next_fid():04d}{file_suffix}" -def write_dir(dir_path: pathlib.Path, +def write_dir(dir_path: Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: @@ -90,7 +90,7 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) -def clear_dir(dir_path: pathlib.Path, +def clear_dir(dir_path: Path, glob: Optional[str] = None) -> None: """ Deletes all Message files in the given directory. @@ -139,6 +139,34 @@ class Chat: self.messages += messages self.sort() + def latest_message(self) -> Optional[Message]: + """ + Returns the last added message (according to the file ID). + """ + if len(self.messages) > 0: + self.sort() + return self.messages[-1] + else: + return None + + def find_messages(self, msg_names: list[str]) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the + caller should check the result if he requires all messages). + """ + return [m for m in self.messages + if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + + def remove_messages(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (incl. the suffix) or full paths. + """ + self.messages = [m for m in self.messages + if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + self.sort() + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. @@ -192,8 +220,8 @@ class ChatDB(Chat): default_file_suffix: ClassVar[str] = '.txt' - cache_path: pathlib.Path - db_path: pathlib.Path + cache_path: Path + db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix @@ -209,8 +237,8 @@ class ChatDB(Chat): @classmethod def from_dir(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ @@ -230,8 +258,8 @@ class ChatDB(Chat): @classmethod def from_messages(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index f8302eb..d81a97a 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -62,6 +62,28 @@ class TestChat(CmmTestCase): tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + def test_find_remove_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + msgs = self.chat.find_messages(['0001.txt']) + self.assertListEqual(msgs, [self.message1]) + msgs = self.chat.find_messages(['0001.txt', '0002.txt']) + self.assertListEqual(msgs, [self.message1, self.message2]) + # add new Message with full path + message3 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('/foo/bla/0003.txt')) + self.chat.add_messages([message3]) + # find new Message by full path + msgs = self.chat.find_messages(['/foo/bla/0003.txt']) + self.assertListEqual(msgs, [message3]) + # find Message with full path only by filename + msgs = self.chat.find_messages(['0003.txt']) + self.assertListEqual(msgs, [message3]) + # remove last message + self.chat.remove_messages(['0003.txt']) + self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2])