Compare commits

..

2 Commits

Author SHA1 Message Date
be35a8ea9e added testcases for messages.py 2023-08-18 16:56:33 +02:00
2830b076f8 added new module 'message.py' 2023-08-18 16:56:33 +02:00
2 changed files with 53 additions and 2 deletions

View File

@ -4,6 +4,7 @@ Module implementing message related functions and classes.
from typing import Type, TypeVar
QuestionInst = TypeVar('QuestionInst', bound='Question')
AnswerInst = TypeVar('AnswerInst', bound='Answer')
class MessageError(Exception):
@ -12,7 +13,7 @@ class MessageError(Exception):
def source_code(text: str, include_delims: bool = False) -> list[str]:
"""
Extracts all source code sections from the given text, i. e. all lines
Extract all source code sections from the given text, i. e. all lines
surrounded by lines tarting with '```'. If 'include_delims' is True,
the surrounding lines are included, otherwise they are omitted. The
result list contains every source code section as a single string.
@ -50,3 +51,31 @@ class Question(str):
raise MessageError(f"Question '{string}' contains the prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
class Answer(str):
"""
A single answer with a defined prefix.
"""
prefix = '=== ANSWER ==='
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Make sure the answer string does not contain the prefix.
"""
if cls.prefix in string:
raise MessageError(f"Answer '{string}' contains the prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)

View File

@ -8,7 +8,7 @@ from chatmastermind.api_client import ai
from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from chatmastermind.tags import Tag, TagLine, TagError
from chatmastermind.message import source_code
from chatmastermind.message import source_code, MessageError, Question, Answer
from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY
@ -400,3 +400,25 @@ class SourceCodeTestCase(CmmTestCase):
expected_result: list[str] = []
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
class QuestionTestCase(CmmTestCase):
def test_question_with_prefix(self) -> None:
with self.assertRaises(MessageError):
Question("=== QUESTION === What is your name?")
def test_question_without_prefix(self) -> None:
question = Question("What is your favorite color?")
self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase):
def test_answer_with_prefix(self) -> None:
with self.assertRaises(MessageError):
Answer("=== ANSWER === Yes")
def test_answer_without_prefix(self) -> None:
answer = Answer("No")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No")