openai: more pseudo-implementation (brakes old tests)

This commit is contained in:
juk0de 2023-09-03 08:54:58 +02:00
parent f53ebdf372
commit 47cc1b2101
2 changed files with 60 additions and 15 deletions

View File

@ -2,9 +2,14 @@
Implements the OpenAI client classes and functions.
"""
import openai
from ..message import Message
from typing import Optional
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
from ..ai import AI, AIResponse
from ..ai import AI, AIResponse, Tokens
from ..config import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAI(AI):
@ -12,21 +17,38 @@ class OpenAI(AI):
The OpenAI AI client.
"""
config: OpenAIConfig
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
# TODO:
# * transform given message and chat context into OpenAI format
# * make request
# * create a new Message for each answer and return them
# (writing Messages is done by the calles)
raise NotImplementedError
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
answers: list[Message] = []
for choice in response['choices']: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.name,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'],
response['usage']['completion'],
response['usage']['total']))
def models(self) -> list[str]:
"""
@ -46,3 +68,23 @@ class OpenAI(AI):
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
if question:
append('user', question.question)
return oai_chat

View File

@ -26,6 +26,7 @@ class OpenAIConfig(AIConfig):
top_p: float
frequency_penalty: float
presence_penalty: float
system: str
@classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@ -40,7 +41,8 @@ class OpenAIConfig(AIConfig):
temperature=float(source['temperature']),
top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty'])
presence_penalty=float(source['presence_penalty']),
system=str(source['system'])
)
@ -56,12 +58,13 @@ class Config:
@classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
"""
Create OpenAIConfig from a dict.
Create Config from a dict.
"""
return cls(
system=str(source['system']),
db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai'])
# FIXME: move the 'system' parameter into the OpenAI section
openai=OpenAIConfig.from_dict(source['openai'].update({'system': source['system']}))
)
@classmethod