From ee363d98942cadfa3750f86a8aa8f811d82cbe07 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 20 Oct 2023 13:43:31 +0200 Subject: [PATCH] Refactor message.Answer class in a way, that it can be constructed dynamically step by step, in preparation of using streaming API. --- chatmastermind/commands/question.py | 2 +- chatmastermind/message.py | 102 ++++++++++++++++++++++++---- tests/test_message.py | 2 +- 3 files changed, 89 insertions(+), 17 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index cd31d54..ae96bac 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -101,7 +101,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code_file is not None and len(code_file) > 0: add_file_as_code(question_parts, code_file) - full_question = '\n\n'.join(question_parts) + full_question = '\n\n'.join([str(s) for s in question_parts]) message = Message(question=Question(full_question), tags=args.output_tags, diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 8e7a55d..97e3e3a 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,9 @@ import pathlib import yaml import tempfile import shutil +import io from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple +from typing import Generator, Iterator from typing import get_args as typing_get_args from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -142,30 +144,100 @@ class Answer(str): txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __init__(self, data: Union[str, Generator[str, None, None]]) -> None: + # Indicator of whether all of data has been processed + self.is_exhausted: bool = False + + # Initialize data + self.iterator: Iterator[str] = self._init_data(data) + + # Set up the buffer to hold the 'Answer' content + self.buffer: io.StringIO = io.StringIO() + + def _init_data(self, data: Union[str, Generator[str, None, None]]) -> Iterator[str]: """ - Make sure the answer string does not contain the header as a whole line. + Process input data (either a string or a string generator) """ - if cls.txt_header in string.split('\n'): - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") - instance = super().__new__(cls, string) - return instance + if isinstance(data, str): + yield data + else: + yield from data + + def __str__(self) -> str: + """ + Output all content when converted into a string + """ + # Ensure all data has been processed + for _ in self: + pass + # Return the 'Answer' content + return self.buffer.getvalue() + + def __repr__(self) -> str: + return repr(str(self)) + + def __iter__(self) -> Generator[str, None, None]: + """ + Allows the object to be iterable + """ + # Generate content if not all data has been processed + if not self.is_exhausted: + yield from self.generator_iter() + else: + yield self.buffer.getvalue() + + def generator_iter(self) -> Generator[str, None, None]: + """ + Main generator method to process data + """ + for piece in self.iterator: + # Write to buffer and yield piece for the iterator + self.buffer.write(piece) + yield piece + self.is_exhausted = True # Set the flag that all data has been processed + # If the header occurs in the 'Answer' content, raise an error + if f'\n{self.txt_header}' in self.buffer.getvalue() or self.buffer.getvalue().startswith(self.txt_header): + raise MessageError(f"Answer {repr(self.buffer.getvalue())} contains the header {repr(Answer.txt_header)}") + + def __eq__(self, other: object) -> bool: + """ + Comparing the object to a string or another object + """ + if isinstance(other, str): + return str(self) == other # Compare the string value of this object to the other string + # Default behavior for comparing non-string objects + return super().__eq__(other) + + def __hash__(self) -> int: + """ + Generate a hash for the object based on its string representation. + """ + return hash(str(self)) + + def __format__(self, format_spec: str) -> str: + """ + Return a formatted version of the string as per the format specification. + """ + return str(self).__format__(format_spec) @classmethod def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ - Build Question from a list of strings. Make sure strings do not contain the header. + Build Answer from a list of strings. Make sure strings do not contain the header. """ - if cls.txt_header in strings: - raise MessageError(f"Question contains the header '{cls.txt_header}'") - instance = super().__new__(cls, '\n'.join(strings).strip()) - return instance + def _gen() -> Generator[str, None, None]: + if len(strings) > 0: + yield strings[0] + for s in strings[1:]: + yield '\n' + yield s + return cls(_gen()) def source_code(self, include_delims: bool = False) -> list[str]: """ Extract and return all source code sections. """ - return source_code(self, include_delims) + return source_code(str(self), include_delims) class Question(str): @@ -441,7 +513,7 @@ class Message(): output.append(self.question) if self.answer: output.append(Answer.txt_header) - output.append(self.answer) + output.append(str(self.answer)) return '\n'.join(output) def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11 @@ -491,7 +563,7 @@ class Message(): temp_fd.write(f'{ModelLine.from_model(self.model)}\n') temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{str(self.answer)}\n') shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: @@ -560,7 +632,7 @@ class Message(): or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 - or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in str(self.answer))) # noqa: W503 or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 or (mfilter.model_state == 'available' and not self.model) # noqa: W503 diff --git a/tests/test_message.py b/tests/test_message.py index b79bcae..0a6c2de 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -91,7 +91,7 @@ class QuestionTestCase(unittest.TestCase): class AnswerTestCase(unittest.TestCase): def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer(f"{Answer.txt_header}\nno") + str(Answer(f"{Answer.txt_header}\nno")) def test_answer_with_legal_header(self) -> None: answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")