From b4d1edcc73dec3461f495b4ee0d6c5177bbb2815 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 3 Sep 2023 10:18:16 +0200 Subject: [PATCH] chat: new possibilites for adding messages and better tests --- chatmastermind/chat.py | 75 ++++++++++++++++++++++++++++++----- tests/test_chat.py | 88 +++++++++++++++++++++++++++++++----------- 2 files changed, 132 insertions(+), 31 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 9fc0a27..7e6df8f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path, return messages +def make_file_path(dir_path: pathlib.Path, + file_suffix: str, + next_fid: Callable[[], int]) -> pathlib.Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + return dir_path / f"{next_fid():04d}{file_suffix}" + + def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, @@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path, file_path = message.file_path # message has no file_path: create one if not file_path: - fid = next_fid() - fname = f"{fid:04d}{file_suffix}" - file_path = dir_path / fname + file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name @@ -124,11 +132,11 @@ class Chat: """ self.messages = [] - def add_msgs(self, msgs: list[Message]) -> None: + def add_messages(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ - self.messages += msgs + self.messages += messages self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -279,25 +287,25 @@ class ChatDB(Chat): self.messages += new_messages self.sort() - def write_db(self, msgs: Optional[list[Message]] = None) -> None: + def write_db(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) - def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + def write_cache(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) @@ -309,3 +317,52 @@ class ChatDB(Chat): clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the DB directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the cache directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def write_messages(self, messages: Optional[list[Message]] = None) -> None: + """ + Write either the given messages or the internal ones to their current file_path. + If messages are given, they all must have a valid file_path. When writing the + internal messages, the ones with a valid file_path are written, the others + are ignored. + """ + if messages and any(m.file_path is None for m in messages): + raise ChatError("Can't write files without a valid file_path") + msgs = iter(messages if messages else self.messages) + while (m := next(msgs, None)): + m.to_file() diff --git a/tests/test_chat.py b/tests/test_chat.py index 9e74061..87a42dc 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -5,7 +5,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, terminal_width +from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from .test_main import CmmTestCase @@ -22,14 +22,14 @@ class TestChat(CmmTestCase): file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: - self.chat.add_msgs([self.message2, self.message1]) + self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') @@ -38,18 +38,18 @@ class TestChat(CmmTestCase): self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: - self.chat.add_msgs([self.message1]) + self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) - def test_add_msgs(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') @@ -58,13 +58,13 @@ class TestChat(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} @@ -81,7 +81,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 @@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase): self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: + """ + List all Message files in the given TemporaryDirectory. + """ + # exclude '.next' + return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) def tearDown(self) -> None: self.db_path.cleanup() @@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase): def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.get_next_fid(), 1) - self.assertEqual(chat_db.get_next_fid(), 2) - self.assertEqual(chat_db.get_next_fid(), 3) + self.assertEqual(chat_db.get_next_fid(), 5) + self.assertEqual(chat_db.get_next_fid(), 6) + self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_fname, 'r') as f: - self.assertEqual(f.read(), '3') + self.assertEqual(f.read(), '7') def test_chat_db_write(self) -> None: # create a new ChatDB instance @@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) @@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -333,15 +344,48 @@ class TestChatDB(CmmTestCase): message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) - chat_db.add_msgs([message_empty, message_cache]) + chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + # check if the file_path has been correctly set + self.assertIsNotNone(message1.file_path) + self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + # check if the file_path has been correctly set + self.assertIsNotNone(message2.file_path) + self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 5) + + with self.assertRaises(ChatError): + chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) + + # TODO: add testcase for "ChatDB.write_messages()"