From 0dbb0d3c4da2a44194b64af8b323d509e7baeb94 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 16:00:24 +0200 Subject: [PATCH] message: improved robustness of Question and Answer content checks and tests --- chatmastermind/message.py | 48 +++++++++++++++++++++------------------ tests/test_message.py | 29 ++++++++++++++++++----- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 675ab3a..384fb96 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -128,29 +128,29 @@ class ModelLine(str): return cls(' '.join([cls.prefix, model])) -class Question(str): +class Answer(str): """ - A single question with a defined header. + A single answer with a defined header. """ - tokens: int = 0 # tokens used by this question - txt_header: ClassVar[str] = '=== QUESTION ===' - yaml_key: ClassVar[str] = 'question' + tokens: int = 0 # tokens used by this answer + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ - Make sure the question string does not contain the header. + Make sure the answer string does not contain the header as a whole line. """ - if cls.txt_header in string: - raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if cls.txt_header in string.split('\n'): + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance @@ -162,29 +162,33 @@ class Question(str): return source_code(self, include_delims) -class Answer(str): +class Question(str): """ - A single answer with a defined header. + A single question with a defined header. """ - tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' - yaml_key: ClassVar[str] = 'answer' + tokens: int = 0 # tokens used by this question + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ - Make sure the answer string does not contain the header. + Make sure the question string does not contain the header as a whole line + (also not that from 'Answer', so it's always clear where the answer starts). """ - if cls.txt_header in string: - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + string_lines = string.split('\n') + if cls.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if Answer.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance diff --git a/tests/test_message.py b/tests/test_message.py index 0d7953e..e01de66 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase): class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: + def test_question_with_header(self) -> None: with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") + Question(f"{Question.txt_header}\nWhat is your name?") - def test_question_without_prefix(self) -> None: + def test_question_with_answer_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Answer.txt_header}\nBob") + + def test_question_with_legal_header(self) -> None: + """ + If the header is just a part of a line, it's fine. + """ + question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + self.assertIsInstance(question, Question) + self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + + def test_question_without_header(self) -> None: question = Question("What is your favorite color?") self.assertIsInstance(question, Question) self.assertEqual(question, "What is your favorite color?") class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: + def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") + Answer(f"{Answer.txt_header}\nno") - def test_answer_without_prefix(self) -> None: + def test_answer_with_legal_header(self) -> None: + answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + + def test_answer_without_header(self) -> None: answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No")