import unittest import argparse from typing import Union, Optional from chatmastermind.configuration import Config, AIConfig from chatmastermind.tags import Tag from chatmastermind.message import Message, Answer from chatmastermind.chat import Chat from chatmastermind.ai import AI, AIResponse, Tokens, AIError class FakeAI(AI): """ A mocked version of the 'AI' class. """ ID: str name: str config: AIConfig def models(self) -> list[str]: raise NotImplementedError def tokens(self, data: Union[Message, Chat]) -> int: return 123 def print(self) -> None: pass def print_models(self) -> None: pass def __init__(self, ID: str, model: str, error: bool = False): self.ID = ID self.model = model self.error = error def request(self, question: Message, chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ Mock the 'ai.request()' function by either returning fake answers or raising an exception. """ if self.error: raise AIError question.answer = Answer("Answer 0") question.tags = set(otags) if otags is not None else None question.ai = self.ID question.model = self.model answers: list[Message] = [question] for n in range(1, num_answers): answers.append(Message(question=question.question, answer=Answer(f"Answer {n}"), tags=otags, ai=self.ID, model=self.model)) return AIResponse(answers, Tokens(10, 10, 20)) class TestWithFakeAI(unittest.TestCase): """ Base class for all tests that need to use the FakeAI. """ def assert_msgs_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using Question, Answer and all metadata excecot for the file_path. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): # exclude the file_path, compare only Q, A and metadata self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) def assert_msgs_all_equal(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using Question, Answer and ALL metadata. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): self.assertTrue(m1.equals(m2, verbose=True)) def assert_msgs_content_equal(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using only Question and Answer. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): self.assertEqual(m1, m2) def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI: """ Mocked 'create_ai' that returns a 'FakeAI' instance. """ return FakeAI(args.AI, args.model) def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI: """ Mocked 'create_ai' that returns a 'FakeAI' instance. """ return FakeAI(args.AI, args.model, error=True)