From a5c91adc4138bc5163d6665a0e35a7c42b835da9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 23:22:20 +0200 Subject: [PATCH] configuration: minor improvements / fixes Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'. --- chatmastermind/api_client.py | 14 +++--- chatmastermind/configuration.py | 83 +++++++++++++++++---------------- chatmastermind/main.py | 18 +++---- chatmastermind/storage.py | 24 ++-------- tests/test_main.py | 36 +++++++------- 5 files changed, 80 insertions(+), 95 deletions(-) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index d8634bd..2c4a094 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -30,17 +30,15 @@ def ai(chat: ChatType, Make AI request with the given chat history and configuration. Return AI response and tokens used. """ - if not isinstance(config['openai'], dict): - raise RuntimeError('Configuration openai is not a dict.') response = openai.ChatCompletion.create( - model=config['openai']['model'], + model=config.openai.model, messages=chat, - temperature=config['openai']['temperature'], - max_tokens=config['openai']['max_tokens'], - top_p=config['openai']['top_p'], + temperature=config.openai.temperature, + max_tokens=config.openai.max_tokens, + top_p=config.openai.top_p, n=number, - frequency_penalty=config['openai']['frequency_penalty'], - presence_penalty=config['openai']['presence_penalty']) + frequency_penalty=config.openai.frequency_penalty, + presence_penalty=config.openai.presence_penalty) result = [] for choice in response['choices']: # type: ignore result.append(choice['message']['content'].strip()) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 9cb7885..0037916 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,8 +1,13 @@ -import pathlib -from typing import TypedDict, Any, Union +import yaml +from typing import Type, TypeVar, Any +from dataclasses import dataclass, asdict + +ConfigInst = TypeVar('ConfigInst', bound='Config') +OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') -class OpenAIConfig(TypedDict): +@dataclass +class OpenAIConfig(): """ The OpenAI section of the configuration file. """ @@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict): frequency_penalty: float presence_penalty: float - -def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool: - """ - Checks if the given Open AI configuration dict is complete - and contains valid types and values. - """ - try: - str(conf['api_key']) - str(conf['model']) - int(conf['max_tokens']) - float(conf['temperature']) - float(conf['top_p']) - float(conf['frequency_penalty']) - float(conf['presence_penalty']) - return True - except Exception as e: - print(f"OpenAI configuration is invalid: {e}") - return False + @classmethod + def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: + """ + Create OpenAIConfig from a dict. + """ + return cls( + api_key=str(source['api_key']), + model=str(source['model']), + max_tokens=int(source['max_tokens']), + temperature=float(source['temperature']), + top_p=float(source['top_p']), + frequency_penalty=float(source['frequency_penalty']), + presence_penalty=float(source['presence_penalty']) + ) -class Config(TypedDict): +@dataclass +class Config(): """ The configuration file structure. """ @@ -42,22 +44,23 @@ class Config(TypedDict): db: str openai: OpenAIConfig + @classmethod + def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: + """ + Create OpenAIConfig from a dict. + """ + return cls( + system=str(source['system']), + db=str(source['db']), + openai=OpenAIConfig.from_dict(source['openai']) + ) -def config_valid(conf: dict[str, Any]) -> bool: - """ - Checks if the given configuration dict is complete - and contains valid types and values. - """ - try: - str(conf['system']) - pathlib.Path(str(conf['db'])) - return True - except Exception as e: - print(f"Configuration is invalid: {e}") - return False - if 'openai' in conf: - return openai_config_valid(conf['openai']) - else: - # required as long as we only support OpenAI - print("Section 'openai' is missing in the configuration!") - return False + @classmethod + def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: + with open(path, 'r') as f: + source = yaml.load(f, Loader=yaml.FullLoader) + return cls.from_dict(source) + + def to_file(self, path: str) -> None: + with open(path, 'w') as f: + yaml.dump(asdict(self), f) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 7c6df33..7866179 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -8,7 +8,7 @@ import argcomplete import argparse import pathlib from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data +from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config from itertools import zip_longest @@ -72,10 +72,10 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None: if args.list_models: print_models() elif args.print_model: - print(config['openai']['model']) + print(config.openai.model) elif args.model: - config['openai']['model'] = args.model - write_config(args.config, config) + config.openai.model = args.model + config.to_file(args.config) def ask_cmd(args: argparse.Namespace, config: Config) -> None: @@ -83,11 +83,11 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'ask' command. """ if args.max_tokens: - config['openai']['max_tokens'] = args.max_tokens + config.openai.max_tokens = args.max_tokens if args.temperature: - config['openai']['temperature'] = args.temperature + config.openai.temperature = args.temperature if args.model: - config['openai']['model'] = args.model + config.openai.model = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] @@ -225,9 +225,9 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = read_config(args.config) + config = Config.from_file(args.config) - openai_api_key(config['openai']['api_key']) + openai_api_key(config.openai.api_key) command.func(command, config) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index a4648b0..8b9ed97 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,9 +1,8 @@ import yaml -import sys import io import pathlib from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config, config_valid +from .configuration import Config from typing import Any, Optional @@ -24,19 +23,6 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: "file": fname.name} -def read_config(path: str) -> Config: - with open(path, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - if not config_valid(config): - sys.exit(1) - return config - - -def write_config(path: str, config: Config) -> None: - with open(path, 'w') as f: - yaml.dump(config, f) - - def dump_data(data: dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') @@ -60,7 +46,7 @@ def save_answers(question: str, ) -> None: wtags = otags or tags num, inum = 0, 0 - next_fname = pathlib.Path(str(config['db'])) / '.next' + next_fname = pathlib.Path(str(config.db)) / '.next' try: with open(next_fname, 'r') as f: num = int(f.read()) @@ -87,8 +73,8 @@ def create_chat_hist(question: Optional[str], with_file: bool = False ) -> ChatType: chat: ChatType = [] - append_message(chat, 'system', str(config['system']).strip()) - for file in sorted(pathlib.Path(str(config['db'])).iterdir()): + append_message(chat, 'system', str(config.system).strip()) + for file in sorted(pathlib.Path(str(config.db)).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) @@ -114,7 +100,7 @@ def create_chat_hist(question: Optional[str], def get_tags(config: Config, prefix: Optional[str]) -> list[str]: result = [] - for file in sorted(pathlib.Path(str(config['db'])).iterdir()): + for file in sorted(pathlib.Path(str(config.db)).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) diff --git a/tests/test_main.py b/tests/test_main.py index 4a70cbb..db5fcdb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,7 +5,7 @@ import argparse from chatmastermind.utils import terminal_width from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai -from chatmastermind.configuration import Config, OpenAIConfig +from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -19,18 +19,16 @@ class CmmTestCase(unittest.TestCase): """ Creates a dummy configuration. """ - return Config( - system='dummy_system', - db=db, - openai=OpenAIConfig( - api_key='dummy_key', - model='dummy_model', - max_tokens=4000, - temperature=1.0, - top_p=1, - frequency_penalty=0, - presence_penalty=0 - ) + return Config.from_dict( + {'system': 'dummy_system', + 'db': db, + 'openai': {'api_key': 'dummy_key', + 'model': 'dummy_model', + 'max_tokens': 4000, + 'temperature': 1.0, + 'top_p': 1, + 'frequency_penalty': 0, + 'presence_penalty': 0}} ) @@ -55,7 +53,7 @@ class TestCreateChat(CmmTestCase): self.assertEqual(len(test_chat), 4) self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config['system']}) + {'role': 'system', 'content': self.config.system}) self.assertEqual(test_chat[1], {'role': 'user', 'content': 'test_content'}) self.assertEqual(test_chat[2], @@ -77,7 +75,7 @@ class TestCreateChat(CmmTestCase): self.assertEqual(len(test_chat), 2) self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config['system']}) + {'role': 'system', 'content': self.config.system}) self.assertEqual(test_chat[1], {'role': 'user', 'content': self.question}) @@ -100,7 +98,7 @@ class TestCreateChat(CmmTestCase): self.assertEqual(len(test_chat), 6) self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config['system']}) + {'role': 'system', 'content': self.config.system}) self.assertEqual(test_chat[1], {'role': 'user', 'content': 'test_content'}) self.assertEqual(test_chat[2], @@ -209,9 +207,9 @@ class TestAI(CmmTestCase): chat = [{"role": "system", "content": "hello ai"}] config = self.dummy_config(db='dummy') - config['openai']['model'] = "text-davinci-002" - config['openai']['max_tokens'] = 150 - config['openai']['temperature'] = 0.5 + config.openai.model = "text-davinci-002" + config.openai.max_tokens = 150 + config.openai.temperature = 0.5 result = ai(chat, config, 2) expected_result = (['response_text_1', 'response_text_2'],