""" Implements the OpenAI client classes and functions. """ import openai from typing import Optional, Union from ..tags import Tag from ..message import Message, Answer from ..chat import Chat from ..ai import AI, AIResponse, Tokens from ..configuration import OpenAIConfig ChatType = list[dict[str, str]] class OpenAI(AI): """ The OpenAI AI client. """ def __init__(self, config: OpenAIConfig) -> None: self.ID = config.ID self.name = config.name self.config = config openai.api_key = config.api_key def request(self, question: Message, chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ Make an AI request, asking the given question with the given chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, temperature=self.config.temperature, max_tokens=self.config.max_tokens, top_p=self.config.top_p, n=num_answers, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) question.answer = Answer(response['choices'][0]['message']['content']) question.tags = otags question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, ai=self.ID, model=self.config.model)) return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], response['usage']['completion_tokens'], response['usage']['total_tokens'])) def models(self) -> list[str]: """ Return all models supported by this AI. """ raise NotImplementedError def print_models(self) -> None: """ Print all models supported by the current AI. """ not_ready = [] for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): if engine['ready']: print(engine['id']) else: not_ready.append(engine['id']) if len(not_ready) > 0: print('\nNot ready: ' + ', '.join(not_ready)) def openai_chat(self, chat: Chat, system: str, question: Optional[Message] = None) -> ChatType: """ Create a chat history with system message in OpenAI format. Optionally append a new question. """ oai_chat: ChatType = [] def append(role: str, content: str) -> None: oai_chat.append({'role': role, 'content': content.replace("''", "'")}) append('system', system) for message in chat.messages: if message.answer: append('user', message.question) append('assistant', message.answer) if question: append('user', question.question) return oai_chat def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError def print(self) -> None: print(f"MODEL: {self.config.model}") print("=== SYSTEM ===") print(self.config.system)