openai: more pseudo-implementation (brakes old tests)
This commit is contained in:
parent
f53ebdf372
commit
47cc1b2101
@ -2,9 +2,14 @@
|
|||||||
Implements the OpenAI client classes and functions.
|
Implements the OpenAI client classes and functions.
|
||||||
"""
|
"""
|
||||||
import openai
|
import openai
|
||||||
from ..message import Message
|
from typing import Optional
|
||||||
|
from ..tags import Tag
|
||||||
|
from ..message import Message, Answer
|
||||||
from ..chat import Chat
|
from ..chat import Chat
|
||||||
from ..ai import AI, AIResponse
|
from ..ai import AI, AIResponse, Tokens
|
||||||
|
from ..config import OpenAIConfig
|
||||||
|
|
||||||
|
ChatType = list[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(AI):
|
class OpenAI(AI):
|
||||||
@ -12,21 +17,38 @@ class OpenAI(AI):
|
|||||||
The OpenAI AI client.
|
The OpenAI AI client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config: OpenAIConfig
|
||||||
|
|
||||||
def request(self,
|
def request(self,
|
||||||
question: Message,
|
question: Message,
|
||||||
context: Chat,
|
chat: Chat,
|
||||||
num_answers: int = 1) -> AIResponse:
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
"""
|
"""
|
||||||
Make an AI request, asking the given question with the given
|
Make an AI request, asking the given question with the given
|
||||||
context (i. e. chat history). The nr. of requested answers
|
chat history. The nr. of requested answers corresponds to the
|
||||||
corresponds to the nr. of messages in the 'AIResponse'.
|
nr. of messages in the 'AIResponse'.
|
||||||
"""
|
"""
|
||||||
# TODO:
|
oai_chat = self.openai_chat(chat, self.config.system, question)
|
||||||
# * transform given message and chat context into OpenAI format
|
response = openai.ChatCompletion.create(
|
||||||
# * make request
|
model=self.config.model,
|
||||||
# * create a new Message for each answer and return them
|
messages=oai_chat,
|
||||||
# (writing Messages is done by the calles)
|
temperature=self.config.temperature,
|
||||||
raise NotImplementedError
|
max_tokens=self.config.max_tokens,
|
||||||
|
top_p=self.config.top_p,
|
||||||
|
n=num_answers,
|
||||||
|
frequency_penalty=self.config.frequency_penalty,
|
||||||
|
presence_penalty=self.config.presence_penalty)
|
||||||
|
answers: list[Message] = []
|
||||||
|
for choice in response['choices']: # type: ignore
|
||||||
|
answers.append(Message(question=question.question,
|
||||||
|
answer=Answer(choice['message']['content']),
|
||||||
|
tags=otags,
|
||||||
|
ai=self.name,
|
||||||
|
model=self.config.model))
|
||||||
|
return AIResponse(answers, Tokens(response['usage']['prompt'],
|
||||||
|
response['usage']['completion'],
|
||||||
|
response['usage']['total']))
|
||||||
|
|
||||||
def models(self) -> list[str]:
|
def models(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -46,3 +68,23 @@ class OpenAI(AI):
|
|||||||
not_ready.append(engine['id'])
|
not_ready.append(engine['id'])
|
||||||
if len(not_ready) > 0:
|
if len(not_ready) > 0:
|
||||||
print('\nNot ready: ' + ', '.join(not_ready))
|
print('\nNot ready: ' + ', '.join(not_ready))
|
||||||
|
|
||||||
|
def openai_chat(self, chat: Chat, system: str,
|
||||||
|
question: Optional[Message] = None) -> ChatType:
|
||||||
|
"""
|
||||||
|
Create a chat history with system message in OpenAI format.
|
||||||
|
Optionally append a new question.
|
||||||
|
"""
|
||||||
|
oai_chat: ChatType = []
|
||||||
|
|
||||||
|
def append(role: str, content: str) -> None:
|
||||||
|
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||||
|
|
||||||
|
append('system', system)
|
||||||
|
for message in chat.messages:
|
||||||
|
if message.answer:
|
||||||
|
append('user', message.question)
|
||||||
|
append('assistant', message.answer)
|
||||||
|
if question:
|
||||||
|
append('user', question.question)
|
||||||
|
return oai_chat
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class OpenAIConfig(AIConfig):
|
|||||||
top_p: float
|
top_p: float
|
||||||
frequency_penalty: float
|
frequency_penalty: float
|
||||||
presence_penalty: float
|
presence_penalty: float
|
||||||
|
system: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
||||||
@ -40,7 +41,8 @@ class OpenAIConfig(AIConfig):
|
|||||||
temperature=float(source['temperature']),
|
temperature=float(source['temperature']),
|
||||||
top_p=float(source['top_p']),
|
top_p=float(source['top_p']),
|
||||||
frequency_penalty=float(source['frequency_penalty']),
|
frequency_penalty=float(source['frequency_penalty']),
|
||||||
presence_penalty=float(source['presence_penalty'])
|
presence_penalty=float(source['presence_penalty']),
|
||||||
|
system=str(source['system'])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -56,12 +58,13 @@ class Config:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
||||||
"""
|
"""
|
||||||
Create OpenAIConfig from a dict.
|
Create Config from a dict.
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
system=str(source['system']),
|
system=str(source['system']),
|
||||||
db=str(source['db']),
|
db=str(source['db']),
|
||||||
openai=OpenAIConfig.from_dict(source['openai'])
|
# FIXME: move the 'system' parameter into the OpenAI section
|
||||||
|
openai=OpenAIConfig.from_dict(source['openai'].update({'system': source['system']}))
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user