import pathlib import tempfile import time from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, terminal_width from .test_main import CmmTestCase class TestChat(CmmTestCase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('btag2')}, file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: self.chat.add_msgs([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: self.chat.add_msgs([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') self.chat.sort(reverse=True) self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: self.chat.add_msgs([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) def test_add_msgs(self) -> None: self.chat.add_msgs([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: self.chat.add_msgs([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.chat.tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: self.chat.add_msgs([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {'-'*terminal_width()} {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 """ self.assertEqual(mock_stdout.getvalue(), expected_output) @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 {'-'*terminal_width()} {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 """ self.assertEqual(mock_stdout.getvalue(), expected_output) class TestChatDB(CmmTestCase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() self.message1 = Message(Question('Question 1'), Answer('Answer 1'), {Tag('tag1')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), {Tag('tag2')}, file_path=pathlib.Path('0002.yaml')) self.message3 = Message(Question('Question 3'), Answer('Answer 3'), {Tag('tag3')}, file_path=pathlib.Path('0003.txt')) self.message4 = Message(Question('Question 4'), Answer('Answer 4'), {Tag('tag4')}, file_path=pathlib.Path('0004.yaml')) self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) def tearDown(self) -> None: self.db_path.cleanup() self.cache_path.cleanup() pass def test_chat_db_from_dir(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) # check that the files are sorted self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) def test_chat_db_from_dir_glob(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), glob='*.txt') self.assertEqual(len(chat_db.messages), 2) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) def test_chat_db_filter(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) self.assertEqual(len(chat_db.messages), 1) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') def test_chat_db_from_messges(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, self.message3, self.message4]) self.assertEqual(len(chat_db.messages), 4) self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(chat_db.get_next_fid(), 1) self.assertEqual(chat_db.get_next_fid(), 2) self.assertEqual(chat_db.get_next_fid(), 3) with open(chat_db.next_fname, 'r') as f: self.assertEqual(f.read(), '3') def test_chat_db_write(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) # check that Message.file_path has been correctly updated self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) # check if all files in the DB dir have actually been overwritten for file in db_dir_files: self.assertGreater(file.stat().st_mtime, old_timestamps[file]) # check that Message.file_path has been correctly updated (again) self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) def test_chat_db_read(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) self.assertEqual(len(chat_db.messages), 4) # create 2 new files in the DB directory new_message1 = Message(Question('Question 5'), Answer('Answer 5'), {Tag('tag5')}) new_message2 = Message(Question('Question 6'), Answer('Answer 6'), {Tag('tag6')}) new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) # read and check them chat_db.read_db() self.assertEqual(len(chat_db.messages), 6) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) # create 2 new files in the cache directory new_message3 = Message(Question('Question 7'), Answer('Answer 5'), {Tag('tag7')}) new_message4 = Message(Question('Question 8'), Answer('Answer 6'), {Tag('tag8')}) new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) # read and check them chat_db.read_cache() self.assertEqual(len(chat_db.messages), 8) # check that the new message have the cache dir path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) # an the old ones keep their path (since they have not been replaced) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) # now overwrite two messages in the DB directory new_message1.question = Question('New Question 1') new_message2.question = Question('New Question 2') new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) # read from the DB dir and check if the modified messages have been updated chat_db.read_db() self.assertEqual(len(chat_db.messages), 8) self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) # now write the messages from the cache to the DB directory new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) # read and check them chat_db.read_db() self.assertEqual(len(chat_db.messages), 8) # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) def test_chat_db_clear(self) -> None: # create a new ChatDB instance chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) # check that Message.file_path is correct self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) # add a new message with empty file_path message_empty = Message(question=Question("What the hell am I doing here?"), answer=Answer("You don't belong here!")) # and one for the cache dir message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) chat_db.add_msgs([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))