Compare commits

..

3 Commits

2 changed files with 7 additions and 44 deletions

View File

@ -71,7 +71,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags,
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
# only write the new message to the cache,
@ -92,8 +92,8 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
args.num_answers, # FIXME
args.output_tags) # FIXME
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)

View File

@ -439,12 +439,15 @@ class TestQuestionCmd(TestQuestionCmdBase):
"""
Repeat a single question after an error.
"""
# 1. ask a question and provoke an error
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
@ -471,43 +474,3 @@ class TestQuestionCmd(TestQuestionCmdBase):
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question with new arguments.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_question = Message(question=Question(expected_question.question),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model)
expected_responses += self.mock_request(new_expected_question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)