Activate and use OpenAI streaming API.
This commit is contained in:
parent
8f95a362d2
commit
70164a1d45
@ -2,7 +2,8 @@
|
|||||||
Implements the OpenAI client classes and functions.
|
Implements the OpenAI client classes and functions.
|
||||||
"""
|
"""
|
||||||
import openai
|
import openai
|
||||||
from typing import Optional, Union
|
import tiktoken
|
||||||
|
from typing import Optional, Union, Generator
|
||||||
from ..tags import Tag
|
from ..tags import Tag
|
||||||
from ..message import Message, Answer
|
from ..message import Message, Answer
|
||||||
from ..chat import Chat
|
from ..chat import Chat
|
||||||
@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig
|
|||||||
ChatType = list[dict[str, str]]
|
ChatType = list[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAnswer:
|
||||||
|
def __init__(self,
|
||||||
|
idx: int,
|
||||||
|
streams: dict[int, 'OpenAIAnswer'],
|
||||||
|
response: openai.ChatCompletion,
|
||||||
|
tokens: Tokens,
|
||||||
|
encoding: tiktoken.core.Encoding) -> None:
|
||||||
|
self.idx = idx
|
||||||
|
self.streams = streams
|
||||||
|
self.response = response
|
||||||
|
self.position: int = 0
|
||||||
|
self.encoding = encoding
|
||||||
|
self.data: list[str] = []
|
||||||
|
self.finished: bool = False
|
||||||
|
self.tokens = tokens
|
||||||
|
|
||||||
|
def stream(self) -> Generator[str, None, None]:
|
||||||
|
while True:
|
||||||
|
if not self.next():
|
||||||
|
continue
|
||||||
|
if len(self.data) <= self.position:
|
||||||
|
break
|
||||||
|
yield self.data[self.position]
|
||||||
|
self.position += 1
|
||||||
|
|
||||||
|
def next(self) -> bool:
|
||||||
|
if self.finished:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
chunk = next(self.response)
|
||||||
|
except StopIteration:
|
||||||
|
self.finished = True
|
||||||
|
if not self.finished:
|
||||||
|
found_choice = False
|
||||||
|
for choice in chunk['choices']:
|
||||||
|
if not choice['finish_reason']:
|
||||||
|
self.streams[choice['index']].data.append(choice['delta']['content'])
|
||||||
|
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
|
||||||
|
self.tokens.total = self.tokens.prompt + self.tokens.completion
|
||||||
|
if choice['index'] == self.idx:
|
||||||
|
found_choice = True
|
||||||
|
if not found_choice:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class OpenAI(AI):
|
class OpenAI(AI):
|
||||||
"""
|
"""
|
||||||
The OpenAI AI client.
|
The OpenAI AI client.
|
||||||
@ -21,7 +68,6 @@ class OpenAI(AI):
|
|||||||
self.ID = config.ID
|
self.ID = config.ID
|
||||||
self.name = config.name
|
self.name = config.name
|
||||||
self.config = config
|
self.config = config
|
||||||
openai.api_key = config.api_key
|
|
||||||
|
|
||||||
def request(self,
|
def request(self,
|
||||||
question: Message,
|
question: Message,
|
||||||
@ -33,7 +79,10 @@ class OpenAI(AI):
|
|||||||
chat history. The nr. of requested answers corresponds to the
|
chat history. The nr. of requested answers corresponds to the
|
||||||
nr. of messages in the 'AIResponse'.
|
nr. of messages in the 'AIResponse'.
|
||||||
"""
|
"""
|
||||||
oai_chat = self.openai_chat(chat, self.config.system, question)
|
self.encoding = tiktoken.encoding_for_model(self.config.model)
|
||||||
|
openai.api_key = self.config.api_key
|
||||||
|
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
|
||||||
|
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
messages=oai_chat,
|
messages=oai_chat,
|
||||||
@ -41,22 +90,24 @@ class OpenAI(AI):
|
|||||||
max_tokens=self.config.max_tokens,
|
max_tokens=self.config.max_tokens,
|
||||||
top_p=self.config.top_p,
|
top_p=self.config.top_p,
|
||||||
n=num_answers,
|
n=num_answers,
|
||||||
|
stream=True,
|
||||||
frequency_penalty=self.config.frequency_penalty,
|
frequency_penalty=self.config.frequency_penalty,
|
||||||
presence_penalty=self.config.presence_penalty)
|
presence_penalty=self.config.presence_penalty)
|
||||||
question.answer = Answer(response['choices'][0]['message']['content'])
|
streams: dict[int, OpenAIAnswer] = {}
|
||||||
|
for n in range(num_answers):
|
||||||
|
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
|
||||||
|
question.answer = Answer(streams[0].stream())
|
||||||
question.tags = set(otags) if otags is not None else None
|
question.tags = set(otags) if otags is not None else None
|
||||||
question.ai = self.ID
|
question.ai = self.ID
|
||||||
question.model = self.config.model
|
question.model = self.config.model
|
||||||
answers: list[Message] = [question]
|
answers: list[Message] = [question]
|
||||||
for choice in response['choices'][1:]: # type: ignore
|
for idx in range(1, num_answers):
|
||||||
answers.append(Message(question=question.question,
|
answers.append(Message(question=question.question,
|
||||||
answer=Answer(choice['message']['content']),
|
answer=Answer(streams[idx].stream()),
|
||||||
tags=otags,
|
tags=otags,
|
||||||
ai=self.ID,
|
ai=self.ID,
|
||||||
model=self.config.model))
|
model=self.config.model))
|
||||||
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
|
return AIResponse(answers, tokens)
|
||||||
response['usage']['completion_tokens'],
|
|
||||||
response['usage']['total_tokens']))
|
|
||||||
|
|
||||||
def models(self) -> list[str]:
|
def models(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@ -83,24 +134,26 @@ class OpenAI(AI):
|
|||||||
print('\nNot ready: ' + ', '.join(not_ready))
|
print('\nNot ready: ' + ', '.join(not_ready))
|
||||||
|
|
||||||
def openai_chat(self, chat: Chat, system: str,
|
def openai_chat(self, chat: Chat, system: str,
|
||||||
question: Optional[Message] = None) -> ChatType:
|
question: Optional[Message] = None) -> tuple[ChatType, int]:
|
||||||
"""
|
"""
|
||||||
Create a chat history with system message in OpenAI format.
|
Create a chat history with system message in OpenAI format.
|
||||||
Optionally append a new question.
|
Optionally append a new question.
|
||||||
"""
|
"""
|
||||||
oai_chat: ChatType = []
|
oai_chat: ChatType = []
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
|
||||||
def append(role: str, content: str) -> None:
|
def append(role: str, content: str) -> int:
|
||||||
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||||
|
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
|
||||||
|
|
||||||
append('system', system)
|
prompt_tokens += append('system', system)
|
||||||
for message in chat.messages:
|
for message in chat.messages:
|
||||||
if message.answer:
|
if message.answer:
|
||||||
append('user', message.question)
|
prompt_tokens += append('user', message.question)
|
||||||
append('assistant', message.answer)
|
prompt_tokens += append('assistant', message.answer)
|
||||||
if question:
|
if question:
|
||||||
append('user', question.question)
|
prompt_tokens += append('user', question.question)
|
||||||
return oai_chat
|
return oai_chat, prompt_tokens
|
||||||
|
|
||||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -129,13 +129,16 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
|||||||
args.output_tags)
|
args.output_tags)
|
||||||
# only write the response messages to the cache,
|
# only write the response messages to the cache,
|
||||||
# don't add them to the internal list
|
# don't add them to the internal list
|
||||||
chat.cache_write(response.messages)
|
|
||||||
for idx, msg in enumerate(response.messages):
|
for idx, msg in enumerate(response.messages):
|
||||||
print(f"=== ANSWER {idx+1} ===")
|
print(f"=== ANSWER {idx+1} ===", flush=True)
|
||||||
print(msg.answer)
|
if msg.answer:
|
||||||
|
for piece in msg.answer:
|
||||||
|
print(piece, end='', flush=True)
|
||||||
|
print()
|
||||||
if response.tokens:
|
if response.tokens:
|
||||||
print("===============")
|
print("===============")
|
||||||
print(response.tokens)
|
print(response.tokens)
|
||||||
|
chat.cache_write(response.messages)
|
||||||
|
|
||||||
|
|
||||||
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
||||||
|
|||||||
@ -2,3 +2,4 @@ openai
|
|||||||
PyYAML
|
PyYAML
|
||||||
argcomplete
|
argcomplete
|
||||||
pytest
|
pytest
|
||||||
|
tiktoken
|
||||||
|
|||||||
@ -16,26 +16,37 @@ class OpenAITest(unittest.TestCase):
|
|||||||
openai = OpenAI(config)
|
openai = OpenAI(config)
|
||||||
|
|
||||||
# Set up the mock response from openai.ChatCompletion.create
|
# Set up the mock response from openai.ChatCompletion.create
|
||||||
mock_response = {
|
mock_chunk1 = {
|
||||||
'choices': [
|
'choices': [
|
||||||
{
|
{
|
||||||
'message': {
|
'index': 0,
|
||||||
|
'delta': {
|
||||||
'content': 'Answer 1'
|
'content': 'Answer 1'
|
||||||
}
|
},
|
||||||
|
'finish_reason': None
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'message': {
|
'index': 1,
|
||||||
|
'delta': {
|
||||||
'content': 'Answer 2'
|
'content': 'Answer 2'
|
||||||
}
|
},
|
||||||
|
'finish_reason': None
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'usage': {
|
|
||||||
'prompt_tokens': 10,
|
|
||||||
'completion_tokens': 20,
|
|
||||||
'total_tokens': 30
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
mock_create.return_value = mock_response
|
mock_chunk2 = {
|
||||||
|
'choices': [
|
||||||
|
{
|
||||||
|
'index': 0,
|
||||||
|
'finish_reason': 'stop'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'index': 1,
|
||||||
|
'finish_reason': 'stop'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
|
||||||
|
|
||||||
# Create test data
|
# Create test data
|
||||||
question = Message(Question('Question'))
|
question = Message(Question('Question'))
|
||||||
@ -57,9 +68,9 @@ class OpenAITest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(response.tokens)
|
self.assertIsNotNone(response.tokens)
|
||||||
self.assertIsInstance(response.tokens, Tokens)
|
self.assertIsInstance(response.tokens, Tokens)
|
||||||
assert response.tokens
|
assert response.tokens
|
||||||
self.assertEqual(response.tokens.prompt, 10)
|
self.assertEqual(response.tokens.prompt, 53)
|
||||||
self.assertEqual(response.tokens.completion, 20)
|
self.assertEqual(response.tokens.completion, 6)
|
||||||
self.assertEqual(response.tokens.total, 30)
|
self.assertEqual(response.tokens.total, 59)
|
||||||
|
|
||||||
# Assert the mock call to openai.ChatCompletion.create
|
# Assert the mock call to openai.ChatCompletion.create
|
||||||
mock_create.assert_called_once_with(
|
mock_create.assert_called_once_with(
|
||||||
@ -76,6 +87,7 @@ class OpenAITest(unittest.TestCase):
|
|||||||
max_tokens=config.max_tokens,
|
max_tokens=config.max_tokens,
|
||||||
top_p=config.top_p,
|
top_p=config.top_p,
|
||||||
n=2,
|
n=2,
|
||||||
|
stream=True,
|
||||||
frequency_penalty=config.frequency_penalty,
|
frequency_penalty=config.frequency_penalty,
|
||||||
presence_penalty=config.presence_penalty
|
presence_penalty=config.presence_penalty
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user