question_cmd: added testclass for the 'question_cmd()' function
This commit is contained in:
parent
69d916f0cc
commit
0e00d25ef3
@ -3,10 +3,13 @@ import unittest
|
||||
import argparse
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from chatmastermind.commands.question import create_message
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, call
|
||||
from chatmastermind.configuration import Config
|
||||
from chatmastermind.commands.question import create_message, question_cmd
|
||||
from chatmastermind.message import Message, Question, Answer
|
||||
from chatmastermind.chat import ChatDB
|
||||
from chatmastermind.ai import AI, AIResponse, Tokens
|
||||
|
||||
|
||||
class TestMessageCreate(unittest.TestCase):
|
||||
@ -74,6 +77,7 @@ Aaaand again some text."""
|
||||
os.remove(self.source_file1.name)
|
||||
os.remove(self.source_file2.name)
|
||||
os.remove(self.source_file3.name)
|
||||
os.remove(self.source_file4.name)
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||
# exclude '.next'
|
||||
@ -193,3 +197,93 @@ This is embedded source code.
|
||||
It is embedded code
|
||||
```
|
||||
"""))
|
||||
|
||||
|
||||
class TestQuestionCmd(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
# create DB and cache
|
||||
self.db_path = tempfile.TemporaryDirectory()
|
||||
self.cache_path = tempfile.TemporaryDirectory()
|
||||
# create configuration
|
||||
self.config = Config()
|
||||
# create a mock argparse.Namespace
|
||||
self.args = argparse.Namespace(
|
||||
ask=['What is the meaning of life?'],
|
||||
num_answers=1,
|
||||
output_tags=['science'],
|
||||
AI='openai',
|
||||
model='gpt-3.5-turbo',
|
||||
or_tags=None,
|
||||
and_tags=None,
|
||||
exclude_tags=None,
|
||||
source_text=None,
|
||||
source_code=None,
|
||||
create=None,
|
||||
repeat=None,
|
||||
process=None
|
||||
)
|
||||
|
||||
def input_message(self, args: argparse.Namespace) -> Message:
|
||||
"""
|
||||
Create the expected input message for a question using the
|
||||
given arguments.
|
||||
"""
|
||||
# NOTE: we only use the first question from the "ask" list
|
||||
# -> message creation using "question.create_message()" is
|
||||
# tested above
|
||||
# the answer is always empty for the input message
|
||||
return Message(Question(args.ask[0]),
|
||||
tags=args.output_tags,
|
||||
ai=args.AI,
|
||||
model=args.model)
|
||||
|
||||
def response(self, args: argparse.Namespace) -> AIResponse:
|
||||
"""
|
||||
Create the expected AI response from the give arguments.
|
||||
"""
|
||||
input_msg = self.input_message(args)
|
||||
response = AIResponse(messages=[], tokens=Tokens(10, 10, 20))
|
||||
for n in range(args.num_answers):
|
||||
response_msg = Message(input_msg.question,
|
||||
Answer(f"Answer {n}"),
|
||||
tags=input_msg.tags,
|
||||
ai=input_msg.ai,
|
||||
model=input_msg.model)
|
||||
response.messages.append(response_msg)
|
||||
return response
|
||||
|
||||
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
|
||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||
def test_ask_single_answer(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
|
||||
|
||||
# FIXME: this mock is only neccessary because the cache dir is not
|
||||
# configurable in the configuration file
|
||||
chat = MagicMock(spec=ChatDB)
|
||||
mock_from_dir.return_value = chat
|
||||
|
||||
# create a mock AI instance
|
||||
ai = MagicMock(spec=AI)
|
||||
ai.request.return_value = self.response(self.args)
|
||||
mock_create_ai.return_value = ai
|
||||
expected_question = self.input_message(self.args)
|
||||
expected_responses = ai.request.return_value.messages
|
||||
|
||||
# execute the command
|
||||
question_cmd(self.args, self.config)
|
||||
|
||||
# check for correct request call
|
||||
ai.request.assert_called_once_with(expected_question,
|
||||
chat,
|
||||
self.args.num_answers,
|
||||
self.args.output_tags)
|
||||
|
||||
# check for the correct ChatDB calls:
|
||||
# - initial question has been written (prior to the actual request)
|
||||
# - responses have been written (after the request)
|
||||
chat.cache_write.assert_has_calls([call([expected_question]),
|
||||
call(expected_responses)],
|
||||
any_order=False)
|
||||
|
||||
# check that the messages have not been added to the internal message list
|
||||
chat.cache_add.assert_not_called()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user