Compare commits

..

No commits in common. "7c1c67f8ff32df04addec5643bd0cbcaf35090cc" and "17c6fa2453d54f369e276ec56e674d84fa5de6f9" have entirely different histories.

6 changed files with 51 additions and 192 deletions

View File

@ -2,8 +2,7 @@
Implements the OpenAI client classes and functions.
"""
import openai
import tiktoken
from typing import Optional, Union, Generator
from typing import Optional, Union
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
@ -13,52 +12,6 @@ from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAIAnswer:
def __init__(self,
idx: int,
streams: dict[int, 'OpenAIAnswer'],
response: openai.ChatCompletion,
tokens: Tokens,
encoding: tiktoken.core.Encoding) -> None:
self.idx = idx
self.streams = streams
self.response = response
self.position: int = 0
self.encoding = encoding
self.data: list[str] = []
self.finished: bool = False
self.tokens = tokens
def stream(self) -> Generator[str, None, None]:
while True:
if not self.next():
continue
if len(self.data) <= self.position:
break
yield self.data[self.position]
self.position += 1
def next(self) -> bool:
if self.finished:
return True
try:
chunk = next(self.response)
except StopIteration:
self.finished = True
if not self.finished:
found_choice = False
for choice in chunk['choices']:
if not choice['finish_reason']:
self.streams[choice['index']].data.append(choice['delta']['content'])
self.tokens.completion += len(self.encoding.encode(choice['delta']['content']))
self.tokens.total = self.tokens.prompt + self.tokens.completion
if choice['index'] == self.idx:
found_choice = True
if not found_choice:
return False
return True
class OpenAI(AI):
"""
The OpenAI AI client.
@ -68,6 +21,7 @@ class OpenAI(AI):
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = config.api_key
def request(self,
question: Message,
@ -79,10 +33,7 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
self.encoding = tiktoken.encoding_for_model(self.config.model)
openai.api_key = self.config.api_key
oai_chat, prompt_tokens = self.openai_chat(chat, self.config.system, question)
tokens: Tokens = Tokens(prompt_tokens, 0, prompt_tokens)
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
@ -90,24 +41,22 @@ class OpenAI(AI):
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
stream=True,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
streams: dict[int, OpenAIAnswer] = {}
for n in range(num_answers):
streams[n] = OpenAIAnswer(n, streams, response, tokens, self.encoding)
question.answer = Answer(streams[0].stream())
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for idx in range(1, num_answers):
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(streams[idx].stream()),
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.ID,
model=self.config.model))
return AIResponse(answers, tokens)
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]:
"""
@ -134,26 +83,24 @@ class OpenAI(AI):
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> tuple[ChatType, int]:
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
prompt_tokens: int = 0
def append(role: str, content: str) -> int:
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
return len(self.encoding.encode(', '.join(['role:', oai_chat[-1]['role'], 'content:', oai_chat[-1]['content']])))
prompt_tokens += append('system', system)
append('system', system)
for message in chat.messages:
if message.answer:
prompt_tokens += append('user', message.question)
prompt_tokens += append('assistant', message.answer)
append('user', message.question)
append('assistant', message.answer)
if question:
prompt_tokens += append('user', question.question)
return oai_chat, prompt_tokens
append('user', question.question)
return oai_chat
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError

View File

@ -101,7 +101,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
if code_file is not None and len(code_file) > 0:
add_file_as_code(question_parts, code_file)
full_question = '\n\n'.join([str(s) for s in question_parts])
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags,
@ -129,16 +129,13 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
args.output_tags)
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===", flush=True)
if msg.answer:
for piece in msg.answer:
print(piece, end='', flush=True)
print()
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
chat.cache_write(response.messages)
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:

View File

@ -5,9 +5,7 @@ import pathlib
import yaml
import tempfile
import shutil
import io
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import Generator, Iterator
from typing import get_args as typing_get_args
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@ -51,7 +49,7 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
code_lines: list[str] = []
in_code_block = False
for line in str(text).split('\n'):
for line in text.split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
@ -144,100 +142,30 @@ class Answer(str):
txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer'
def __init__(self, data: Union[str, Generator[str, None, None]]) -> None:
# Indicator of whether all of data has been processed
self.is_exhausted: bool = False
# Initialize data
self.iterator: Iterator[str] = self._init_data(data)
# Set up the buffer to hold the 'Answer' content
self.buffer: io.StringIO = io.StringIO()
def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]:
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Process input data (either a string or a string generator)
Make sure the answer string does not contain the header as a whole line.
"""
if isinstance(data, str):
yield data
else:
yield from data
def __str__(self) -> str:
"""
Output all content when converted into a string
"""
# Ensure all data has been processed
for _ in self:
pass
# Return the 'Answer' content
return self.buffer.getvalue()
def __repr__(self) -> str:
return repr(str(self))
def __iter__(self) -> Generator[str, None, None]:
"""
Allows the object to be iterable
"""
# Generate content if not all data has been processed
if not self.is_exhausted:
yield from self.generator_iter()
else:
yield self.buffer.getvalue()
def generator_iter(self) -> Generator[str, None, None]:
"""
Main generator method to process data
"""
for piece in self.iterator:
# Write to buffer and yield piece for the iterator
self.buffer.write(piece)
yield piece
self.is_exhausted = True # Set the flag that all data has been processed
# If the header occurs in the 'Answer' content, raise an error
if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header):
raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}")
def __eq__(self, other: object) -> bool:
"""
Comparing the object to a string or another object
"""
if isinstance(other, str):
return str(self) == other # Compare the string value of this object to the other string
# Default behavior for comparing non-string objects
return super().__eq__(other)
def __hash__(self) -> int:
"""
Generate a hash for the object based on its string representation.
"""
return hash(str(self))
def __format__(self, format_spec: str) -> str:
"""
Return a formatted version of the string as per the format specification.
"""
return str(self).__format__(format_spec)
if cls.txt_header in string.split('\n'):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Answer from a list of strings. Make sure strings do not contain the header.
Build Question from a list of strings. Make sure strings do not contain the header.
"""
def _gen() -> Generator[str, None, None]:
if len(strings) > 0:
yield strings[0]
for s in strings[1:]:
yield '\n'
yield s
return cls(_gen())
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(str(self), include_delims)
return source_code(self, include_delims)
class Question(str):
@ -513,7 +441,7 @@ class Message():
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(str(self.answer))
output.append(self.answer)
return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
@ -563,7 +491,7 @@ class Message():
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n')
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
@ -632,7 +560,7 @@ class Message():
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503

View File

@ -2,4 +2,3 @@ openai
PyYAML
argcomplete
pytest
tiktoken

View File

@ -16,37 +16,26 @@ class OpenAITest(unittest.TestCase):
openai = OpenAI(config)
# Set up the mock response from openai.ChatCompletion.create
mock_chunk1 = {
mock_response = {
'choices': [
{
'index': 0,
'delta': {
'message': {
'content': 'Answer 1'
},
'finish_reason': None
}
},
{
'index': 1,
'delta': {
'message': {
'content': 'Answer 2'
},
'finish_reason': None
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 20,
'total_tokens': 30
}
mock_chunk2 = {
'choices': [
{
'index': 0,
'finish_reason': 'stop'
},
{
'index': 1,
'finish_reason': 'stop'
}
],
}
mock_create.return_value = iter([mock_chunk1, mock_chunk2])
mock_create.return_value = mock_response
# Create test data
question = Message(Question('Question'))
@ -68,9 +57,9 @@ class OpenAITest(unittest.TestCase):
self.assertIsNotNone(response.tokens)
self.assertIsInstance(response.tokens, Tokens)
assert response.tokens
self.assertEqual(response.tokens.prompt, 53)
self.assertEqual(response.tokens.completion, 6)
self.assertEqual(response.tokens.total, 59)
self.assertEqual(response.tokens.prompt, 10)
self.assertEqual(response.tokens.completion, 20)
self.assertEqual(response.tokens.total, 30)
# Assert the mock call to openai.ChatCompletion.create
mock_create.assert_called_once_with(
@ -87,7 +76,6 @@ class OpenAITest(unittest.TestCase):
max_tokens=config.max_tokens,
top_p=config.top_p,
n=2,
stream=True,
frequency_penalty=config.frequency_penalty,
presence_penalty=config.presence_penalty
)

View File

@ -91,7 +91,7 @@ class QuestionTestCase(unittest.TestCase):
class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError):
str(Answer(f"{Answer.txt_header}\nno"))
Answer(f"{Answer.txt_header}\nno")
def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")