diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index 8eaf695..d8634bd 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -1,11 +1,17 @@ import openai +from .utils import ChatType +from .configuration import Config + def openai_api_key(api_key: str) -> None: openai.api_key = api_key def print_models() -> None: + """ + Print all models supported by the current AI. + """ not_ready = [] for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): if engine['ready']: @@ -16,10 +22,16 @@ def print_models() -> None: print('\nNot ready: ' + ', '.join(not_ready)) -def ai(chat: list[dict[str, str]], - config: dict, +def ai(chat: ChatType, + config: Config, number: int ) -> tuple[list[str], dict[str, int]]: + """ + Make AI request with the given chat history and configuration. + Return AI response and tokens used. + """ + if not isinstance(config['openai'], dict): + raise RuntimeError('Configuration openai is not a dict.') response = openai.ChatCompletion.create( model=config['openai']['model'], messages=chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py new file mode 100644 index 0000000..9cb7885 --- /dev/null +++ b/chatmastermind/configuration.py @@ -0,0 +1,63 @@ +import pathlib +from typing import TypedDict, Any, Union + + +class OpenAIConfig(TypedDict): + """ + The OpenAI section of the configuration file. + """ + api_key: str + model: str + temperature: float + max_tokens: int + top_p: float + frequency_penalty: float + presence_penalty: float + + +def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool: + """ + Checks if the given Open AI configuration dict is complete + and contains valid types and values. + """ + try: + str(conf['api_key']) + str(conf['model']) + int(conf['max_tokens']) + float(conf['temperature']) + float(conf['top_p']) + float(conf['frequency_penalty']) + float(conf['presence_penalty']) + return True + except Exception as e: + print(f"OpenAI configuration is invalid: {e}") + return False + + +class Config(TypedDict): + """ + The configuration file structure. + """ + system: str + db: str + openai: OpenAIConfig + + +def config_valid(conf: dict[str, Any]) -> bool: + """ + Checks if the given configuration dict is complete + and contains valid types and values. + """ + try: + str(conf['system']) + pathlib.Path(str(conf['db'])) + return True + except Exception as e: + print(f"Configuration is invalid: {e}") + return False + if 'openai' in conf: + return openai_config_valid(conf['openai']) + else: + # required as long as we only support OpenAI + print("Section 'openai' is missing in the configuration!") + return False diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 15e8208..7c6df33 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,23 +7,25 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .api_client import ai, openai_api_key, print_models +from .configuration import Config from itertools import zip_longest +from typing import Any default_config = '.config.yaml' -def tags_completer(prefix, parsed_args, **kwargs): +def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: with open(parsed_args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return get_tags_unique(config, prefix) def create_question_with_hist(args: argparse.Namespace, - config: ConfigType, - ) -> tuple[list[dict[str, str]], str, list[str]]: + config: Config, + ) -> tuple[ChatType, str, list[str]]: """ Creates the "AI request", including the question and chat history as determined by the specified tags. @@ -55,7 +57,7 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def tag_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tag' command. """ @@ -63,13 +65,10 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_tags_frequency(get_tags(config, None)) -def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def config_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'config' command. """ - if type(config['openai']) is not dict: - raise RuntimeError('Configuration openai is not a dict.') - if args.list_models: print_models() elif args.print_model: @@ -79,19 +78,16 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: write_config(args.config, config) -def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. """ - if type(config['openai']) is not dict: - raise RuntimeError('Configuration openai is not a dict.') - config_openai = config['openai'] if args.max_tokens: - config_openai['max_tokens'] = args.max_tokens + config['openai']['max_tokens'] = args.max_tokens if args.temperature: - config_openai['temperature'] = args.temperature + config['openai']['temperature'] = args.temperature if args.model: - config_openai['model'] = args.model + config['openai']['model'] = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] @@ -101,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: print(f"Usage: {usage}") -def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ @@ -115,7 +111,7 @@ def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_chat_hist(chat, args.dump, args.only_source_code) -def print_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ @@ -231,10 +227,7 @@ def main() -> int: command = parser.parse_args() config = read_config(args.config) - if type(config['openai']) is dict and type(config['openai']['api_key']) is str: - openai_api_key(config['openai']['api_key']) - else: - raise RuntimeError("Configuration openai.api_key is wrong.") + openai_api_key(config['openai']['api_key']) command.func(command, config) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index d90598b..a4648b0 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,11 +1,13 @@ import yaml +import sys import io import pathlib -from .utils import terminal_width, append_message, message_to_chat, ConfigType -from typing import List, Dict, Any, Optional +from .utils import terminal_width, append_message, message_to_chat, ChatType +from .configuration import Config, config_valid +from typing import Any, Optional -def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: +def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: with open(fname, "r") as fd: tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() # also support tags separated by ',' (old format) @@ -22,18 +24,20 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: "file": fname.name} -def read_config(path: str) -> ConfigType: +def read_config(path: str) -> Config: with open(path, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) + if not config_valid(config): + sys.exit(1) return config -def write_config(path: str, config: ConfigType) -> None: +def write_config(path: str, config: Config) -> None: with open(path, 'w') as f: yaml.dump(config, f) -def dump_data(data: Dict[str, Any]) -> str: +def dump_data(data: dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') @@ -41,7 +45,7 @@ def dump_data(data: Dict[str, Any]) -> str: return fd.getvalue() -def write_file(fname: str, data: Dict[str, Any]) -> None: +def write_file(fname: str, data: dict[str, Any]) -> None: with open(fname, "w") as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') @@ -52,7 +56,7 @@ def save_answers(question: str, answers: list[str], tags: list[str], otags: Optional[list[str]], - config: ConfigType + config: Config ) -> None: wtags = otags or tags num, inum = 0, 0 @@ -75,14 +79,14 @@ def save_answers(question: str, def create_chat_hist(question: Optional[str], - tags: Optional[List[str]], - extags: Optional[List[str]], - config: ConfigType, + tags: Optional[list[str]], + extags: Optional[list[str]], + config: Config, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False - ) -> List[Dict[str, str]]: - chat: List[Dict[str, str]] = [] + ) -> ChatType: + chat: ChatType = [] append_message(chat, 'system', str(config['system']).strip()) for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -108,7 +112,7 @@ def create_chat_hist(question: Optional[str], return chat -def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: +def get_tags(config: Config, prefix: Optional[str]) -> list[str]: result = [] for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -127,5 +131,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: return result -def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: +def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index fba8296..bd80e4f 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -1,14 +1,15 @@ import shutil from pprint import PrettyPrinter +from typing import Any -ConfigType = dict[str, str | dict[str, str | int | float]] +ChatType = list[dict[str, str]] def terminal_width() -> int: return shutil.get_terminal_size().columns -def pp(*args, **kwargs) -> None: +def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) @@ -30,7 +31,7 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None print() -def append_message(chat: list[dict[str, str]], +def append_message(chat: ChatType, role: str, content: str ) -> None: @@ -38,7 +39,7 @@ def append_message(chat: list[dict[str, str]], def message_to_chat(message: dict[str, str], - chat: list[dict[str, str]], + chat: ChatType, with_tags: bool = False, with_file: bool = False ) -> None: @@ -61,7 +62,7 @@ def display_source_code(content: str) -> None: pass -def print_chat_hist(chat, dump=False, source_code=False) -> None: +def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: if dump: pp(chat) return diff --git a/mypy.ini b/mypy.ini index b99c5a5..aecd40e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,3 +5,4 @@ strict_optional = True warn_unused_ignores = False warn_redundant_casts = True warn_unused_configs = True +disallow_untyped_defs = True diff --git a/tests/test_main.py b/tests/test_main.py index 3634740..4a70cbb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,25 +5,46 @@ import argparse from chatmastermind.utils import terminal_width from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai +from chatmastermind.configuration import Config, OpenAIConfig from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY -class TestCreateChat(unittest.TestCase): +class CmmTestCase(unittest.TestCase): + """ + Base class for all cmm testcases. + """ + def dummy_config(self, db: str) -> Config: + """ + Creates a dummy configuration. + """ + return Config( + system='dummy_system', + db=db, + openai=OpenAIConfig( + api_key='dummy_key', + model='dummy_model', + max_tokens=4000, + temperature=1.0, + top_p=1, + frequency_penalty=0, + presence_penalty=0 + ) + ) - def setUp(self): - self.config = { - 'system': 'System text', - 'db': 'test_files' - } + +class TestCreateChat(CmmTestCase): + + def setUp(self) -> None: + self.config = self.dummy_config(db='test_files') self.question = "test question" self.tags = ['test_tag'] @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( @@ -45,7 +66,7 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( @@ -63,7 +84,7 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.side_effect = ( @@ -90,9 +111,9 @@ class TestCreateChat(unittest.TestCase): {'role': 'assistant', 'content': 'some answer2'}) -class TestHandleQuestion(unittest.TestCase): +class TestHandleQuestion(CmmTestCase): - def setUp(self): + def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], @@ -109,12 +130,7 @@ class TestHandleQuestion(unittest.TestCase): with_tags=False, with_file=False, ) - self.config = { - 'db': 'test_files', - 'setting1': 'value1', - 'setting2': 'value2', - 'openai': {}, - } + self.config = self.dummy_config(db='test_files') @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.print_tag_args") @@ -122,9 +138,9 @@ class TestHandleQuestion(unittest.TestCase): @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") - def test_ask_cmd(self, mock_print, mock_pp, mock_ai, - mock_print_chat_hist, mock_print_tag_args, - mock_create_chat_hist): + def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, + mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, + mock_create_chat_hist: MagicMock) -> None: open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) @@ -155,15 +171,15 @@ class TestHandleQuestion(unittest.TestCase): open_mock.assert_has_calls(open_expected_calls, any_order=True) -class TestSaveAnswers(unittest.TestCase): +class TestSaveAnswers(CmmTestCase): @mock.patch('builtins.open') @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock, open_mock): + def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: question = "Test question?" answers = ["Answer 1", "Answer 2"] tags = ["tag1", "tag2"] otags = ["otag1", "otag2"] - config = {'db': 'test_db'} + config = self.dummy_config(db='test_db') with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ mock.patch('chatmastermind.storage.yaml.dump'), \ @@ -179,10 +195,10 @@ class TestSaveAnswers(unittest.TestCase): open_mock.assert_has_calls(open_calls, any_order=True) -class TestAI(unittest.TestCase): +class TestAI(CmmTestCase): @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock): + def test_ai(self, mock_create: MagicMock) -> None: mock_create.return_value = { 'choices': [ {'message': {'content': 'response_text_1'}}, @@ -191,28 +207,20 @@ class TestAI(unittest.TestCase): 'usage': {'tokens': 10} } - number = 2 chat = [{"role": "system", "content": "hello ai"}] - config = { - "openai": { - "model": "text-davinci-002", - "temperature": 0.5, - "max_tokens": 150, - "top_p": 1, - "n": number, - "frequency_penalty": 0, - "presence_penalty": 0 - } - } + config = self.dummy_config(db='dummy') + config['openai']['model'] = "text-davinci-002" + config['openai']['max_tokens'] = 150 + config['openai']['temperature'] = 0.5 - result = ai(chat, config, number) + result = ai(chat, config, 2) expected_result = (['response_text_1', 'response_text_2'], {'tokens': 10}) self.assertEqual(result, expected_result) -class TestCreateParser(unittest.TestCase): - def test_create_parser(self): +class TestCreateParser(CmmTestCase): + def test_create_parser(self) -> None: with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: mock_cmdparser = Mock() mock_add_subparsers.return_value = mock_cmdparser