From e8343fde013303da6b84026da39cc0032a7385c7 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 11:15:14 +0200 Subject: [PATCH] test_main: added type annotations and a helper class / function --- tests/test_main.py | 88 +++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 3634740..4a70cbb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,25 +5,46 @@ import argparse from chatmastermind.utils import terminal_width from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai +from chatmastermind.configuration import Config, OpenAIConfig from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY -class TestCreateChat(unittest.TestCase): +class CmmTestCase(unittest.TestCase): + """ + Base class for all cmm testcases. + """ + def dummy_config(self, db: str) -> Config: + """ + Creates a dummy configuration. + """ + return Config( + system='dummy_system', + db=db, + openai=OpenAIConfig( + api_key='dummy_key', + model='dummy_model', + max_tokens=4000, + temperature=1.0, + top_p=1, + frequency_penalty=0, + presence_penalty=0 + ) + ) - def setUp(self): - self.config = { - 'system': 'System text', - 'db': 'test_files' - } + +class TestCreateChat(CmmTestCase): + + def setUp(self) -> None: + self.config = self.dummy_config(db='test_files') self.question = "test question" self.tags = ['test_tag'] @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( @@ -45,7 +66,7 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( @@ -63,7 +84,7 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.side_effect = ( @@ -90,9 +111,9 @@ class TestCreateChat(unittest.TestCase): {'role': 'assistant', 'content': 'some answer2'}) -class TestHandleQuestion(unittest.TestCase): +class TestHandleQuestion(CmmTestCase): - def setUp(self): + def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], @@ -109,12 +130,7 @@ class TestHandleQuestion(unittest.TestCase): with_tags=False, with_file=False, ) - self.config = { - 'db': 'test_files', - 'setting1': 'value1', - 'setting2': 'value2', - 'openai': {}, - } + self.config = self.dummy_config(db='test_files') @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.print_tag_args") @@ -122,9 +138,9 @@ class TestHandleQuestion(unittest.TestCase): @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") - def test_ask_cmd(self, mock_print, mock_pp, mock_ai, - mock_print_chat_hist, mock_print_tag_args, - mock_create_chat_hist): + def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, + mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, + mock_create_chat_hist: MagicMock) -> None: open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) @@ -155,15 +171,15 @@ class TestHandleQuestion(unittest.TestCase): open_mock.assert_has_calls(open_expected_calls, any_order=True) -class TestSaveAnswers(unittest.TestCase): +class TestSaveAnswers(CmmTestCase): @mock.patch('builtins.open') @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock, open_mock): + def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: question = "Test question?" answers = ["Answer 1", "Answer 2"] tags = ["tag1", "tag2"] otags = ["otag1", "otag2"] - config = {'db': 'test_db'} + config = self.dummy_config(db='test_db') with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ mock.patch('chatmastermind.storage.yaml.dump'), \ @@ -179,10 +195,10 @@ class TestSaveAnswers(unittest.TestCase): open_mock.assert_has_calls(open_calls, any_order=True) -class TestAI(unittest.TestCase): +class TestAI(CmmTestCase): @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock): + def test_ai(self, mock_create: MagicMock) -> None: mock_create.return_value = { 'choices': [ {'message': {'content': 'response_text_1'}}, @@ -191,28 +207,20 @@ class TestAI(unittest.TestCase): 'usage': {'tokens': 10} } - number = 2 chat = [{"role": "system", "content": "hello ai"}] - config = { - "openai": { - "model": "text-davinci-002", - "temperature": 0.5, - "max_tokens": 150, - "top_p": 1, - "n": number, - "frequency_penalty": 0, - "presence_penalty": 0 - } - } + config = self.dummy_config(db='dummy') + config['openai']['model'] = "text-davinci-002" + config['openai']['max_tokens'] = 150 + config['openai']['temperature'] = 0.5 - result = ai(chat, config, number) + result = ai(chat, config, 2) expected_result = (['response_text_1', 'response_text_2'], {'tokens': 10}) self.assertEqual(result, expected_result) -class TestCreateParser(unittest.TestCase): - def test_create_parser(self): +class TestCreateParser(CmmTestCase): + def test_create_parser(self) -> None: with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: mock_cmdparser = Mock() mock_add_subparsers.return_value = mock_cmdparser