configuration: minor improvements / fixes

Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'.
This commit is contained in:
juk0de 2023-08-16 23:22:20 +02:00
parent 380b7c1b67
commit a5c91adc41
5 changed files with 80 additions and 95 deletions

View File

@ -30,17 +30,15 @@ def ai(chat: ChatType,
Make AI request with the given chat history and configuration. Make AI request with the given chat history and configuration.
Return AI response and tokens used. Return AI response and tokens used.
""" """
if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.')
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=config['openai']['model'], model=config.openai.model,
messages=chat, messages=chat,
temperature=config['openai']['temperature'], temperature=config.openai.temperature,
max_tokens=config['openai']['max_tokens'], max_tokens=config.openai.max_tokens,
top_p=config['openai']['top_p'], top_p=config.openai.top_p,
n=number, n=number,
frequency_penalty=config['openai']['frequency_penalty'], frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config['openai']['presence_penalty']) presence_penalty=config.openai.presence_penalty)
result = [] result = []
for choice in response['choices']: # type: ignore for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip()) result.append(choice['message']['content'].strip())

View File

@ -1,8 +1,13 @@
import pathlib import yaml
from typing import TypedDict, Any, Union from typing import Type, TypeVar, Any
from dataclasses import dataclass, asdict
ConfigInst = TypeVar('ConfigInst', bound='Config')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
class OpenAIConfig(TypedDict): @dataclass
class OpenAIConfig():
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict):
frequency_penalty: float frequency_penalty: float
presence_penalty: float presence_penalty: float
@classmethod
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
""" """
Checks if the given Open AI configuration dict is complete Create OpenAIConfig from a dict.
and contains valid types and values.
""" """
try: return cls(
str(conf['api_key']) api_key=str(source['api_key']),
str(conf['model']) model=str(source['model']),
int(conf['max_tokens']) max_tokens=int(source['max_tokens']),
float(conf['temperature']) temperature=float(source['temperature']),
float(conf['top_p']) top_p=float(source['top_p']),
float(conf['frequency_penalty']) frequency_penalty=float(source['frequency_penalty']),
float(conf['presence_penalty']) presence_penalty=float(source['presence_penalty'])
return True )
except Exception as e:
print(f"OpenAI configuration is invalid: {e}")
return False
class Config(TypedDict): @dataclass
class Config():
""" """
The configuration file structure. The configuration file structure.
""" """
@ -42,22 +44,23 @@ class Config(TypedDict):
db: str db: str
openai: OpenAIConfig openai: OpenAIConfig
@classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
"""
Create OpenAIConfig from a dict.
"""
return cls(
system=str(source['system']),
db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai'])
)
def config_valid(conf: dict[str, Any]) -> bool: @classmethod
""" def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
Checks if the given configuration dict is complete with open(path, 'r') as f:
and contains valid types and values. source = yaml.load(f, Loader=yaml.FullLoader)
""" return cls.from_dict(source)
try:
str(conf['system']) def to_file(self, path: str) -> None:
pathlib.Path(str(conf['db'])) with open(path, 'w') as f:
return True yaml.dump(asdict(self), f)
except Exception as e:
print(f"Configuration is invalid: {e}")
return False
if 'openai' in conf:
return openai_config_valid(conf['openai'])
else:
# required as long as we only support OpenAI
print("Section 'openai' is missing in the configuration!")
return False

View File

@ -8,7 +8,7 @@ import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config from .configuration import Config
from itertools import zip_longest from itertools import zip_longest
@ -72,10 +72,10 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
if args.list_models: if args.list_models:
print_models() print_models()
elif args.print_model: elif args.print_model:
print(config['openai']['model']) print(config.openai.model)
elif args.model: elif args.model:
config['openai']['model'] = args.model config.openai.model = args.model
write_config(args.config, config) config.to_file(args.config)
def ask_cmd(args: argparse.Namespace, config: Config) -> None: def ask_cmd(args: argparse.Namespace, config: Config) -> None:
@ -83,11 +83,11 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
if args.max_tokens: if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens config.openai.max_tokens = args.max_tokens
if args.temperature: if args.temperature:
config['openai']['temperature'] = args.temperature config.openai.temperature = args.temperature
if args.model: if args.model:
config['openai']['model'] = args.model config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code) print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []
@ -225,9 +225,9 @@ def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
config = read_config(args.config) config = Config.from_file(args.config)
openai_api_key(config['openai']['api_key']) openai_api_key(config.openai.api_key)
command.func(command, config) command.func(command, config)

View File

@ -1,9 +1,8 @@
import yaml import yaml
import sys
import io import io
import pathlib import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config, config_valid from .configuration import Config
from typing import Any, Optional from typing import Any, Optional
@ -24,19 +23,6 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
"file": fname.name} "file": fname.name}
def read_config(path: str) -> Config:
with open(path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if not config_valid(config):
sys.exit(1)
return config
def write_config(path: str, config: Config) -> None:
with open(path, 'w') as f:
yaml.dump(config, f)
def dump_data(data: dict[str, Any]) -> str: def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd: with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
@ -60,7 +46,7 @@ def save_answers(question: str,
) -> None: ) -> None:
wtags = otags or tags wtags = otags or tags
num, inum = 0, 0 num, inum = 0, 0
next_fname = pathlib.Path(str(config['db'])) / '.next' next_fname = pathlib.Path(str(config.db)) / '.next'
try: try:
with open(next_fname, 'r') as f: with open(next_fname, 'r') as f:
num = int(f.read()) num = int(f.read())
@ -87,8 +73,8 @@ def create_chat_hist(question: Optional[str],
with_file: bool = False with_file: bool = False
) -> ChatType: ) -> ChatType:
chat: ChatType = [] chat: ChatType = []
append_message(chat, 'system', str(config['system']).strip()) append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
with open(file, 'r') as f: with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader) data = yaml.load(f, Loader=yaml.FullLoader)
@ -114,7 +100,7 @@ def create_chat_hist(question: Optional[str],
def get_tags(config: Config, prefix: Optional[str]) -> list[str]: def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = [] result = []
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
with open(file, 'r') as f: with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader) data = yaml.load(f, Loader=yaml.FullLoader)

View File

@ -5,7 +5,7 @@ import argparse
from chatmastermind.utils import terminal_width from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai from chatmastermind.api_client import ai
from chatmastermind.configuration import Config, OpenAIConfig from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY from unittest.mock import patch, MagicMock, Mock, ANY
@ -19,18 +19,16 @@ class CmmTestCase(unittest.TestCase):
""" """
Creates a dummy configuration. Creates a dummy configuration.
""" """
return Config( return Config.from_dict(
system='dummy_system', {'system': 'dummy_system',
db=db, 'db': db,
openai=OpenAIConfig( 'openai': {'api_key': 'dummy_key',
api_key='dummy_key', 'model': 'dummy_model',
model='dummy_model', 'max_tokens': 4000,
max_tokens=4000, 'temperature': 1.0,
temperature=1.0, 'top_p': 1,
top_p=1, 'frequency_penalty': 0,
frequency_penalty=0, 'presence_penalty': 0}}
presence_penalty=0
)
) )
@ -55,7 +53,7 @@ class TestCreateChat(CmmTestCase):
self.assertEqual(len(test_chat), 4) self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0], self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config['system']}) {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'}) {'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2], self.assertEqual(test_chat[2],
@ -77,7 +75,7 @@ class TestCreateChat(CmmTestCase):
self.assertEqual(len(test_chat), 2) self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0], self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config['system']}) {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question}) {'role': 'user', 'content': self.question})
@ -100,7 +98,7 @@ class TestCreateChat(CmmTestCase):
self.assertEqual(len(test_chat), 6) self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0], self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config['system']}) {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'}) {'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2], self.assertEqual(test_chat[2],
@ -209,9 +207,9 @@ class TestAI(CmmTestCase):
chat = [{"role": "system", "content": "hello ai"}] chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy') config = self.dummy_config(db='dummy')
config['openai']['model'] = "text-davinci-002" config.openai.model = "text-davinci-002"
config['openai']['max_tokens'] = 150 config.openai.max_tokens = 150
config['openai']['temperature'] = 0.5 config.openai.temperature = 0.5
result = ai(chat, config, 2) result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'], expected_result = (['response_text_1', 'response_text_2'],