import pathlib import tempfile import time from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from .test_main import CmmTestCase class TestChat(CmmTestCase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') self.chat.sort(reverse=True) self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) def test_add_messages(self) -> None: self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.chat.tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) def test_find_remove_messages(self) -> None: self.chat.add_messages([self.message1, self.message2]) msgs = self.chat.find_messages(['0001.txt']) self.assertListEqual(msgs, [self.message1]) msgs = self.chat.find_messages(['0001.txt', '0002.txt']) self.assertListEqual(msgs, [self.message1, self.message2]) # add new Message with full path message3 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, file_path=pathlib.Path('/foo/bla/0003.txt')) self.chat.add_messages([message3]) # find new Message by full path msgs = self.chat.find_messages(['/foo/bla/0003.txt']) self.assertListEqual(msgs, [message3]) # find Message with full path only by filename msgs = self.chat.find_messages(['0003.txt']) self.assertListEqual(msgs, [message3]) # remove last message self.chat.remove_messages(['0003.txt']) self.assertListEqual(self.chat.messages, [self.message1, self.message2]) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {'-'*terminal_width()} {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 {'-'*terminal_width()} """ self.assertEqual(mock_stdout.getvalue(), expected_output) @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {'-'*terminal_width()} {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 {'-'*terminal_width()} """ self.assertEqual(mock_stdout.getvalue(), expected_output) class TestChatDB(CmmTestCase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('tag1')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('tag2')}, file_path=pathlib.Path('0002.yaml')) self.message3 = Message(Question('Question 3'), Answer('Answer 3'), {Tag('tag3')}, file_path=pathlib.Path('0003.txt')) self.message4 = Message(Question('Question 4'), Answer('Answer 4'), {Tag('tag4')}, file_path=pathlib.Path('0004.yaml')) self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) # make the next FID match the current state next_fname = pathlib.Path(self.db_path.name) / '.next' with open(next_fname, 'w') as f: f.write('4') def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: """ List all Message files in the given TemporaryDirectory. """ # exclude '.next' return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) def tearDown(self) -> None: self.db_path.cleanup() self.cache_path.cleanup() pass def test_chat_db_from_dir(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) # check that the files are sorted self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) def test_chat_db_from_dir_glob(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), glob='*.txt') self.assertEqual(len(chat_db.messages), 2) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) def test_chat_db_filter(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) self.assertEqual(len(chat_db.messages), 1) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') def test_chat_db_from_messges(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, self.message3, self.message4]) self.assertEqual(len(chat_db.messages), 4) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.get_next_fid(), 5) self.assertEqual(chat_db.get_next_fid(), 6) self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_fname, 'r') as f: self.assertEqual(f.read(), '7') def test_chat_db_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) # check that Message.file_path has been correctly updated self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) # check if all files in the DB dir have actually been overwritten for file in db_dir_files: self.assertGreater(file.stat().st_mtime, old_timestamps[file]) # check that Message.file_path has been correctly updated (again) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) def test_chat_db_read(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) # create 2 new files in the DB directory new_message1 = Message(Question('Question 5'), Answer('Answer 5'), {Tag('tag5')}) new_message2 = Message(Question('Question 6'), Answer('Answer 6'), {Tag('tag6')}) new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) # read and check them chat_db.read_db() self.assertEqual(len(chat_db.messages), 6) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) # create 2 new files in the cache directory new_message3 = Message(Question('Question 7'), Answer('Answer 5'), {Tag('tag7')}) new_message4 = Message(Question('Question 8'), Answer('Answer 6'), {Tag('tag8')}) new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) # read and check them chat_db.read_cache() self.assertEqual(len(chat_db.messages), 8) # check that the new message have the cache dir path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) # an the old ones keep their path (since they have not been replaced) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) # now overwrite two messages in the DB directory new_message1.question = Question('New Question 1') new_message2.question = Question('New Question 2') new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) # read from the DB dir and check if the modified messages have been updated chat_db.read_db() self.assertEqual(len(chat_db.messages), 8) self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) # now write the messages from the cache to the DB directory new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) # read and check them chat_db.read_db() self.assertEqual(len(chat_db.messages), 8) # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) def test_chat_db_clear(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) # add a new message with empty file_path message_empty = Message(question=Question("What the hell am I doing here?"), answer=Answer("You don't belong here!")) # and one for the cache dir message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) def test_chat_db_add(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # add new messages to the cache dir message1 = Message(question=Question("Question 1"), answer=Answer("Answer 1")) chat_db.add_to_cache([message1]) # check if the file_path has been correctly set self.assertIsNotNone(message1.file_path) self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) # add new messages to the DB dir message2 = Message(question=Question("Question 2"), answer=Answer("Answer 2")) chat_db.add_to_db([message2]) # check if the file_path has been correctly set self.assertIsNotNone(message2.file_path) self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 5) with self.assertRaises(ChatError): chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) def test_chat_db_write_messages(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # try to write a message without a valid file_path message = Message(question=Question("Question 1"), answer=Answer("Answer 1")) with self.assertRaises(ChatError): chat_db.write_messages([message]) # write a message with a valid file_path message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' chat_db.write_messages([message]) cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)