import unittest import argparse from typing import Union, Optional from chatmastermind.configuration import Config, AIConfig from chatmastermind.tags import Tag from chatmastermind.message import Message, Answer from chatmastermind.chat import Chat from chatmastermind.ai import AI, AIResponse, Tokens, AIError class FakeAI(AI): """ A mocked version of the 'AI' class. """ ID: str name: str config: AIConfig def models(self) -> list[str]: raise NotImplementedError def tokens(self, data: Union[Message, Chat]) -> int: return 123 def print(self) -> None: pass def print_models(self) -> None: pass def __init__(self, ID: str, model: str, error: bool = False): self.ID = ID self.model = model self.error = error def request(self, question: Message, chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ Mock the 'ai.request()' function by either returning fake answers or raising an exception. """ if self.error: raise AIError question.answer = Answer("Answer 0") question.tags = set(otags) if otags is not None else None question.ai = self.ID question.model = self.model answers: list[Message] = [question] for n in range(1, num_answers): answers.append(Message(question=question.question, answer=Answer(f"Answer {n}"), tags=otags, ai=self.ID, model=self.model)) return AIResponse(answers, Tokens(10, 10, 20)) class TestWithFakeAI(unittest.TestCase): def assert_messages_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None: """ Compare messages using more than just Question and Answer. """ self.assertEqual(len(msg1), len(msg2)) for m1, m2 in zip(msg1, msg2): # exclude the file_path, compare only Q, A and metadata self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI: """ Mocked 'create_ai' that returns a 'FakeAI' instance. """ return FakeAI(args.AI, args.model) def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI: """ Mocked 'create_ai' that returns a 'FakeAI' instance. """ return FakeAI(args.AI, args.model, error=True)