161 lines
5.6 KiB
Python
161 lines
5.6 KiB
Python
"""
|
|
Implements the OpenAI client classes and functions.
|
|
"""
|
|
import openai
|
|
import tiktoken
|
|
from typing import Optional, Union, Generator
|
|
from ..tags import Tag
|
|
from ..message import Message, Answer
|
|
from ..chat import Chat
|
|
from ..ai import AI, AIResponse, Tokens
|
|
from ..configuration import OpenAIConfig
|
|
|
|
ChatType = list[dict[str, str]]
|
|
|
|
|
|
class OpenAIAnswer:
|
|
def __init__(self,
|
|
idx: int,
|
|
streams: dict[int, 'OpenAIAnswer'],
|
|
response: openai.ChatCompletion,
|
|
tokens: Tokens,
|
|
encoding: tiktoken.core.Encoding) -> None:
|
|
self.idx = idx
|
|
self.streams = streams
|
|
self.response = response
|
|
self.position: int = 0
|
|
self.encoding = encoding
|
|
self.data: list[str] = []
|
|
self.finished: bool = False
|
|
self.tokens = tokens
|
|
|
|
def stream(self) -> Generator[str, None, None]:
|
|
while True:
|
|
if not self.next():
|
|
continue
|
|
if len(self.data) <= self.position:
|
|
break
|
|
yield self.data[self.position]
|
|
self.position += 1
|
|
|
|
def next(self) -> bool:
|
|
if self.finished:
|
|
return True
|
|
try:
|
|
chunk = next(self.response)
|
|
except StopIteration:
|
|
self.finished = True
|
|
if not self.finished:
|
|
found_choice = False
|
|
for choice in chunk.choices:
|
|
if not choice.finish_reason:
|
|
self.streams[choice.index].data.append(choice.delta.content)
|
|
self.tokens.completion += len(self.encoding.encode(choice.delta.content))
|
|
self.tokens.total = self.tokens.prompt + self.tokens.completion
|
|
if choice.index == self.idx:
|
|
found_choice = True
|
|
if not found_choice:
|
|
return False
|
|
return True
|
|
|
|
|
|
class OpenAI(AI):
|
|
"""
|
|
The OpenAI AI client.
|
|
"""
|
|
|
|
def __init__(self, config: OpenAIConfig) -> None:
|
|
self.ID = config.ID
|
|
self.name = config.name
|
|
self.config = config
|
|
self.client = openai.OpenAI(api_key=self.config.api_key)
|
|
|
|
def _completions(self, *args, **kw): # type: ignore
|
|
return self.client.chat.completions.create(*args, **kw)
|
|
|
|
def request(self,
|
|
question: Message,
|
|
chat: Chat,
|
|
num_answers: int = 1,
|
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
|
"""
|
|
Make an AI request, asking the given question with the given
|
|
chat history. The nr. of requested answers corresponds to the
|
|
nr. of messages in the 'AIResponse'.
|
|
"""
|
|
self.encoding = tiktoken.encoding_for_model(self.config.model)
|
|
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
|
|
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
|
|
response = self._completions(
|
|
model=self.config.model,
|
|
messages=oai_chat,
|
|
temperature=self.config.temperature,
|
|
max_tokens=self.config.max_tokens,
|
|
top_p=self.config.top_p,
|
|
n=num_answers,
|
|
stream=True,
|
|
frequency_penalty=self.config.frequency_penalty,
|
|
presence_penalty=self.config.presence_penalty)
|
|
streams: dict[int, OpenAIAnswer] = {}
|
|
for n in range(num_answers):
|
|
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
|
|
question.answer = Answer(streams[0].stream())
|
|
question.tags = set(otags) if otags is not None else None
|
|
question.ai = self.ID
|
|
question.model = self.config.model
|
|
answers: list[Message] = [question]
|
|
for idx in range(1, num_answers):
|
|
answers.append(Message(question=question.question,
|
|
answer=Answer(streams[idx].stream()),
|
|
tags=otags,
|
|
ai=self.ID,
|
|
model=self.config.model))
|
|
return AIResponse(answers, tokens)
|
|
|
|
def models(self) -> list[str]:
|
|
"""
|
|
Return all models supported by this AI.
|
|
"""
|
|
ret = []
|
|
for engine in sorted(self.client.models.list().data, key=lambda x: x.id):
|
|
ret.append(engine.id)
|
|
ret.sort()
|
|
return ret
|
|
|
|
def print_models(self) -> None:
|
|
"""
|
|
Print all models supported by the current AI.
|
|
"""
|
|
for model in self.models():
|
|
print(model)
|
|
|
|
def openai_chat(self, chat: Chat, system: str,
|
|
question: Optional[Message] = None) -> tuple[ChatType, int]:
|
|
"""
|
|
Create a chat history with system message in OpenAI format.
|
|
Optionally append a new question.
|
|
"""
|
|
oai_chat: ChatType = []
|
|
prompt_tokens: int = 0
|
|
|
|
def append(role: str, content: str) -> int:
|
|
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
|
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
|
|
|
|
prompt_tokens += append('system', system)
|
|
for message in chat.messages:
|
|
if message.answer:
|
|
prompt_tokens += append('user', message.question)
|
|
prompt_tokens += append('assistant', str(message.answer))
|
|
if question:
|
|
prompt_tokens += append('user', question.question)
|
|
return oai_chat, prompt_tokens
|
|
|
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
|
raise NotImplementedError
|
|
|
|
def print(self) -> None:
|
|
print(f"MODEL: {self.config.model}")
|
|
print("=== SYSTEM ===")
|
|
print(self.config.system)
|