Compare commits
3 Commits
820d4aed21
...
cad8d7ac71
| Author | SHA1 | Date | |
|---|---|---|---|
| cad8d7ac71 | |||
| fd69c07eb4 | |||
| dc0454a72c |
@ -71,7 +71,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||
full_question = '\n\n'.join(question_parts)
|
||||
|
||||
message = Message(question=Question(full_question),
|
||||
tags=args.output_tags, # FIXME
|
||||
tags=args.output_tags,
|
||||
ai=args.AI,
|
||||
model=args.model)
|
||||
# only write the new message to the cache,
|
||||
@ -92,8 +92,8 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
||||
print(message.to_str())
|
||||
response: AIResponse = ai.request(message,
|
||||
chat,
|
||||
args.num_answers, # FIXME
|
||||
args.output_tags) # FIXME
|
||||
args.num_answers,
|
||||
args.output_tags)
|
||||
# only write the response messages to the cache,
|
||||
# don't add them to the internal list
|
||||
chat.cache_write(response.messages)
|
||||
|
||||
@ -439,15 +439,12 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
||||
"""
|
||||
Repeat a single question after an error.
|
||||
"""
|
||||
# 1. ask a question
|
||||
# 1. ask a question and provoke an error
|
||||
mock_create_ai.return_value = self.ai
|
||||
expected_question = self.input_message(self.args)
|
||||
self.ai.request.side_effect = AIError
|
||||
|
||||
# execute the command
|
||||
with self.assertRaises(AIError):
|
||||
question_cmd(self.args, self.config)
|
||||
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
@ -474,3 +471,43 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
||||
# also check that the file ID has not been changed
|
||||
assert cached_msg[0].file_path
|
||||
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
||||
|
||||
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
|
||||
"""
|
||||
Repeat a single question with new arguments.
|
||||
"""
|
||||
# 1. ask a question
|
||||
mock_create_ai.return_value = self.ai
|
||||
expected_question = self.input_message(self.args)
|
||||
expected_responses = self.mock_request(expected_question,
|
||||
Chat([]),
|
||||
self.args.num_answers,
|
||||
self.args.output_tags).messages
|
||||
question_cmd(self.args, self.config)
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||
self.assert_messages_equal(cached_msg, expected_responses)
|
||||
|
||||
# 2. repeat the last question with new arguments (without overwriting)
|
||||
# -> expect two messages with identical question and answer, but different metadata
|
||||
self.args.ask = None
|
||||
self.args.repeat = []
|
||||
self.args.overwrite = False
|
||||
self.args.output_tags = ['newtag']
|
||||
self.args.AI = 'newai'
|
||||
self.args.model = 'newmodel'
|
||||
new_expected_question = Message(question=Question(expected_question.question),
|
||||
tags=set(self.args.output_tags),
|
||||
ai=self.args.AI,
|
||||
model=self.args.model)
|
||||
expected_responses += self.mock_request(new_expected_question,
|
||||
Chat([]),
|
||||
self.args.num_answers,
|
||||
set(self.args.output_tags)).messages
|
||||
question_cmd(self.args, self.config)
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||
self.assert_messages_equal(cached_msg, expected_responses)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user