From e4e760b4a8eb44ceca77d6ad9aae1bcafe5bc1d8 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 15 Sep 2023 22:42:11 +0200 Subject: [PATCH] chat: added functions msg_in_cache() and msg_in_db(), also tests --- chatmastermind/chat.py | 26 +++++++++++++++++++++++++- tests/test_chat.py | 19 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index cb4855e..f030c5e 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -6,7 +6,7 @@ from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass -from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union from .configuration import default_config_file from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag @@ -466,6 +466,30 @@ class ChatDB(Chat): return m return None + def msg_in_cache(self, message: Union[Message, str]) -> bool: + """ + Return true if the given Message (or filename or Message.msg_id()) + is located in the cache directory. False otherwise. + """ + if isinstance(message, Message): + return (message.file_path is not None + and message.file_path.parent.samefile(self.cache_path) # noqa: W503 + and message.file_path.exists()) # noqa: W503 + else: + return len(self.msg_find([message], loc='cache')) > 0 + + def msg_in_db(self, message: Union[Message, str]) -> bool: + """ + Return true if the given Message (or filename or Message.msg_id()) + is located in the DB directory. False otherwise. + """ + if isinstance(message, Message): + return (message.file_path is not None + and message.file_path.parent.samefile(self.db_path) # noqa: W503 + and message.file_path.exists()) # noqa: W503 + else: + return len(self.msg_find([message], loc='db')) > 0 + def cache_read(self) -> None: """ Read messages from the cache directory. New ones are added to the internal list, diff --git a/tests/test_chat.py b/tests/test_chat.py index 7a0c94d..3421852 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -284,6 +284,25 @@ class TestChatDB(unittest.TestCase): with open(chat_db.next_path, 'r') as f: self.assertEqual(f.read(), '7') + def test_msg_in_db_or_cache(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertTrue(chat_db.msg_in_db(self.message1)) + self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path))) + self.assertTrue(chat_db.msg_in_db(self.message1.msg_id())) + self.assertFalse(chat_db.msg_in_cache(self.message1)) + self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path))) + self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id())) + # add new message to the cache dir + cache_message = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.cache_add([cache_message]) + self.assertTrue(chat_db.msg_in_cache(cache_message)) + self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id())) + self.assertFalse(chat_db.msg_in_db(cache_message)) + self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path))) + def test_db_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),