Compare commits

..

4 Commits

Author SHA1 Message Date
7c1c67f8ff Merge pull request 'Dynamic Answer class and OpenAI streaming API' (#19) from dynamic_answer into main
Introduces several changes with the main objective of enabling OpenAI's streaming API in the chatmastermind application. This allows for the retrieval of AI responses gradually as a stream, which can significantly improve the user experience in interactions that involve large result sets.

* Added tiktoken import in 'openai.py' and modifications to the OpenAI class to support streaming. This includes the addition of a new class OpenAIAnswer to handle streaming API responses.
* Modified request function in the OpenAI class: the stream=True flag is added to the openai.ChatCompletion.create method to enable streaming API.
* Modified 'question.py' to print the answer parts as they are streamed.
* Replaced the Answer class's string data type with a generator which supports str and Generator[str, None, None] data types. Modifications are made to the Answer class methods to handle both data types accordingly.
* Updated the tests in 'test_ais_openai.py' and 'test_message.py' to reflect and validate these changes.
2023-10-21 15:50:45 +02:00
Oleksandr Kozachuk
dbe72ff11c Activate and use OpenAI streaming API. 2023-10-21 14:21:48 +02:00
Oleksandr Kozachuk
bbc1ab5a0a Fix source_code function with the dynamic answer class. 2023-10-20 14:02:09 +02:00
Oleksandr Kozachuk
2aee018708 Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. 2023-10-20 13:43:31 +02:00
6 changed files with 192 additions and 51 deletions

View File

@ -2,7 +2,8 @@
Implements the OpenAI client classes and functions. Implements the OpenAI client classes and functions.
""" """
import openai import openai
from typing import Optional, Union import tiktoken
from typing import Optional, Union, Generator
from ..tags import Tag from ..tags import Tag
from ..message import Message, Answer from ..message import Message, Answer
from ..chat import Chat from ..chat import Chat
@ -12,6 +13,52 @@ from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]] ChatType = list[dict[str, str]]
class OpenAIAnswer:
def __init__(self,
idx: int,
streams: dict[int, 'OpenAIAnswer'],
response: openai.ChatCompletion,
tokens: Tokens,
encoding: tiktoken.core.Encoding) -> None:
self.idx = idx
self.streams = streams
self.response = response
self.position: int = 0
self.encoding = encoding
self.data: list[str] = []
self.finished: bool = False
self.tokens = tokens
def stream(self) -> Generator[str, None, None]:
while True:
if not self.next():
continue
if len(self.data) <= self.position:
break
yield self.data[self.position]
self.position += 1
def next(self) -> bool:
if self.finished:
return True
try:
chunk = next(self.response)
except StopIteration:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk['choices']:
if not choice['finish_reason']:
self.streams[choice['index']].data.append(choice['delta']['content'])
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice['index'] == self.idx:
found_choice = True
if not found_choice:
return False
return True
class OpenAI(AI): class OpenAI(AI):
""" """
The OpenAI AI client. The OpenAI AI client.
@ -21,7 +68,6 @@ class OpenAI(AI):
self.ID = config.ID self.ID = config.ID
self.name = config.name self.name = config.name
self.config = config self.config = config
openai.api_key = config.api_key
def request(self, def request(self,
question: Message, question: Message,
@ -33,7 +79,10 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
oai_chat = self.openai_chat(chat, self.config.system, question) self.encoding = tiktoken.encoding_for_model(self.config.model)
openai.api_key = self.config.api_key
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,
@ -41,22 +90,24 @@ class OpenAI(AI):
max_tokens=self.config.max_tokens, max_tokens=self.config.max_tokens,
top_p=self.config.top_p, top_p=self.config.top_p,
n=num_answers, n=num_answers,
stream=True,
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content']) streams: dict[int, OpenAIAnswer] = {}
for n in range(num_answers):
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
question.answer = Answer(streams[0].stream())
question.tags = set(otags) if otags is not None else None question.tags = set(otags) if otags is not None else None
question.ai = self.ID question.ai = self.ID
question.model = self.config.model question.model = self.config.model
answers: list[Message] = [question] answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore for idx in range(1, num_answers):
answers.append(Message(question=question.question, answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']), answer=Answer(streams[idx].stream()),
tags=otags, tags=otags,
ai=self.ID, ai=self.ID,
model=self.config.model)) model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], return AIResponse(answers, tokens)
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]: def models(self) -> list[str]:
""" """
@ -83,24 +134,26 @@ class OpenAI(AI):
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str, def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType: question: Optional[Message] = None) -> tuple[ChatType, int]:
""" """
Create a chat history with system message in OpenAI format. Create a chat history with system message in OpenAI format.
Optionally append a new question. Optionally append a new question.
""" """
oai_chat: ChatType = [] oai_chat: ChatType = []
prompt_tokens: int = 0
def append(role: str, content: str) -> None: def append(role: str, content: str) -> int:
oai_chat.append({'role': role, 'content': content.replace("''", "'")}) oai_chat.append({'role': role, 'content': content.replace("''", "'")})
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
append('system', system) prompt_tokens += append('system', system)
for message in chat.messages: for message in chat.messages:
if message.answer: if message.answer:
append('user', message.question) prompt_tokens += append('user', message.question)
append('assistant', message.answer) prompt_tokens += append('assistant', message.answer)
if question: if question:
append('user', question.question) prompt_tokens += append('user', question.question)
return oai_chat return oai_chat, prompt_tokens
def tokens(self, data: Union[Message, Chat]) -> int: def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError raise NotImplementedError

View File

@ -101,7 +101,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code_file is not None and len(code_file) > 0: if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file) add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join(question_parts) full_question = '\n\n'.join([str(s) for s in question_parts])
message = Message(question=Question(full_question), message = Message(question=Question(full_question),
tags=args.output_tags, tags=args.output_tags,
@ -129,13 +129,16 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
args.output_tags) args.output_tags)
# only write the response messages to the cache, # only write the response messages to the cache,
# don't add them to the internal list # don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages): for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===") print(f"=== ANSWER {idx+1} ===", flush=True)
print(msg.answer) if msg.answer:
for piece in msg.answer:
print(piece, end='', flush=True)
print()
if response.tokens: if response.tokens:
print("===============") print("===============")
print(response.tokens) print(response.tokens)
chat.cache_write(response.messages)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:

View File

@ -5,7 +5,9 @@ import pathlib
import yaml import yaml
import tempfile import tempfile
import shutil import shutil
import io
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import Generator, Iterator
from typing import get_args as typing_get_args from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@ -49,7 +51,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
code_lines: list[str] = [] code_lines: list[str] = []
in_code_block = False in_code_block = False
for line in text.split('\n'): for line in str(text).split('\n'):
if line.strip().startswith('```'): if line.strip().startswith('```'):
if include_delims: if include_delims:
code_lines.append(line) code_lines.append(line)
@ -142,30 +144,100 @@ class Answer(str):
txt_header: ClassVar[str] = '==== ANSWER ====' txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer' yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
# Indicator of whether all of data has been processed
self.is_exhausted: bool = False
# Initialize data
self.iterator: Iterator[str] = self._init_data(data)
# Set up the buffer to hold the 'Answer' content
self.buffer: io.StringIO = io.StringIO()
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
""" """
Make sure the answer string does not contain the header as a whole line. Process input data (either a string or a string generator)
""" """
if cls.txt_header in string.split('\n'): if isinstance(data, str):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") yield data
instance = super().__new__(cls, string) else:
return instance yield from data
def __str__(self) -> str:
"""
Output all content when converted into a string
"""
# Ensure all data has been processed
for _ in self:
pass
# Return the 'Answer' content
return self.buffer.getvalue()
def __repr__(self) -> str:
return repr(str(self))
def __iter__(self) -> Generator[str, None, None]:
"""
Allows the object to be iterable
"""
# Generate content if not all data has been processed
if not self.is_exhausted:
yield from self.generator_iter()
else:
yield self.buffer.getvalue()
def generator_iter(self) -> Generator[str, None, None]:
"""
Main generator method to process data
"""
for piece in self.iterator:
# Write to buffer and yield piece for the iterator
self.buffer.write(piece)
yield piece
self.is_exhausted = True # Set the flag that all data has been processed
# If the header occurs in the 'Answer' content, raise an error
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
def __eq__(self, other: object) -> bool:
"""
Comparing the object to a string or another object
"""
if isinstance(other, str):
return str(self) == other # Compare the string value of this object to the other string
# Default behavior for comparing non-string objects
return super().__eq__(other)
def __hash__(self) -> int:
"""
Generate a hash for the object based on its string representation.
"""
return hash(str(self))
def __format__(self, format_spec: str) -> str:
"""
Return a formatted version of the string as per the format specification.
"""
return str(self).__format__(format_spec)
@classmethod @classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Answer from a list of strings. Make sure strings do not contain the header.
""" """
if cls.txt_header in strings: def _gen() -> Generator[str, None, None]:
raise MessageError(f"Question contains the header '{cls.txt_header}'") if len(strings) > 0:
instance = super().__new__(cls, '\n'.join(strings).strip()) yield strings[0]
return instance for s in strings[1:]:
yield '\n'
yield s
return cls(_gen())
def source_code(self, include_delims: bool = False) -> list[str]: def source_code(self, include_delims: bool = False) -> list[str]:
""" """
Extract and return all source code sections. Extract and return all source code sections.
""" """
return source_code(self, include_delims) return source_code(str(self), include_delims)
class Question(str): class Question(str):
@ -441,7 +513,7 @@ class Message():
output.append(self.question) output.append(self.question)
if self.answer: if self.answer:
output.append(Answer.txt_header) output.append(Answer.txt_header)
output.append(self.answer) output.append(str(self.answer))
return '\n'.join(output) return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
@ -491,7 +563,7 @@ class Message():
temp_fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer: if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
shutil.move(temp_file_path, file_path) shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None: def __to_file_yaml(self, file_path: pathlib.Path) -> None:
@ -560,7 +632,7 @@ class Message():
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503

View File

@ -2,3 +2,4 @@ openai
PyYAML PyYAML
argcomplete argcomplete
pytest pytest
tiktoken

View File

@ -16,26 +16,37 @@ class OpenAITest(unittest.TestCase):
openai = OpenAI(config) openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create # Set up the mock response from openai.ChatCompletion.create
mock_response = { mock_chunk1 = {
'choices': [ 'choices': [
{ {
'message': { 'index': 0,
'delta': {
'content': 'Answer 1' 'content': 'Answer 1'
} },
'finish_reason': None
}, },
{ {
'message': { 'index': 1,
'delta': {
'content': 'Answer 2' 'content': 'Answer 2'
} },
'finish_reason': None
} }
], ],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
} }
mock_chunk2 = {
'choices': [
{
'index': 0,
'finish_reason': 'stop'
},
{
'index': 1,
'finish_reason': 'stop'
} }
mock_create.return_value = mock_response ],
}
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
# Create test data # Create test data
question = Message(Question('Question')) question = Message(Question('Question'))
@ -57,9 +68,9 @@ class OpenAITest(unittest.TestCase):
self.assertIsNotNone(response.tokens) self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens) self.assertIsInstance(response.tokens, Tokens)
assert response.tokens assert response.tokens
self.assertEqual(response.tokens.prompt, 10) self.assertEqual(response.tokens.prompt, 53)
self.assertEqual(response.tokens.completion, 20) self.assertEqual(response.tokens.completion, 6)
self.assertEqual(response.tokens.total, 30) self.assertEqual(response.tokens.total, 59)
# Assert the mock call to openai.ChatCompletion.create # Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with( mock_create.assert_called_once_with(
@ -76,6 +87,7 @@ class OpenAITest(unittest.TestCase):
max_tokens=config.max_tokens, max_tokens=config.max_tokens,
top_p=config.top_p, top_p=config.top_p,
n=2, n=2,
stream=True,
frequency_penalty=config.frequency_penalty, frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty presence_penalty=config.presence_penalty
) )

View File

@ -91,7 +91,7 @@ class QuestionTestCase(unittest.TestCase):
class AnswerTestCase(unittest.TestCase): class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") str(Answer(f"{Answer.txt_header}\nno"))
def test_answer_with_legal_header(self) -> None: def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")