question_cmd: added testclass for the 'question_cmd()' function

This commit is contained in:
juk0de 2023-09-18 09:57:19 +02:00
parent dc0454a72c
commit fd69c07eb4

View File

@ -3,23 +3,39 @@ import unittest
import argparse import argparse
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock from unittest import mock
from chatmastermind.commands.question import create_message from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class TestMessageCreate(unittest.TestCase): class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestMessageCreate(TestQuestionCmdBase):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
""" """
def setUp(self) -> None: def setUp(self) -> None:
# create ChatDB structure # create ChatDB structure
self.db_path = tempfile.TemporaryDirectory() self.db_dir = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_path.name)) db_path=Path(self.db_dir.name))
# create some messages # create some messages
self.message_text = Message(Question("What is this?"), self.message_text = Message(Question("What is this?"),
Answer("It is pure text")) Answer("It is pure text"))
@ -74,6 +90,7 @@ Aaaand again some text."""
os.remove(self.source_file1.name) os.remove(self.source_file1.name)
os.remove(self.source_file2.name) os.remove(self.source_file2.name)
os.remove(self.source_file3.name) os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
@ -81,10 +98,10 @@ Aaaand again some text."""
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args) create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0]) message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
@ -193,3 +210,192 @@ This is embedded source code.
It is embedded code It is embedded code
``` ```
""")) """))
class TestQuestionCmd(TestQuestionCmdBase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='openai',
model='gpt-3.5-turbo',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None,
overwrite=None
)
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version)
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)