From 70164a1d451b8ed90b2289c92dfb0e8250202c91 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 21 Oct 2023 14:21:48 +0200 Subject: [PATCH] Activate and use OpenAI streaming API. --- chatmastermind/ais/openai.py | 85 +++++++++++++++++++++++------ chatmastermind/commands/question.py | 9 ++- requirements.txt | 1 + tests/test_ais_openai.py | 40 +++++++++----- 4 files changed, 102 insertions(+), 33 deletions(-) diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index d7bb12f..a8ceb34 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -2,7 +2,8 @@ Implements the OpenAI client classes and functions. """ import openai -from typing import Optional, Union +import tiktoken +from typing import Optional, Union, Generator from ..tags import Tag from ..message import Message, Answer from ..chat import Chat @@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig ChatType = list[dict[str, str]] +class OpenAIAnswer: + def __init__(self, + idx: int, + streams: dict[int, 'OpenAIAnswer'], + response: openai.ChatCompletion, + tokens: Tokens, + encoding: tiktoken.core.Encoding) -> None: + self.idx = idx + self.streams = streams + self.response = response + self.position: int = 0 + self.encoding = encoding + self.data: list[str] = [] + self.finished: bool = False + self.tokens = tokens + + def stream(self) -> Generator[str, None, None]: + while True: + if not self.next(): + continue + if len(self.data) <= self.position: + break + yield self.data[self.position] + self.position += 1 + + def next(self) -> bool: + if self.finished: + return True + try: + chunk = next(self.response) + except StopIteration: + self.finished = True + if not self.finished: + found_choice = False + for choice in chunk['choices']: + if not choice['finish_reason']: + self.streams[choice['index']].data.append(choice['delta']['content']) + self.tokens.completion += len(self.encoding.encode(choice['delta']['content'])) + self.tokens.total = self.tokens.prompt + self.tokens.completion + if choice['index'] == self.idx: + found_choice = True + if not found_choice: + return False + return True + + class OpenAI(AI): """ The OpenAI AI client. @@ -21,7 +68,6 @@ class OpenAI(AI): self.ID = config.ID self.name = config.name self.config = config - openai.api_key = config.api_key def request(self, question: Message, @@ -33,7 +79,10 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - oai_chat = self.openai_chat(chat, self.config.system, question) + self.encoding = tiktoken.encoding_for_model(self.config.model) + openai.api_key = self.config.api_key + oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question) + tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, @@ -41,22 +90,24 @@ class OpenAI(AI): max_tokens=self.config.max_tokens, top_p=self.config.top_p, n=num_answers, + stream=True, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - question.answer = Answer(response['choices'][0]['message']['content']) + streams: dict[int, OpenAIAnswer] = {} + for n in range(num_answers): + streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding) + question.answer = Answer(streams[0].stream()) question.tags = set(otags) if otags is not None else None question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] - for choice in response['choices'][1:]: # type: ignore + for idx in range(1, num_answers): answers.append(Message(question=question.question, - answer=Answer(choice['message']['content']), + answer=Answer(streams[idx].stream()), tags=otags, ai=self.ID, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], - response['usage']['completion_tokens'], - response['usage']['total_tokens'])) + return AIResponse(answers, tokens) def models(self) -> list[str]: """ @@ -83,24 +134,26 @@ class OpenAI(AI): print('\nNot ready: ' + ', '.join(not_ready)) def openai_chat(self, chat: Chat, system: str, - question: Optional[Message] = None) -> ChatType: + question: Optional[Message] = None) -> tuple[ChatType, int]: """ Create a chat history with system message in OpenAI format. Optionally append a new question. """ oai_chat: ChatType = [] + prompt_tokens: int = 0 - def append(role: str, content: str) -> None: + def append(role: str, content: str) -> int: oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']]))) - append('system', system) + prompt_tokens += append('system', system) for message in chat.messages: if message.answer: - append('user', message.question) - append('assistant', message.answer) + prompt_tokens += append('user', message.question) + prompt_tokens += append('assistant', message.answer) if question: - append('user', question.question) - return oai_chat + prompt_tokens += append('user', question.question) + return oai_chat, prompt_tokens def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index ae96bac..79c37da 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -129,13 +129,16 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac args.output_tags) # only write the response messages to the cache, # don't add them to the internal list - chat.cache_write(response.messages) for idx, msg in enumerate(response.messages): - print(f"=== ANSWER {idx+1} ===") - print(msg.answer) + print(f"=== ANSWER {idx+1} ===", flush=True) + if msg.answer: + for piece in msg.answer: + print(piece, end='', flush=True) + print() if response.tokens: print("===============") print(response.tokens) + chat.cache_write(response.messages) def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: diff --git a/requirements.txt b/requirements.txt index 0762ecf..00e89b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ openai PyYAML argcomplete pytest +tiktoken diff --git a/tests/test_ais_openai.py b/tests/test_ais_openai.py index b53a14d..eab84e6 100644 --- a/tests/test_ais_openai.py +++ b/tests/test_ais_openai.py @@ -16,26 +16,37 @@ class OpenAITest(unittest.TestCase): openai = OpenAI(config) # Set up the mock response from openai.ChatCompletion.create - mock_response = { + mock_chunk1 = { 'choices': [ { - 'message': { + 'index': 0, + 'delta': { 'content': 'Answer 1' - } + }, + 'finish_reason': None }, { - 'message': { + 'index': 1, + 'delta': { 'content': 'Answer 2' - } + }, + 'finish_reason': None } ], - 'usage': { - 'prompt_tokens': 10, - 'completion_tokens': 20, - 'total_tokens': 30 - } } - mock_create.return_value = mock_response + mock_chunk2 = { + 'choices': [ + { + 'index': 0, + 'finish_reason': 'stop' + }, + { + 'index': 1, + 'finish_reason': 'stop' + } + ], + } + mock_create.return_value = iter([mock_chunk1, mock_chunk2]) # Create test data question = Message(Question('Question')) @@ -57,9 +68,9 @@ class OpenAITest(unittest.TestCase): self.assertIsNotNone(response.tokens) self.assertIsInstance(response.tokens, Tokens) assert response.tokens - self.assertEqual(response.tokens.prompt, 10) - self.assertEqual(response.tokens.completion, 20) - self.assertEqual(response.tokens.total, 30) + self.assertEqual(response.tokens.prompt, 53) + self.assertEqual(response.tokens.completion, 6) + self.assertEqual(response.tokens.total, 59) # Assert the mock call to openai.ChatCompletion.create mock_create.assert_called_once_with( @@ -76,6 +87,7 @@ class OpenAITest(unittest.TestCase): max_tokens=config.max_tokens, top_p=config.top_p, n=2, + stream=True, frequency_penalty=config.frequency_penalty, presence_penalty=config.presence_penalty )