configuration: minor improvements / fixes

Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'.
This commit is contained in:
juk0de 2023-08-16 23:22:20 +02:00
parent 380b7c1b67
commit a5c91adc41
5 changed files with 80 additions and 95 deletions

View File

@ -30,17 +30,15 @@ def ai(chat: ChatType,
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.')
response = openai.ChatCompletion.create(
model=config['openai']['model'],
model=config.openai.model,
messages=chat,
temperature=config['openai']['temperature'],
max_tokens=config['openai']['max_tokens'],
top_p=config['openai']['top_p'],
temperature=config.openai.temperature,
max_tokens=config.openai.max_tokens,
top_p=config.openai.top_p,
n=number,
frequency_penalty=config['openai']['frequency_penalty'],
presence_penalty=config['openai']['presence_penalty'])
frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config.openai.presence_penalty)
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())

View File

@ -1,8 +1,13 @@
import pathlib
from typing import TypedDict, Any, Union
import yaml
from typing import Type, TypeVar, Any
from dataclasses import dataclass, asdict
ConfigInst = TypeVar('ConfigInst', bound='Config')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
class OpenAIConfig(TypedDict):
@dataclass
class OpenAIConfig():
"""
The OpenAI section of the configuration file.
"""
@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict):
frequency_penalty: float
presence_penalty: float
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool:
@classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
"""
Checks if the given Open AI configuration dict is complete
and contains valid types and values.
Create OpenAIConfig from a dict.
"""
try:
str(conf['api_key'])
str(conf['model'])
int(conf['max_tokens'])
float(conf['temperature'])
float(conf['top_p'])
float(conf['frequency_penalty'])
float(conf['presence_penalty'])
return True
except Exception as e:
print(f"OpenAI configuration is invalid: {e}")
return False
return cls(
api_key=str(source['api_key']),
model=str(source['model']),
max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']),
top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty'])
)
class Config(TypedDict):
@dataclass
class Config():
"""
The configuration file structure.
"""
@ -42,22 +44,23 @@ class Config(TypedDict):
db: str
openai: OpenAIConfig
@classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
"""
Create OpenAIConfig from a dict.
"""
return cls(
system=str(source['system']),
db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai'])
)
def config_valid(conf: dict[str, Any]) -> bool:
"""
Checks if the given configuration dict is complete
and contains valid types and values.
"""
try:
str(conf['system'])
pathlib.Path(str(conf['db']))
return True
except Exception as e:
print(f"Configuration is invalid: {e}")
return False
if 'openai' in conf:
return openai_config_valid(conf['openai'])
else:
# required as long as we only support OpenAI
print("Section 'openai' is missing in the configuration!")
return False
@classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source)
def to_file(self, path: str) -> None:
with open(path, 'w') as f:
yaml.dump(asdict(self), f)

View File

@ -8,7 +8,7 @@ import argcomplete
import argparse
import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from itertools import zip_longest
@ -72,10 +72,10 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
if args.list_models:
print_models()
elif args.print_model:
print(config['openai']['model'])
print(config.openai.model)
elif args.model:
config['openai']['model'] = args.model
write_config(args.config, config)
config.openai.model = args.model
config.to_file(args.config)
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
@ -83,11 +83,11 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
Handler for the 'ask' command.
"""
if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens
config.openai.max_tokens = args.max_tokens
if args.temperature:
config['openai']['temperature'] = args.temperature
config.openai.temperature = args.temperature
if args.model:
config['openai']['model'] = args.model
config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or []
@ -225,9 +225,9 @@ def main() -> int:
parser = create_parser()
args = parser.parse_args()
command = parser.parse_args()
config = read_config(args.config)
config = Config.from_file(args.config)
openai_api_key(config['openai']['api_key'])
openai_api_key(config.openai.api_key)
command.func(command, config)

View File

@ -1,9 +1,8 @@
import yaml
import sys
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config, config_valid
from .configuration import Config
from typing import Any, Optional
@ -24,19 +23,6 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
"file": fname.name}
def read_config(path: str) -> Config:
with open(path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if not config_valid(config):
sys.exit(1)
return config
def write_config(path: str, config: Config) -> None:
with open(path, 'w') as f:
yaml.dump(config, f)
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
@ -60,7 +46,7 @@ def save_answers(question: str,
) -> None:
wtags = otags or tags
num, inum = 0, 0
next_fname = pathlib.Path(str(config['db'])) / '.next'
next_fname = pathlib.Path(str(config.db)) / '.next'
try:
with open(next_fname, 'r') as f:
num = int(f.read())
@ -87,8 +73,8 @@ def create_chat_hist(question: Optional[str],
with_file: bool = False
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config['system']).strip())
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
@ -114,7 +100,7 @@ def create_chat_hist(question: Optional[str],
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)

View File

@ -5,7 +5,7 @@ import argparse
from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai
from chatmastermind.configuration import Config, OpenAIConfig
from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY
@ -19,18 +19,16 @@ class CmmTestCase(unittest.TestCase):
"""
Creates a dummy configuration.
"""
return Config(
system='dummy_system',
db=db,
openai=OpenAIConfig(
api_key='dummy_key',
model='dummy_model',
max_tokens=4000,
temperature=1.0,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
return Config.from_dict(
{'system': 'dummy_system',
'db': db,
'openai': {'api_key': 'dummy_key',
'model': 'dummy_model',
'max_tokens': 4000,
'temperature': 1.0,
'top_p': 1,
'frequency_penalty': 0,
'presence_penalty': 0}}
)
@ -55,7 +53,7 @@ class TestCreateChat(CmmTestCase):
self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config['system']})
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
@ -77,7 +75,7 @@ class TestCreateChat(CmmTestCase):
self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config['system']})
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question})
@ -100,7 +98,7 @@ class TestCreateChat(CmmTestCase):
self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config['system']})
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
@ -209,9 +207,9 @@ class TestAI(CmmTestCase):
chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy')
config['openai']['model'] = "text-davinci-002"
config['openai']['max_tokens'] = 150
config['openai']['temperature'] = 0.5
config.openai.model = "text-davinci-002"
config.openai.max_tokens = 150
config.openai.temperature = 0.5
result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'],