question_cmd: fixed '--ask' command
This commit is contained in:
parent
864ab7aeb1
commit
faac42d3c2
@ -66,3 +66,9 @@ class AI(Protocol):
|
|||||||
and is not implemented for all AIs.
|
and is not implemented for all AIs.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print(self) -> None:
|
||||||
|
"""
|
||||||
|
Print some info about the current AI, like system message.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@ -43,16 +43,20 @@ class OpenAI(AI):
|
|||||||
n=num_answers,
|
n=num_answers,
|
||||||
frequency_penalty=self.config.frequency_penalty,
|
frequency_penalty=self.config.frequency_penalty,
|
||||||
presence_penalty=self.config.presence_penalty)
|
presence_penalty=self.config.presence_penalty)
|
||||||
answers: list[Message] = []
|
question.answer = Answer(response['choices'][0]['message']['content'])
|
||||||
for choice in response['choices']: # type: ignore
|
question.tags = otags
|
||||||
|
question.ai = self.name
|
||||||
|
question.model = self.config.model
|
||||||
|
answers: list[Message] = [question]
|
||||||
|
for choice in response['choices'][1:]: # type: ignore
|
||||||
answers.append(Message(question=question.question,
|
answers.append(Message(question=question.question,
|
||||||
answer=Answer(choice['message']['content']),
|
answer=Answer(choice['message']['content']),
|
||||||
tags=otags,
|
tags=otags,
|
||||||
ai=self.name,
|
ai=self.name,
|
||||||
model=self.config.model))
|
model=self.config.model))
|
||||||
return AIResponse(answers, Tokens(response['usage']['prompt'],
|
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
|
||||||
response['usage']['completion'],
|
response['usage']['completion_tokens'],
|
||||||
response['usage']['total']))
|
response['usage']['total_tokens']))
|
||||||
|
|
||||||
def models(self) -> list[str]:
|
def models(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -95,3 +99,8 @@ class OpenAI(AI):
|
|||||||
|
|
||||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print(self) -> None:
|
||||||
|
print(f"MODEL: {self.config.model}")
|
||||||
|
print("=== SYSTEM ===")
|
||||||
|
print(self.config.system)
|
||||||
|
|||||||
@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
# create the correct AI instance
|
# create the correct AI instance
|
||||||
ai: AI = create_ai(args, config)
|
ai: AI = create_ai(args, config)
|
||||||
if args.ask:
|
if args.ask:
|
||||||
|
ai.print()
|
||||||
|
chat.print(paged=False)
|
||||||
response: AIResponse = ai.request(message,
|
response: AIResponse = ai.request(message,
|
||||||
chat,
|
chat,
|
||||||
args.num_answers, # FIXME
|
args.num_answers, # FIXME
|
||||||
args.output_tags) # FIXME
|
args.output_tags) # FIXME
|
||||||
assert response
|
chat.update_messages([response.messages[0]])
|
||||||
# TODO:
|
chat.add_to_cache(response.messages[1:])
|
||||||
# * add answer to the message above (and create
|
for idx, msg in enumerate(response.messages):
|
||||||
# more messages for any additional answers)
|
print(f"=== ANSWER {idx+1} ===")
|
||||||
pass
|
print(msg.answer)
|
||||||
|
if response.tokens:
|
||||||
|
print("===============")
|
||||||
|
print(response.tokens)
|
||||||
elif args.repeat:
|
elif args.repeat:
|
||||||
lmessage = chat.latest_message()
|
lmessage = chat.latest_message()
|
||||||
assert lmessage
|
assert lmessage
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user