diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py new file mode 100644 index 0000000..74438b8 --- /dev/null +++ b/chatmastermind/ais/openai.py @@ -0,0 +1,96 @@ +""" +Implements the OpenAI client classes and functions. +""" +import openai +from typing import Optional, Union +from ..tags import Tag +from ..message import Message, Answer +from ..chat import Chat +from ..ai import AI, AIResponse, Tokens +from ..configuration import OpenAIConfig + +ChatType = list[dict[str, str]] + + +class OpenAI(AI): + """ + The OpenAI AI client. + """ + + def __init__(self, name: str, config: OpenAIConfig) -> None: + self.name = name + self.config = config + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + chat history. The nr. of requested answers corresponds to the + nr. of messages in the 'AIResponse'. + """ + # FIXME: use real 'system' message (store in OpenAIConfig) + oai_chat = self.openai_chat(chat, "system", question) + response = openai.ChatCompletion.create( + model=self.config.model, + messages=oai_chat, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + n=num_answers, + frequency_penalty=self.config.frequency_penalty, + presence_penalty=self.config.presence_penalty) + answers: list[Message] = [] + for choice in response['choices']: # type: ignore + answers.append(Message(question=question.question, + answer=Answer(choice['message']['content']), + tags=otags, + ai=self.name, + model=self.config.model)) + return AIResponse(answers, Tokens(response['usage']['prompt'], + response['usage']['completion'], + response['usage']['total'])) + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def print_models(self) -> None: + """ + Print all models supported by the current AI. + """ + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def openai_chat(self, chat: Chat, system: str, + question: Optional[Message] = None) -> ChatType: + """ + Create a chat history with system message in OpenAI format. + Optionally append a new question. + """ + oai_chat: ChatType = [] + + def append(role: str, content: str) -> None: + oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + + append('system', system) + for message in chat.messages: + if message.answer: + append('user', message.question) + append('assistant', message.answer) + if question: + append('user', question.question) + return oai_chat + + def tokens(self, data: Union[Message, Chat]) -> int: + raise NotImplementedError