From b2021ba36b75832b0927998446044acd2d7b7d38 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 23 Sep 2023 09:03:20 +0200 Subject: [PATCH] tests: moved 'FakeAI' and common functions to 'test_common.py' --- tests/test_common.py | 81 +++++++++++++++++++++++++++++++ tests/test_question_cmd.py | 99 +++++--------------------------------- 2 files changed, 93 insertions(+), 87 deletions(-) create mode 100644 tests/test_common.py diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..7283ffa --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,81 @@ +import unittest +import argparse +from typing import Union, Optional +from chatmastermind.configuration import Config, AIConfig +from chatmastermind.tags import Tag +from chatmastermind.message import Message, Answer +from chatmastermind.chat import Chat +from chatmastermind.ai import AI, AIResponse, Tokens, AIError + + +class FakeAI(AI): + """ + A mocked version of the 'AI' class. + """ + ID: str + name: str + config: AIConfig + + def models(self) -> list[str]: + raise NotImplementedError + + def tokens(self, data: Union[Message, Chat]) -> int: + return 123 + + def print(self) -> None: + pass + + def print_models(self) -> None: + pass + + def __init__(self, ID: str, model: str, error: bool = False): + self.ID = ID + self.model = model + self.error = error + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Mock the 'ai.request()' function by either returning fake + answers or raising an exception. + """ + if self.error: + raise AIError + question.answer = Answer("Answer 0") + question.tags = set(otags) if otags is not None else None + question.ai = self.ID + question.model = self.model + answers: list[Message] = [question] + for n in range(1, num_answers): + answers.append(Message(question=question.question, + answer=Answer(f"Answer {n}"), + tags=otags, + ai=self.ID, + model=self.model)) + return AIResponse(answers, Tokens(10, 10, 20)) + + +class TestWithFakeAI(unittest.TestCase): + def assert_messages_equal_except_file_path(self, msg1: list[Message], msg2: list[Message]) -> None: + """ + Compare messages using more than just Question and Answer. + """ + self.assertEqual(len(msg1), len(msg2)) + for m1, m2 in zip(msg1, msg2): + # exclude the file_path, compare only Q, A and metadata + self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) + + def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI: + """ + Mocked 'create_ai' that returns a 'FakeAI' instance. + """ + return FakeAI(args.AI, args.model) + + def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI: + """ + Mocked 'create_ai' that returns a 'FakeAI' instance. + """ + return FakeAI(args.AI, args.model, error=True) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 8e55b8f..28e3155 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -1,94 +1,19 @@ import os -import unittest import argparse import tempfile from copy import copy from pathlib import Path from unittest import mock from unittest.mock import MagicMock, call -from typing import Optional, Union -from chatmastermind.configuration import Config, AIConfig +from chatmastermind.configuration import Config from chatmastermind.commands.question import create_message, question_cmd -from chatmastermind.tags import Tag from chatmastermind.message import Message, Question, Answer from chatmastermind.chat import Chat, ChatDB -from chatmastermind.ai import AI, AIResponse, Tokens, AIError +from chatmastermind.ai import AIError +from .test_common import TestWithFakeAI -class FakeAI(AI): - """ - A mocked version of the 'AI' class. - """ - ID: str - name: str - config: AIConfig - - def models(self) -> list[str]: - raise NotImplementedError - - def tokens(self, data: Union[Message, Chat]) -> int: - return 123 - - def print(self) -> None: - pass - - def print_models(self) -> None: - pass - - def __init__(self, ID: str, model: str, error: bool = False): - self.ID = ID - self.model = model - self.error = error - - def request(self, - question: Message, - chat: Chat, - num_answers: int = 1, - otags: Optional[set[Tag]] = None) -> AIResponse: - """ - Mock the 'ai.request()' function by either returning fake - answers or raising an exception. - """ - if self.error: - raise AIError - question.answer = Answer("Answer 0") - question.tags = set(otags) if otags is not None else None - question.ai = self.ID - question.model = self.model - answers: list[Message] = [question] - for n in range(1, num_answers): - answers.append(Message(question=question.question, - answer=Answer(f"Answer {n}"), - tags=otags, - ai=self.ID, - model=self.model)) - return AIResponse(answers, Tokens(10, 10, 20)) - - -class TestQuestionCmdBase(unittest.TestCase): - def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: - """ - Compare messages using more than just Question and Answer. - """ - self.assertEqual(len(msg1), len(msg2)) - for m1, m2 in zip(msg1, msg2): - # exclude the file_path, compare only Q, A and metadata - self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) - - def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI: - """ - Mocked 'create_ai' that returns a 'FakeAI' instance. - """ - return FakeAI(args.AI, args.model) - - def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI: - """ - Mocked 'create_ai' that returns a 'FakeAI' instance. - """ - return FakeAI(args.AI, args.model, error=True) - - -class TestMessageCreate(TestQuestionCmdBase): +class TestMessageCreate(TestWithFakeAI): """ Test if messages created by the 'question' command have the correct format. @@ -275,7 +200,7 @@ It is embedded code """)) -class TestQuestionCmd(TestQuestionCmdBase): +class TestQuestionCmd(TestWithFakeAI): def setUp(self) -> None: # create DB and cache @@ -335,7 +260,7 @@ class TestQuestionCmdAsk(TestQuestionCmd): Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) - self.assert_messages_equal(cached_msg, expected_responses) + self.assert_messages_equal_except_file_path(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.create_ai') @@ -393,7 +318,7 @@ class TestQuestionCmdAsk(TestQuestionCmd): Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) - self.assert_messages_equal(cached_msg, [expected_question]) + self.assert_messages_equal_except_file_path(cached_msg, [expected_question]) class TestQuestionCmdRepeat(TestQuestionCmd): @@ -433,7 +358,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): cached_msg = chat.msg_gather(loc='cache') print(self.message_list(self.cache_dir)) self.assertEqual(len(self.message_list(self.cache_dir)), 2) - self.assert_messages_equal(cached_msg, expected_responses) + self.assert_messages_equal_except_file_path(cached_msg, expected_responses) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None: @@ -468,7 +393,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) - self.assert_messages_equal(cached_msg, expected_response) + self.assert_messages_equal_except_file_path(cached_msg, expected_response) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @@ -507,7 +432,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 1) - self.assert_messages_equal(cached_msg, expected_response) + self.assert_messages_equal_except_file_path(cached_msg, expected_response) # also check that the file ID has not been changed assert cached_msg[0].file_path self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) @@ -551,7 +476,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd): question_cmd(self.args, self.config) cached_msg = chat.msg_gather(loc='cache') self.assertEqual(len(self.message_list(self.cache_dir)), 2) - self.assert_messages_equal(cached_msg, [message] + new_expected_response) + self.assert_messages_equal_except_file_path(cached_msg, [message] + new_expected_response) @mock.patch('chatmastermind.commands.question.create_ai') def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None: @@ -608,4 +533,4 @@ class TestQuestionCmdRepeat(TestQuestionCmd): chat = ChatDB.from_dir(Path(self.cache_dir.name), Path(self.db_dir.name)) cached_msg = chat.msg_gather(loc='cache') - self.assert_messages_equal(cached_msg, expected_cache_messages) + self.assert_messages_equal_except_file_path(cached_msg, expected_cache_messages)