import unittest import pathlib import tempfile import time import yaml from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, ChatError, msg_location msg_suffix: str = Message.file_suffix_write def msg_to_file_force_suffix(msg: Message) -> None: """ Force writing a message file with illegal suffixes. """ def_suffix = Message.file_suffix_write assert msg.file_path Message.file_suffix_write = msg.file_path.suffix msg.to_file() Message.file_suffix_write = def_suffix class TestChatBase(unittest.TestCase): def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using more than just Question and Answer. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): # exclude the file_path, compare only Q, A and metadata self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) class TestChat(TestChatBase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('atag1'), Tag('btag2')}, ai='FakeAI', model='FakeModel', file_path=pathlib.Path(f'0001{msg_suffix}')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, ai='FakeAI', model='FakeModel', file_path=pathlib.Path(f'0002{msg_suffix}')) self.maxDiff = None def test_unique_id(self) -> None: # test with two identical messages self.chat.msg_add([self.message1, self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message1]) self.chat.msg_unique_id() self.assert_messages_equal(self.chat.messages, [self.message1]) # test with two different messages self.chat.msg_add([self.message2]) self.chat.msg_unique_id() self.assert_messages_equal(self.chat.messages, [self.message1, self.message2]) def test_unique_content(self) -> None: # test with two identical messages self.chat.msg_add([self.message1, self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message1]) self.chat.msg_unique_content() self.assert_messages_equal(self.chat.messages, [self.message1]) # test with two different messages self.chat.msg_add([self.message2]) self.chat.msg_unique_content() self.assert_messages_equal(self.chat.messages, [self.message1, self.message2]) def test_filter(self) -> None: self.chat.msg_add([self.message1, self.message2]) self.chat.msg_filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: self.chat.msg_add([self.message2, self.message1]) self.chat.msg_sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') self.chat.msg_sort(reverse=True) self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: self.chat.msg_add([self.message1]) self.chat.msg_clear() self.assertEqual(len(self.chat.messages), 0) def test_add_messages(self) -> None: self.chat.msg_add([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: self.chat.msg_add([self.message1, self.message2]) tags_all = self.chat.msg_tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.msg_tags(prefix='a') self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.chat.msg_tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: self.chat.msg_add([self.message1, self.message2]) tags_freq = self.chat.msg_tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) def test_find_remove_messages(self) -> None: self.chat.msg_add([self.message1, self.message2]) msgs = self.chat.msg_find(['0001']) self.assertListEqual(msgs, [self.message1]) msgs = self.chat.msg_find(['0001', '0002']) self.assertListEqual(msgs, [self.message1, self.message2]) # add new Message with full path message3 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}')) self.chat.msg_add([message3]) # find new Message by full path msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}']) self.assertListEqual(msgs, [message3]) # find Message with full path only by filename msgs = self.chat.msg_find([f'0003{msg_suffix}']) self.assertListEqual(msgs, [message3]) # remove last message self.chat.msg_remove(['0003']) self.assertListEqual(self.chat.messages, [self.message1, self.message2]) def test_latest_message(self) -> None: self.assertIsNone(self.chat.msg_latest()) self.chat.msg_add([self.message1]) self.assertEqual(self.chat.msg_latest(), self.message1) self.chat.msg_add([self.message2]) self.assertEqual(self.chat.msg_latest(), self.message2) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.msg_add([self.message1, self.message2]) self.chat.print(paged=False, tight=True) expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 """ self.assertEqual(mock_stdout.getvalue(), expected_output) @patch('sys.stdout', new_callable=StringIO) def test_print_with_metadata(self, mock_stdout: StringIO) -> None: self.chat.msg_add([self.message1, self.message2]) self.chat.print(paged=False, with_metadata=True, tight=True) expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001{msg_suffix} AI: FakeAI MODEL: FakeModel {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {TagLine.prefix} btag2 FILE: 0002{msg_suffix} AI: FakeAI MODEL: FakeModel {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 """ self.assertEqual(mock_stdout.getvalue(), expected_output) class TestChatDB(TestChatBase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('tag1')}) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('tag2')}) self.message3 = Message(Question('Question 3'), Answer('Answer 3'), {Tag('tag3')}) self.message4 = Message(Question('Question 4'), Answer('Answer 4'), {Tag('tag4')}) self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt') self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml') self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt') self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml') # make the next FID match the current state next_fname = pathlib.Path(self.db_path.name) / '.next' with open(next_fname, 'w') as f: f.write('4') # add some "trash" in order to test if it's correctly handled / ignored self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg'] for file in self.trash_files: with open(pathlib.Path(self.db_path.name) / file, 'w') as f: f.write('test trash') # also create a file with actual yaml content with open(pathlib.Path(self.db_path.name) / 'content.yaml', 'w') as f: yaml.dump({'key': 'value'}, f) self.trash_files.append('content.yaml') self.maxDiff = None def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: """ List all Message files in the given TemporaryDirectory. """ # exclude '.next' return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files] def tearDown(self) -> None: self.db_path.cleanup() self.cache_path.cleanup() pass def test_validate(self) -> None: duplicate_message = Message(Question('Question 4'), Answer('Answer 4'), {Tag('tag4')}, file_path=pathlib.Path(self.db_path.name, '0004.txt')) msg_to_file_force_suffix(duplicate_message) with self.assertRaises(ChatError) as cm: ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), glob='*') self.assertEqual(str(cm.exception), "Validation failed") def test_file_path_ID_exists(self) -> None: """ Tests if the CacheDB chooses another ID if a file path with the given one exists. """ # create a new and empty CacheDB db_path = tempfile.TemporaryDirectory() cache_path = tempfile.TemporaryDirectory() chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name), pathlib.Path(db_path.name)) # add a message file message = Message(Question('What?'), file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}') message.to_file() message1 = Message(Question('Where?')) chat_db.cache_write([message1]) self.assertEqual(message1.msg_id(), '0002') def test_from_dir(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) # check that the files are sorted self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) def test_from_dir_glob(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), glob='*1.*') self.assertEqual(len(chat_db.messages), 1) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) def test_from_dir_filter_tags(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(tags_or={Tag('tag1')})) self.assertEqual(len(chat_db.messages), 1) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) def test_from_dir_filter_tags_empty(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(tags_or=set(), tags_and=set(), tags_not=set())) self.assertEqual(len(chat_db.messages), 0) def test_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) self.assertEqual(len(chat_db.messages), 1) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') def test_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, self.message3, self.message4]) self.assertEqual(len(chat_db.messages), 4) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) def test_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.get_next_fid(), 5) self.assertEqual(chat_db.get_next_fid(), 6) self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_path, 'r') as f: self.assertEqual(f.read(), '7') def test_msg_in_db_or_cache(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertTrue(chat_db.msg_in_db(self.message1)) self.assertTrue(chat_db.msg_in_db(str(self.message1.file_path))) self.assertTrue(chat_db.msg_in_db(self.message1.msg_id())) self.assertFalse(chat_db.msg_in_cache(self.message1)) self.assertFalse(chat_db.msg_in_cache(str(self.message1.file_path))) self.assertFalse(chat_db.msg_in_cache(self.message1.msg_id())) # add new message to the cache dir cache_message = Message(question=Question("Question 1"), answer=Answer("Answer 1")) chat_db.cache_add([cache_message]) self.assertTrue(chat_db.msg_in_cache(cache_message)) self.assertTrue(chat_db.msg_in_cache(cache_message.msg_id())) self.assertFalse(chat_db.msg_in_db(cache_message)) self.assertFalse(chat_db.msg_in_db(str(cache_message.file_path))) def test_db_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) # write the messages to the cache directory chat_db.cache_write() # check if the written files are in the cache directory cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files) # check that Message.file_path has been correctly updated self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}')) # check the timestamp of the files in the DB directory db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.db_write() # check if the written files are in the DB directory db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files) # check if all files in the DB dir have actually been overwritten for file in db_dir_files: self.assertGreater(file.stat().st_mtime, old_timestamps[file]) # check that Message.file_path has been correctly updated (again) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) def test_db_read(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) # create 2 new files in the DB directory new_message1 = Message(Question('Question 5'), Answer('Answer 5'), {Tag('tag5')}) new_message2 = Message(Question('Question 6'), Answer('Answer 6'), {Tag('tag6')}) new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt') new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml') # read and check them chat_db.db_read() self.assertEqual(len(chat_db.messages), 6) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) # create 2 new files in the cache directory new_message3 = Message(Question('Question 7'), Answer('Answer 7'), {Tag('tag7')}) new_message4 = Message(Question('Question 8'), Answer('Answer 8'), {Tag('tag8')}) new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt') new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml') # read and check them chat_db.cache_read() self.assertEqual(len(chat_db.messages), 8) # check that the new message have the cache dir path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}')) # an the old ones keep their path (since they have not been replaced) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) # now overwrite two messages in the DB directory new_message1.question = Question('New Question 1') new_message2.question = Question('New Question 2') new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt') new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml') # read from the DB dir and check if the modified messages have been updated chat_db.db_read() self.assertEqual(len(chat_db.messages), 8) self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}')) # now write the messages from the cache to the DB directory new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}')) new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}')) # read and check them chat_db.db_read() self.assertEqual(len(chat_db.messages), 8) # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}')) def test_cache_clear(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}')) # write the messages to the cache directory chat_db.cache_write() # check if the written files are in the cache directory cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.db_write() db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files) # add a new message with empty file_path message_empty = Message(question=Question("What the hell am I doing here?"), answer=Answer("You don't belong here!")) # and one for the cache dir message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005')) chat_db.msg_add([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.cache_clear() cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) def test_add(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # add new messages to the cache dir message1 = Message(question=Question("Question 1"), answer=Answer("Answer 1")) chat_db.cache_add([message1]) # check if the file_path has been correctly set self.assertIsNotNone(message1.file_path) self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) # add new messages to the DB dir message2 = Message(question=Question("Question 2"), answer=Answer("Answer 2")) chat_db.db_add([message2]) # check if the file_path has been correctly set self.assertIsNotNone(message2.file_path) self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 5) with self.assertRaises(ChatError): chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))]) def test_msg_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # try to write a message without a valid file_path message = Message(question=Question("Question 1"), answer=Answer("Answer 1")) with self.assertRaises(ChatError): chat_db.msg_write([message]) # write a message with a valid file_path message.file_path = pathlib.Path(self.cache_path.name) / '123456' chat_db.msg_write([message]) cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files) def test_msg_update(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) message = chat_db.messages[0] message.answer = Answer("New answer") # update message without writing chat_db.msg_update([message], write=False) self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) # re-read the message and check for old content chat_db.db_read() self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) # now check with writing (message should be overwritten) chat_db.msg_update([message], write=True) chat_db.db_read() self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) # test without file_path -> expect error message1 = Message(question=Question("Question 1"), answer=Answer("Answer 1")) with self.assertRaises(ChatError): chat_db.msg_update([message1]) def test_msg_find(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # search for a DB file in memory self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc=msg_location.MEM), [self.message1]) self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc=msg_location.MEM), [self.message1]) self.assertEqual(chat_db.msg_find(['0001.msg'], loc=msg_location.MEM), [self.message1]) self.assertEqual(chat_db.msg_find(['0001'], loc=msg_location.MEM), [self.message1]) # and on disk self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc=msg_location.DB), [self.message2]) self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc=msg_location.DB), [self.message2]) self.assertEqual(chat_db.msg_find(['0002.msg'], loc=msg_location.DB), [self.message2]) self.assertEqual(chat_db.msg_find(['0002'], loc=msg_location.DB), [self.message2]) # now search the cache -> expect empty result self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc=msg_location.CACHE), []) self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc=msg_location.CACHE), []) self.assertEqual(chat_db.msg_find(['0003.msg'], loc=msg_location.CACHE), []) self.assertEqual(chat_db.msg_find(['0003'], loc=msg_location.CACHE), []) # search for multiple messages # -> search one twice, expect result to be unique search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)] expected_result = [self.message1, self.message2, self.message3] result = chat_db.msg_find(search_names, loc=msg_location.ALL) self.assert_messages_equal(result, expected_result) def test_msg_latest(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), self.message4) self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), self.message4) # the cache is currently empty: self.assertIsNone(chat_db.msg_latest(loc=msg_location.CACHE)) # add new messages to the cache dir new_message = Message(question=Question("New Question"), answer=Answer("New Answer")) chat_db.cache_add([new_message]) self.assertEqual(chat_db.msg_latest(loc=msg_location.CACHE), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.MEM), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.DISK), new_message) self.assertEqual(chat_db.msg_latest(loc=msg_location.ALL), new_message) # the DB does not contain the new message self.assertEqual(chat_db.msg_latest(loc=msg_location.DB), self.message4) def test_msg_gather(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) all_messages = [self.message1, self.message2, self.message3, self.message4] self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) # add a new message, but only to the internal list new_message = Message(Question("What?")) all_messages_mem = all_messages + [new_message] chat_db.msg_add([new_message]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages_mem) # the nr. of messages on disk did not change -> expect old result self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) # test with MessageFilter self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL, mfilter=MessageFilter(tags_or={Tag('tag1')})), [self.message1]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK, mfilter=MessageFilter(tags_or={Tag('tag2')})), [self.message2]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE, mfilter=MessageFilter(tags_or={Tag('tag3')})), []) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM, mfilter=MessageFilter(question_contains="What")), [new_message]) def test_msg_move_and_gather(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) all_messages = [self.message1, self.message2, self.message3, self.message4] self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) # move first message to the cache chat_db.cache_move(self.message1) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), [self.message1]) self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), [self.message2, self.message3, self.message4]) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.ALL), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DISK), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.MEM), all_messages) # now move first message back to the DB chat_db.db_move(self.message1) self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.CACHE), []) self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assert_messages_equal(chat_db.msg_gather(loc=msg_location.DB), all_messages)