Compare commits
3 Commits
820d4aed21
...
cad8d7ac71
| Author | SHA1 | Date | |
|---|---|---|---|
| cad8d7ac71 | |||
| fd69c07eb4 | |||
| dc0454a72c |
@ -71,7 +71,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
|||||||
full_question = '\n\n'.join(question_parts)
|
full_question = '\n\n'.join(question_parts)
|
||||||
|
|
||||||
message = Message(question=Question(full_question),
|
message = Message(question=Question(full_question),
|
||||||
tags=args.output_tags, # FIXME
|
tags=args.output_tags,
|
||||||
ai=args.AI,
|
ai=args.AI,
|
||||||
model=args.model)
|
model=args.model)
|
||||||
# only write the new message to the cache,
|
# only write the new message to the cache,
|
||||||
@ -92,8 +92,8 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
|||||||
print(message.to_str())
|
print(message.to_str())
|
||||||
response: AIResponse = ai.request(message,
|
response: AIResponse = ai.request(message,
|
||||||
chat,
|
chat,
|
||||||
args.num_answers, # FIXME
|
args.num_answers,
|
||||||
args.output_tags) # FIXME
|
args.output_tags)
|
||||||
# only write the response messages to the cache,
|
# only write the response messages to the cache,
|
||||||
# don't add them to the internal list
|
# don't add them to the internal list
|
||||||
chat.cache_write(response.messages)
|
chat.cache_write(response.messages)
|
||||||
|
|||||||
@ -439,15 +439,12 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
"""
|
"""
|
||||||
Repeat a single question after an error.
|
Repeat a single question after an error.
|
||||||
"""
|
"""
|
||||||
# 1. ask a question
|
# 1. ask a question and provoke an error
|
||||||
mock_create_ai.return_value = self.ai
|
mock_create_ai.return_value = self.ai
|
||||||
expected_question = self.input_message(self.args)
|
expected_question = self.input_message(self.args)
|
||||||
self.ai.request.side_effect = AIError
|
self.ai.request.side_effect = AIError
|
||||||
|
|
||||||
# execute the command
|
|
||||||
with self.assertRaises(AIError):
|
with self.assertRaises(AIError):
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
|
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
Path(self.db_dir.name))
|
Path(self.db_dir.name))
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
@ -474,3 +471,43 @@ class TestQuestionCmd(TestQuestionCmdBase):
|
|||||||
# also check that the file ID has not been changed
|
# also check that the file ID has not been changed
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
|
||||||
|
|
||||||
|
@mock.patch('chatmastermind.commands.question.create_ai')
|
||||||
|
def test_repeat_single_question_new_args(self, mock_create_ai: MagicMock) -> None:
|
||||||
|
"""
|
||||||
|
Repeat a single question with new arguments.
|
||||||
|
"""
|
||||||
|
# 1. ask a question
|
||||||
|
mock_create_ai.return_value = self.ai
|
||||||
|
expected_question = self.input_message(self.args)
|
||||||
|
expected_responses = self.mock_request(expected_question,
|
||||||
|
Chat([]),
|
||||||
|
self.args.num_answers,
|
||||||
|
self.args.output_tags).messages
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
|
||||||
|
self.assert_messages_equal(cached_msg, expected_responses)
|
||||||
|
|
||||||
|
# 2. repeat the last question with new arguments (without overwriting)
|
||||||
|
# -> expect two messages with identical question and answer, but different metadata
|
||||||
|
self.args.ask = None
|
||||||
|
self.args.repeat = []
|
||||||
|
self.args.overwrite = False
|
||||||
|
self.args.output_tags = ['newtag']
|
||||||
|
self.args.AI = 'newai'
|
||||||
|
self.args.model = 'newmodel'
|
||||||
|
new_expected_question = Message(question=Question(expected_question.question),
|
||||||
|
tags=set(self.args.output_tags),
|
||||||
|
ai=self.args.AI,
|
||||||
|
model=self.args.model)
|
||||||
|
expected_responses += self.mock_request(new_expected_question,
|
||||||
|
Chat([]),
|
||||||
|
self.args.num_answers,
|
||||||
|
set(self.args.output_tags)).messages
|
||||||
|
question_cmd(self.args, self.config)
|
||||||
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
|
self.assert_messages_equal(cached_msg, expected_responses)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user