From d2d9d9231f08bd3e02b298ebb40b568aa0c3dadb Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:52:07 +0200 Subject: [PATCH] question_cmd: fixed '--ask' command --- chatmastermind/ai.py | 6 ++++++ chatmastermind/ais/openai.py | 19 ++++++++++++++----- chatmastermind/commands/question.py | 15 ++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index e94de8e..b97b5f1 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -66,3 +66,9 @@ class AI(Protocol): and is not implemented for all AIs. """ raise NotImplementedError + + def print(self) -> None: + """ + Print some info about the current AI, like system message. + """ + pass diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 14ce33f..1db4d20 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -43,16 +43,20 @@ class OpenAI(AI): n=num_answers, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - answers: list[Message] = [] - for choice in response['choices']: # type: ignore + question.answer = Answer(response['choices'][0]['message']['content']) + question.tags = otags + question.ai = self.name + question.model = self.config.model + answers: list[Message] = [question] + for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, ai=self.name, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt'], - response['usage']['completion'], - response['usage']['total'])) + return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], + response['usage']['completion_tokens'], + response['usage']['total_tokens'])) def models(self) -> list[str]: """ @@ -95,3 +99,8 @@ class OpenAI(AI): def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError + + def print(self) -> None: + print(f"MODEL: {self.config.model}") + print("=== SYSTEM ===") + print(self.config.system) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 756a051..fdabd62 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: # create the correct AI instance ai: AI = create_ai(args, config) if args.ask: + ai.print() + chat.print(paged=False) response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME args.output_tags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass + chat.update_messages([response.messages[0]]) + chat.add_to_cache(response.messages[1:]) + for idx, msg in enumerate(response.messages): + print(f"=== ANSWER {idx+1} ===") + print(msg.answer) + if response.tokens: + print("===============") + print(response.tokens) elif args.repeat: lmessage = chat.latest_message() assert lmessage