configuration: minor improvements / fixes
Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'.
This commit is contained in:
parent
380b7c1b67
commit
a5c91adc41
@ -30,17 +30,15 @@ def ai(chat: ChatType,
|
||||
Make AI request with the given chat history and configuration.
|
||||
Return AI response and tokens used.
|
||||
"""
|
||||
if not isinstance(config['openai'], dict):
|
||||
raise RuntimeError('Configuration openai is not a dict.')
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config['openai']['model'],
|
||||
model=config.openai.model,
|
||||
messages=chat,
|
||||
temperature=config['openai']['temperature'],
|
||||
max_tokens=config['openai']['max_tokens'],
|
||||
top_p=config['openai']['top_p'],
|
||||
temperature=config.openai.temperature,
|
||||
max_tokens=config.openai.max_tokens,
|
||||
top_p=config.openai.top_p,
|
||||
n=number,
|
||||
frequency_penalty=config['openai']['frequency_penalty'],
|
||||
presence_penalty=config['openai']['presence_penalty'])
|
||||
frequency_penalty=config.openai.frequency_penalty,
|
||||
presence_penalty=config.openai.presence_penalty)
|
||||
result = []
|
||||
for choice in response['choices']: # type: ignore
|
||||
result.append(choice['message']['content'].strip())
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
import pathlib
|
||||
from typing import TypedDict, Any, Union
|
||||
import yaml
|
||||
from typing import Type, TypeVar, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||
|
||||
|
||||
class OpenAIConfig(TypedDict):
|
||||
@dataclass
|
||||
class OpenAIConfig():
|
||||
"""
|
||||
The OpenAI section of the configuration file.
|
||||
"""
|
||||
@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict):
|
||||
frequency_penalty: float
|
||||
presence_penalty: float
|
||||
|
||||
|
||||
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool:
|
||||
"""
|
||||
Checks if the given Open AI configuration dict is complete
|
||||
and contains valid types and values.
|
||||
"""
|
||||
try:
|
||||
str(conf['api_key'])
|
||||
str(conf['model'])
|
||||
int(conf['max_tokens'])
|
||||
float(conf['temperature'])
|
||||
float(conf['top_p'])
|
||||
float(conf['frequency_penalty'])
|
||||
float(conf['presence_penalty'])
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"OpenAI configuration is invalid: {e}")
|
||||
return False
|
||||
@classmethod
|
||||
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
||||
"""
|
||||
Create OpenAIConfig from a dict.
|
||||
"""
|
||||
return cls(
|
||||
api_key=str(source['api_key']),
|
||||
model=str(source['model']),
|
||||
max_tokens=int(source['max_tokens']),
|
||||
temperature=float(source['temperature']),
|
||||
top_p=float(source['top_p']),
|
||||
frequency_penalty=float(source['frequency_penalty']),
|
||||
presence_penalty=float(source['presence_penalty'])
|
||||
)
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
@dataclass
|
||||
class Config():
|
||||
"""
|
||||
The configuration file structure.
|
||||
"""
|
||||
@ -42,22 +44,23 @@ class Config(TypedDict):
|
||||
db: str
|
||||
openai: OpenAIConfig
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
||||
"""
|
||||
Create OpenAIConfig from a dict.
|
||||
"""
|
||||
return cls(
|
||||
system=str(source['system']),
|
||||
db=str(source['db']),
|
||||
openai=OpenAIConfig.from_dict(source['openai'])
|
||||
)
|
||||
|
||||
def config_valid(conf: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Checks if the given configuration dict is complete
|
||||
and contains valid types and values.
|
||||
"""
|
||||
try:
|
||||
str(conf['system'])
|
||||
pathlib.Path(str(conf['db']))
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Configuration is invalid: {e}")
|
||||
return False
|
||||
if 'openai' in conf:
|
||||
return openai_config_valid(conf['openai'])
|
||||
else:
|
||||
# required as long as we only support OpenAI
|
||||
print("Section 'openai' is missing in the configuration!")
|
||||
return False
|
||||
@classmethod
|
||||
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
|
||||
with open(path, 'r') as f:
|
||||
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return cls.from_dict(source)
|
||||
|
||||
def to_file(self, path: str) -> None:
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(asdict(self), f)
|
||||
|
||||
@ -8,7 +8,7 @@ import argcomplete
|
||||
import argparse
|
||||
import pathlib
|
||||
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
|
||||
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
|
||||
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
|
||||
from .api_client import ai, openai_api_key, print_models
|
||||
from .configuration import Config
|
||||
from itertools import zip_longest
|
||||
@ -72,10 +72,10 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
if args.list_models:
|
||||
print_models()
|
||||
elif args.print_model:
|
||||
print(config['openai']['model'])
|
||||
print(config.openai.model)
|
||||
elif args.model:
|
||||
config['openai']['model'] = args.model
|
||||
write_config(args.config, config)
|
||||
config.openai.model = args.model
|
||||
config.to_file(args.config)
|
||||
|
||||
|
||||
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
@ -83,11 +83,11 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
Handler for the 'ask' command.
|
||||
"""
|
||||
if args.max_tokens:
|
||||
config['openai']['max_tokens'] = args.max_tokens
|
||||
config.openai.max_tokens = args.max_tokens
|
||||
if args.temperature:
|
||||
config['openai']['temperature'] = args.temperature
|
||||
config.openai.temperature = args.temperature
|
||||
if args.model:
|
||||
config['openai']['model'] = args.model
|
||||
config.openai.model = args.model
|
||||
chat, question, tags = create_question_with_hist(args, config)
|
||||
print_chat_hist(chat, False, args.only_source_code)
|
||||
otags = args.output_tags or []
|
||||
@ -225,9 +225,9 @@ def main() -> int:
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
command = parser.parse_args()
|
||||
config = read_config(args.config)
|
||||
config = Config.from_file(args.config)
|
||||
|
||||
openai_api_key(config['openai']['api_key'])
|
||||
openai_api_key(config.openai.api_key)
|
||||
|
||||
command.func(command, config)
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import yaml
|
||||
import sys
|
||||
import io
|
||||
import pathlib
|
||||
from .utils import terminal_width, append_message, message_to_chat, ChatType
|
||||
from .configuration import Config, config_valid
|
||||
from .configuration import Config
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@ -24,19 +23,6 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
|
||||
"file": fname.name}
|
||||
|
||||
|
||||
def read_config(path: str) -> Config:
|
||||
with open(path, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if not config_valid(config):
|
||||
sys.exit(1)
|
||||
return config
|
||||
|
||||
|
||||
def write_config(path: str, config: Config) -> None:
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
|
||||
def dump_data(data: dict[str, Any]) -> str:
|
||||
with io.StringIO() as fd:
|
||||
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
|
||||
@ -60,7 +46,7 @@ def save_answers(question: str,
|
||||
) -> None:
|
||||
wtags = otags or tags
|
||||
num, inum = 0, 0
|
||||
next_fname = pathlib.Path(str(config['db'])) / '.next'
|
||||
next_fname = pathlib.Path(str(config.db)) / '.next'
|
||||
try:
|
||||
with open(next_fname, 'r') as f:
|
||||
num = int(f.read())
|
||||
@ -87,8 +73,8 @@ def create_chat_hist(question: Optional[str],
|
||||
with_file: bool = False
|
||||
) -> ChatType:
|
||||
chat: ChatType = []
|
||||
append_message(chat, 'system', str(config['system']).strip())
|
||||
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
|
||||
append_message(chat, 'system', str(config.system).strip())
|
||||
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
@ -114,7 +100,7 @@ def create_chat_hist(question: Optional[str],
|
||||
|
||||
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
|
||||
result = []
|
||||
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
|
||||
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
@ -5,7 +5,7 @@ import argparse
|
||||
from chatmastermind.utils import terminal_width
|
||||
from chatmastermind.main import create_parser, ask_cmd
|
||||
from chatmastermind.api_client import ai
|
||||
from chatmastermind.configuration import Config, OpenAIConfig
|
||||
from chatmastermind.configuration import Config
|
||||
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
||||
from unittest import mock
|
||||
from unittest.mock import patch, MagicMock, Mock, ANY
|
||||
@ -19,18 +19,16 @@ class CmmTestCase(unittest.TestCase):
|
||||
"""
|
||||
Creates a dummy configuration.
|
||||
"""
|
||||
return Config(
|
||||
system='dummy_system',
|
||||
db=db,
|
||||
openai=OpenAIConfig(
|
||||
api_key='dummy_key',
|
||||
model='dummy_model',
|
||||
max_tokens=4000,
|
||||
temperature=1.0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0
|
||||
)
|
||||
return Config.from_dict(
|
||||
{'system': 'dummy_system',
|
||||
'db': db,
|
||||
'openai': {'api_key': 'dummy_key',
|
||||
'model': 'dummy_model',
|
||||
'max_tokens': 4000,
|
||||
'temperature': 1.0,
|
||||
'top_p': 1,
|
||||
'frequency_penalty': 0,
|
||||
'presence_penalty': 0}}
|
||||
)
|
||||
|
||||
|
||||
@ -55,7 +53,7 @@ class TestCreateChat(CmmTestCase):
|
||||
|
||||
self.assertEqual(len(test_chat), 4)
|
||||
self.assertEqual(test_chat[0],
|
||||
{'role': 'system', 'content': self.config['system']})
|
||||
{'role': 'system', 'content': self.config.system})
|
||||
self.assertEqual(test_chat[1],
|
||||
{'role': 'user', 'content': 'test_content'})
|
||||
self.assertEqual(test_chat[2],
|
||||
@ -77,7 +75,7 @@ class TestCreateChat(CmmTestCase):
|
||||
|
||||
self.assertEqual(len(test_chat), 2)
|
||||
self.assertEqual(test_chat[0],
|
||||
{'role': 'system', 'content': self.config['system']})
|
||||
{'role': 'system', 'content': self.config.system})
|
||||
self.assertEqual(test_chat[1],
|
||||
{'role': 'user', 'content': self.question})
|
||||
|
||||
@ -100,7 +98,7 @@ class TestCreateChat(CmmTestCase):
|
||||
|
||||
self.assertEqual(len(test_chat), 6)
|
||||
self.assertEqual(test_chat[0],
|
||||
{'role': 'system', 'content': self.config['system']})
|
||||
{'role': 'system', 'content': self.config.system})
|
||||
self.assertEqual(test_chat[1],
|
||||
{'role': 'user', 'content': 'test_content'})
|
||||
self.assertEqual(test_chat[2],
|
||||
@ -209,9 +207,9 @@ class TestAI(CmmTestCase):
|
||||
|
||||
chat = [{"role": "system", "content": "hello ai"}]
|
||||
config = self.dummy_config(db='dummy')
|
||||
config['openai']['model'] = "text-davinci-002"
|
||||
config['openai']['max_tokens'] = 150
|
||||
config['openai']['temperature'] = 0.5
|
||||
config.openai.model = "text-davinci-002"
|
||||
config.openai.max_tokens = 150
|
||||
config.openai.temperature = 0.5
|
||||
|
||||
result = ai(chat, config, 2)
|
||||
expected_result = (['response_text_1', 'response_text_2'],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user