From 30027bcfa1b1d96e2e04cc960be4cf9bebd40dd3 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 3 Sep 2023 10:18:16 +0200 Subject: [PATCH] chat: new possibilites for adding messages --- chatmastermind/chat.py | 56 +++++++++++++++++++++++++++++++++++------- tests/test_chat.py | 52 +++++++++++++++++++++++++++++++-------- 2 files changed, 89 insertions(+), 19 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 9fc0a27..de1850d 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path, return messages +def make_file_path(dir_path: pathlib.Path, + file_suffix: str, + next_fid: Callable[[], int]) -> pathlib.Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + return dir_path / f"{next_fid():04d}{file_suffix}" + + def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, @@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path, file_path = message.file_path # message has no file_path: create one if not file_path: - fid = next_fid() - fname = f"{fid:04d}{file_suffix}" - file_path = dir_path / fname + file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name @@ -124,11 +132,11 @@ class Chat: """ self.messages = [] - def add_msgs(self, msgs: list[Message]) -> None: + def add_messages(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ - self.messages += msgs + self.messages += messages self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -279,25 +287,25 @@ class ChatDB(Chat): self.messages += new_messages self.sort() - def write_db(self, msgs: Optional[list[Message]] = None) -> None: + def write_db(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) - def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + def write_cache(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) @@ -309,3 +317,33 @@ class ChatDB(Chat): clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], do_write: bool = True) -> None: + """ + Adds the given messages and sets the file_path to the DB directory. + """ + if do_write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], do_write: bool = True) -> None: + """ + Adds the given messages and sets the file_path to the cache directory. + """ + if do_write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() diff --git a/tests/test_chat.py b/tests/test_chat.py index 9e74061..e91a8f9 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -22,14 +22,14 @@ class TestChat(CmmTestCase): file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: - self.chat.add_msgs([self.message2, self.message1]) + self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') @@ -38,18 +38,18 @@ class TestChat(CmmTestCase): self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: - self.chat.add_msgs([self.message1]) + self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) - def test_add_msgs(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') @@ -58,13 +58,13 @@ class TestChat(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} @@ -81,7 +81,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 @@ -333,7 +333,7 @@ class TestChatDB(CmmTestCase): message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) - chat_db.add_msgs([message_empty, message_cache]) + chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() @@ -345,3 +345,35 @@ class TestChatDB(CmmTestCase): self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*.[ty]*')) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 1) + self.assertIsNotNone(chat_db.messages[4].file_path) + self.assertEqual(chat_db.messages[4].file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*.[ty]*')) + self.assertEqual(len(db_dir_files), 5) + self.assertIsNotNone(chat_db.messages[5].file_path) + self.assertEqual(chat_db.messages[5].file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + + next_fname.unlink()