ChatMasterMind Application Refactor and Enhancement #8

Merged
juk0de merged 122 commits from restructurings into main 2023-09-12 07:36:07 +02:00
3 changed files with 30 additions and 10 deletions
Showing only changes of commit cf50818f28 - Show all commits

View File

@ -66,3 +66,9 @@ class AI(Protocol):
and is not implemented for all AIs.
"""
raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass

View File

@ -43,16 +43,20 @@ class OpenAI(AI):
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
answers: list[Message] = []
for choice in response['choices']: # type: ignore
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags
question.ai = self.name
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.name,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'],
response['usage']['completion'],
response['usage']['total']))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]:
"""
@ -95,3 +99,8 @@ class OpenAI(AI):
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)

View File

@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance
ai: AI = create_ai(args, config)
if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
assert response
# TODO:
# * add answer to the message above (and create
# more messages for any additional answers)
pass
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat:
lmessage = chat.latest_message()
assert lmessage