From eac93cb667418ca7b5443103b1c3c2c308bd0b6b Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 18 Sep 2023 09:57:19 +0200 Subject: [PATCH] question_cmd: added testclass for the 'question_cmd()' function --- tests/test_question_cmd.py | 226 +++++++++++++++++++++++++++++++++++-- 1 file changed, 216 insertions(+), 10 deletions(-) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 1c6c958..b809567 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -3,23 +3,39 @@ import unittest import argparse import tempfile from pathlib import Path -from unittest.mock import MagicMock -from chatmastermind.commands.question import create_message +from unittest import mock +from unittest.mock import MagicMock, call, ANY +from typing import Optional +from chatmastermind.configuration import Config +from chatmastermind.commands.question import create_message, question_cmd +from chatmastermind.tags import Tag from chatmastermind.message import Message, Question, Answer -from chatmastermind.chat import ChatDB +from chatmastermind.chat import Chat, ChatDB +from chatmastermind.ai import AI, AIResponse, Tokens, AIError -class TestMessageCreate(unittest.TestCase): +class TestQuestionCmdBase(unittest.TestCase): + def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: + """ + Compare messages using more than just Question and Answer. + """ + self.assertEqual(len(msg1), len(msg2)) + for m1, m2 in zip(msg1, msg2): + # exclude the file_path, compare only Q, A and metadata + self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) + + +class TestMessageCreate(TestQuestionCmdBase): """ Test if messages created by the 'question' command have the correct format. """ def setUp(self) -> None: # create ChatDB structure - self.db_path = tempfile.TemporaryDirectory() - self.cache_path = tempfile.TemporaryDirectory() - self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), - db_path=Path(self.db_path.name)) + self.db_dir = tempfile.TemporaryDirectory() + self.cache_dir = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name), + db_path=Path(self.db_dir.name)) # create some messages self.message_text = Message(Question("What is this?"), Answer("It is pure text")) @@ -74,6 +90,7 @@ Aaaand again some text.""" os.remove(self.source_file1.name) os.remove(self.source_file2.name) os.remove(self.source_file3.name) + os.remove(self.source_file4.name) def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: # exclude '.next' @@ -81,10 +98,10 @@ Aaaand again some text.""" def test_message_file_created(self) -> None: self.args.ask = ["What is this?"] - cache_dir_files = self.message_list(self.cache_path) + cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) - cache_dir_files = self.message_list(self.cache_path) + cache_dir_files = self.message_list(self.cache_dir) self.assertEqual(len(cache_dir_files), 1) message = Message.from_file(cache_dir_files[0]) self.assertIsInstance(message, Message) @@ -193,3 +210,192 @@ This is embedded source code. It is embedded code ``` """)) + + +class TestQuestionCmd(TestQuestionCmdBase): + + def setUp(self) -> None: + # create DB and cache + self.db_dir = tempfile.TemporaryDirectory() + self.cache_dir = tempfile.TemporaryDirectory() + # create configuration + self.config = Config() + self.config.cache = self.cache_dir.name + self.config.db = self.db_dir.name + # create a mock argparse.Namespace + self.args = argparse.Namespace( + ask=['What is the meaning of life?'], + num_answers=1, + output_tags=['science'], + AI='openai', + model='gpt-3.5-turbo', + or_tags=None, + and_tags=None, + exclude_tags=None, + source_text=None, + source_code=None, + create=None, + repeat=None, + process=None, + overwrite=None + ) + # create a mock AI instance + self.ai = MagicMock(spec=AI) + self.ai.request.side_effect = self.mock_request + + def input_message(self, args: argparse.Namespace) -> Message: + """ + Create the expected input message for a question using the + given arguments. + """ + # NOTE: we only use the first question from the "ask" list + # -> message creation using "question.create_message()" is + # tested above + # the answer is always empty for the input message + return Message(Question(args.ask[0]), + tags=args.output_tags, + ai=args.AI, + model=args.model) + + def mock_request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Mock the 'ai.request()' function + """ + question.answer = Answer("Answer 0") + question.tags = set(otags) if otags else None + question.ai = 'FakeAI' + question.model = 'FakeModel' + answers: list[Message] = [question] + for n in range(1, num_answers): + answers.append(Message(question=question.question, + answer=Answer(f"Answer {n}"), + tags=otags, + ai='FakeAI', + model='FakeModel')) + return AIResponse(answers, Tokens(10, 10, 20)) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: + """ + Test single answer with no errors + """ + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + + # execute the command + question_cmd(self.args, self.config) + + # check for correct request call + self.ai.request.assert_called_once_with(expected_question, + ANY, + self.args.num_answers, + self.args.output_tags) + # check for the expected message files + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + + @mock.patch('chatmastermind.commands.question.ChatDB.from_dir') + @mock.patch('chatmastermind.commands.question.create_ai') + def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: + """ + Test single answer with no errors (mocked ChatDB version) + """ + chat = MagicMock(spec=ChatDB) + mock_from_dir.return_value = chat + + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + + # execute the command + question_cmd(self.args, self.config) + + # check for correct request call + self.ai.request.assert_called_once_with(expected_question, + chat, + self.args.num_answers, + self.args.output_tags) + + # check for the correct ChatDB calls: + # - initial question has been written (prior to the actual request) + # - responses have been written (after the request) + chat.cache_write.assert_has_calls([call([expected_question]), + call(expected_responses)], + any_order=False) + + # check that the messages have not been added to the internal message list + chat.cache_add.assert_not_called() + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_ask_with_error(self, mock_create_ai: MagicMock) -> None: + """ + Provoke an error during the AI request and verify that the question + has been correctly stored in the cache. + """ + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + self.ai.request.side_effect = AIError + + # execute the command + with self.assertRaises(AIError): + question_cmd(self.args, self.config) + + # check for correct request call + self.ai.request.assert_called_once_with(expected_question, + ANY, + self.args.num_answers, + self.args.output_tags) + # check for the expected message files + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, [expected_question]) + + @mock.patch('chatmastermind.commands.question.create_ai') + def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: + """ + Repeat a single question + """ + # 1. ask a question + mock_create_ai.return_value = self.ai + expected_question = self.input_message(self.args) + expected_responses = self.mock_request(expected_question, + Chat([]), + self.args.num_answers, + self.args.output_tags).messages + question_cmd(self.args, self.config) + chat = ChatDB.from_dir(Path(self.cache_dir.name), + Path(self.db_dir.name)) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 1) + self.assert_messages_equal(cached_msg, expected_responses) + + # 2. repeat the last question (without overwriting) + # -> expect two identical messages (except for the file_path) + self.args.ask = None + self.args.repeat = [] + self.args.overwrite = False + expected_responses += expected_responses + question_cmd(self.args, self.config) + cached_msg = chat.msg_gather(loc='cache') + self.assertEqual(len(self.message_list(self.cache_dir)), 2) + self.assert_messages_equal(cached_msg, expected_responses)