97 lines
3.3 KiB
Python
97 lines
3.3 KiB
Python
"""
|
|
Implements the OpenAI client classes and functions.
|
|
"""
|
|
import openai
|
|
from typing import Optional, Union
|
|
from ..tags import Tag
|
|
from ..message import Message, Answer
|
|
from ..chat import Chat
|
|
from ..ai import AI, AIResponse, Tokens
|
|
from ..configuration import OpenAIConfig
|
|
|
|
ChatType = list[dict[str, str]]
|
|
|
|
|
|
class OpenAI(AI):
|
|
"""
|
|
The OpenAI AI client.
|
|
"""
|
|
|
|
def __init__(self, config: OpenAIConfig) -> None:
|
|
self.ai_type = config.ai_type
|
|
self.name = config.name
|
|
self.config = config
|
|
|
|
def request(self,
|
|
question: Message,
|
|
chat: Chat,
|
|
num_answers: int = 1,
|
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
|
"""
|
|
Make an AI request, asking the given question with the given
|
|
chat history. The nr. of requested answers corresponds to the
|
|
nr. of messages in the 'AIResponse'.
|
|
"""
|
|
oai_chat = self.openai_chat(chat, self.config.system, question)
|
|
response = openai.ChatCompletion.create(
|
|
model=self.config.model,
|
|
messages=oai_chat,
|
|
temperature=self.config.temperature,
|
|
max_tokens=self.config.max_tokens,
|
|
top_p=self.config.top_p,
|
|
n=num_answers,
|
|
frequency_penalty=self.config.frequency_penalty,
|
|
presence_penalty=self.config.presence_penalty)
|
|
answers: list[Message] = []
|
|
for choice in response['choices']: # type: ignore
|
|
answers.append(Message(question=question.question,
|
|
answer=Answer(choice['message']['content']),
|
|
tags=otags,
|
|
ai=self.name,
|
|
model=self.config.model))
|
|
return AIResponse(answers, Tokens(response['usage']['prompt'],
|
|
response['usage']['completion'],
|
|
response['usage']['total']))
|
|
|
|
def models(self) -> list[str]:
|
|
"""
|
|
Return all models supported by this AI.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def print_models(self) -> None:
|
|
"""
|
|
Print all models supported by the current AI.
|
|
"""
|
|
not_ready = []
|
|
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
|
if engine['ready']:
|
|
print(engine['id'])
|
|
else:
|
|
not_ready.append(engine['id'])
|
|
if len(not_ready) > 0:
|
|
print('\nNot ready: ' + ', '.join(not_ready))
|
|
|
|
def openai_chat(self, chat: Chat, system: str,
|
|
question: Optional[Message] = None) -> ChatType:
|
|
"""
|
|
Create a chat history with system message in OpenAI format.
|
|
Optionally append a new question.
|
|
"""
|
|
oai_chat: ChatType = []
|
|
|
|
def append(role: str, content: str) -> None:
|
|
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
|
|
|
append('system', system)
|
|
for message in chat.messages:
|
|
if message.answer:
|
|
append('user', message.question)
|
|
append('assistant', message.answer)
|
|
if question:
|
|
append('user', question.question)
|
|
return oai_chat
|
|
|
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
|
raise NotImplementedError
|