""" Implements the OpenAI client classes and functions. """ import openai import tiktoken from typing import Optional, Union, Generator from ..tags import Tag from ..message import Message, Answer from ..chat import Chat from ..ai import AI, AIResponse, Tokens from ..configuration import OpenAIConfig ChatType = list[dict[str, str]] class OpenAIAnswer: def __init__(self, idx: int, streams: dict[int, 'OpenAIAnswer'], response: openai.ChatCompletion, tokens: Tokens, encoding: tiktoken.core.Encoding) -> None: self.idx = idx self.streams = streams self.response = response self.position: int = 0 self.encoding = encoding self.data: list[str] = [] self.finished: bool = False self.tokens = tokens def stream(self) -> Generator[str, None, None]: while True: if not self.next(): continue if len(self.data) <= self.position: break yield self.data[self.position] self.position += 1 def next(self) -> bool: if self.finished: return True try: chunk = next(self.response) except StopIteration: self.finished = True if not self.finished: found_choice = False for choice in chunk.choices: if not choice.finish_reason: self.streams[choice.index].data.append(choice.delta.content) self.tokens.completion += len(self.encoding.encode(choice.delta.content)) self.tokens.total = self.tokens.prompt + self.tokens.completion if choice.index == self.idx: found_choice = True if not found_choice: return False return True class OpenAI(AI): """ The OpenAI AI client. """ def __init__(self, config: OpenAIConfig) -> None: self.ID = config.ID self.name = config.name self.config = config self.client = openai.OpenAI(api_key=self.config.api_key) def _completions(self, *args, **kw): # type: ignore return self.client.chat.completions.create(*args, **kw) def request(self, question: Message, chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ Make an AI request, asking the given question with the given chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ self.encoding = tiktoken.encoding_for_model(self.config.model) oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question) tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens) response = self._completions( model=self.config.model, messages=oai_chat, temperature=self.config.temperature, max_tokens=self.config.max_tokens, top_p=self.config.top_p, n=num_answers, stream=True, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) streams: dict[int, OpenAIAnswer] = {} for n in range(num_answers): streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding) question.answer = Answer(streams[0].stream()) question.tags = set(otags) if otags is not None else None question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] for idx in range(1, num_answers): answers.append(Message(question=question.question, answer=Answer(streams[idx].stream()), tags=otags, ai=self.ID, model=self.config.model)) return AIResponse(answers, tokens) def models(self) -> list[str]: """ Return all models supported by this AI. """ ret = [] for engine in sorted(self.client.models.list().data, key=lambda x: x.id): ret.append(engine.id) ret.sort() return ret def print_models(self) -> None: """ Print all models supported by the current AI. """ for model in self.models(): print(model) def openai_chat(self, chat: Chat, system: str, question: Optional[Message] = None) -> tuple[ChatType, int]: """ Create a chat history with system message in OpenAI format. Optionally append a new question. """ oai_chat: ChatType = [] prompt_tokens: int = 0 def append(role: str, content: str) -> int: oai_chat.append({'role': role, 'content': content.replace("''", "'")}) return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']]))) prompt_tokens += append('system', system) for message in chat.messages: if message.answer: prompt_tokens += append('user', message.question) prompt_tokens += append('assistant', str(message.answer)) if question: prompt_tokens += append('user', question.question) return oai_chat, prompt_tokens def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError def print(self) -> None: print(f"MODEL: {self.config.model}") print("=== SYSTEM ===") print(self.config.system)