openai: more pseudo-implementation (brakes old tests)

This commit is contained in:
juk0de 2023-09-03 08:54:58 +02:00
parent f53ebdf372
commit 47cc1b2101
2 changed files with 60 additions and 15 deletions

View File

@ -2,9 +2,14 @@
Implements the OpenAI client classes and functions. Implements the OpenAI client classes and functions.
""" """
import openai import openai
from ..message import Message from typing import Optional
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat from ..chat import Chat
from ..ai import AI, AIResponse from ..ai import AI, AIResponse, Tokens
from ..config import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAI(AI): class OpenAI(AI):
@ -12,21 +17,38 @@ class OpenAI(AI):
The OpenAI AI client. The OpenAI AI client.
""" """
config: OpenAIConfig
def request(self, def request(self,
question: Message, question: Message,
context: Chat, chat: Chat,
num_answers: int = 1) -> AIResponse: num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
""" """
Make an AI request, asking the given question with the given Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers chat history. The nr. of requested answers corresponds to the
corresponds to the nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
# TODO: oai_chat = self.openai_chat(chat, self.config.system, question)
# * transform given message and chat context into OpenAI format response = openai.ChatCompletion.create(
# * make request model=self.config.model,
# * create a new Message for each answer and return them messages=oai_chat,
# (writing Messages is done by the calles) temperature=self.config.temperature,
raise NotImplementedError max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
answers: list[Message] = []
for choice in response['choices']: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.name,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'],
response['usage']['completion'],
response['usage']['total']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@ -46,3 +68,23 @@ class OpenAI(AI):
not_ready.append(engine['id']) not_ready.append(engine['id'])
if len(not_ready) > 0: if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
if question:
append('user', question.question)
return oai_chat

View File

@ -26,6 +26,7 @@ class OpenAIConfig(AIConfig):
top_p: float top_p: float
frequency_penalty: float frequency_penalty: float
presence_penalty: float presence_penalty: float
system: str
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@ -40,7 +41,8 @@ class OpenAIConfig(AIConfig):
temperature=float(source['temperature']), temperature=float(source['temperature']),
top_p=float(source['top_p']), top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty']),
system=str(source['system'])
) )
@ -56,12 +58,13 @@ class Config:
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
""" """
Create OpenAIConfig from a dict. Create Config from a dict.
""" """
return cls( return cls(
system=str(source['system']), system=str(source['system']),
db=str(source['db']), db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai']) # FIXME: move the 'system' parameter into the OpenAI section
openai=OpenAIConfig.from_dict(source['openai'].update({'system': source['system']}))
) )
@classmethod @classmethod