diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..cb68eff --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,234 @@ +import pathlib +import tempfile +import time +from io import StringIO +from unittest.mock import patch +from chatmastermind.tags import TagLine +from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter +from chatmastermind.chat import Chat, ChatDB, terminal_width +from .test_main import CmmTestCase + + +class TestChat(CmmTestCase): + def setUp(self) -> None: + self.chat = Chat([]) + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('atag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('0002.txt')) + + def test_filter(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.filter(MessageFilter(answer_contains='Answer 1')) + + self.assertEqual(len(self.chat.messages), 1) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + + def test_sort(self) -> None: + self.chat.add_msgs([self.message2, self.message1]) + self.chat.sort() + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + self.chat.sort(reverse=True) + self.assertEqual(self.chat.messages[0].question, 'Question 2') + self.assertEqual(self.chat.messages[1].question, 'Question 1') + + def test_clear(self) -> None: + self.chat.add_msgs([self.message1]) + self.chat.clear() + self.assertEqual(len(self.chat.messages), 0) + + def test_add_msgs(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.assertEqual(len(self.chat.messages), 2) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + + def test_tags(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_all = self.chat.tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.chat.tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.chat.tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + @patch('sys.stdout', new_callable=StringIO) + def test_print(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False, with_tags=True, with_file=True) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{TagLine.prefix} atag1 +FILE: 0001.txt +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +{TagLine.prefix} btag2 +FILE: 0002.txt +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + +class TestChatDB(CmmTestCase): + def setUp(self) -> None: + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('tag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('tag2')}, + file_path=pathlib.Path('0002.yaml')) + self.message3 = Message(Question('Question 3'), + Answer('Answer 3'), + {Tag('tag3')}, + file_path=pathlib.Path('0003.txt')) + self.message4 = Message(Question('Question 4'), + Answer('Answer 4'), + {Tag('tag4')}, + file_path=pathlib.Path('0004.yaml')) + + self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) + self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) + self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) + self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + + def tearDown(self) -> None: + self.db_path.cleanup() + self.cache_path.cleanup() + pass + + def test_chat_db_from_dir(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + # check that the files are sorted + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, + pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_from_dir_glob(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + glob='*.txt') + self.assertEqual(len(chat_db.messages), 2) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + + def test_chat_db_filter(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(answer_contains='Answer 2')) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[0].answer, 'Answer 2') + + def test_chat_db_from_messges(self) -> None: + chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + messages=[self.message1, self.message2, + self.message3, self.message4]) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + + def test_chat_db_fids(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.get_next_fid(), 1) + self.assertEqual(chat_db.get_next_fid(), 2) + self.assertEqual(chat_db.get_next_fid(), 3) + with open(chat_db.next_fname, 'r') as f: + self.assertEqual(f.read(), '3') + + def test_chat_db_write(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + # check that Message.file_path has been correctly updated + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + + # check the timestamp of the files in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} + # overwrite the messages in the db directory + time.sleep(0.05) + chat_db.write_db() + # check if the written files are in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + # check if all files in the DB dir have actually been overwritten + for file in db_dir_files: + self.assertGreater(file.stat().st_mtime, old_timestamps[file]) + # check that Message.file_path has been correctly updated (again) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))