configuration: minor improvements / fixes
Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'.
This commit is contained in:
parent
380b7c1b67
commit
a5c91adc41
@ -30,17 +30,15 @@ def ai(chat: ChatType,
|
|||||||
Make AI request with the given chat history and configuration.
|
Make AI request with the given chat history and configuration.
|
||||||
Return AI response and tokens used.
|
Return AI response and tokens used.
|
||||||
"""
|
"""
|
||||||
if not isinstance(config['openai'], dict):
|
|
||||||
raise RuntimeError('Configuration openai is not a dict.')
|
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model=config['openai']['model'],
|
model=config.openai.model,
|
||||||
messages=chat,
|
messages=chat,
|
||||||
temperature=config['openai']['temperature'],
|
temperature=config.openai.temperature,
|
||||||
max_tokens=config['openai']['max_tokens'],
|
max_tokens=config.openai.max_tokens,
|
||||||
top_p=config['openai']['top_p'],
|
top_p=config.openai.top_p,
|
||||||
n=number,
|
n=number,
|
||||||
frequency_penalty=config['openai']['frequency_penalty'],
|
frequency_penalty=config.openai.frequency_penalty,
|
||||||
presence_penalty=config['openai']['presence_penalty'])
|
presence_penalty=config.openai.presence_penalty)
|
||||||
result = []
|
result = []
|
||||||
for choice in response['choices']: # type: ignore
|
for choice in response['choices']: # type: ignore
|
||||||
result.append(choice['message']['content'].strip())
|
result.append(choice['message']['content'].strip())
|
||||||
|
|||||||
@ -1,8 +1,13 @@
|
|||||||
import pathlib
|
import yaml
|
||||||
from typing import TypedDict, Any, Union
|
from typing import Type, TypeVar, Any
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||||
|
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConfig(TypedDict):
|
@dataclass
|
||||||
|
class OpenAIConfig():
|
||||||
"""
|
"""
|
||||||
The OpenAI section of the configuration file.
|
The OpenAI section of the configuration file.
|
||||||
"""
|
"""
|
||||||
@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict):
|
|||||||
frequency_penalty: float
|
frequency_penalty: float
|
||||||
presence_penalty: float
|
presence_penalty: float
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool:
|
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
||||||
"""
|
"""
|
||||||
Checks if the given Open AI configuration dict is complete
|
Create OpenAIConfig from a dict.
|
||||||
and contains valid types and values.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return cls(
|
||||||
str(conf['api_key'])
|
api_key=str(source['api_key']),
|
||||||
str(conf['model'])
|
model=str(source['model']),
|
||||||
int(conf['max_tokens'])
|
max_tokens=int(source['max_tokens']),
|
||||||
float(conf['temperature'])
|
temperature=float(source['temperature']),
|
||||||
float(conf['top_p'])
|
top_p=float(source['top_p']),
|
||||||
float(conf['frequency_penalty'])
|
frequency_penalty=float(source['frequency_penalty']),
|
||||||
float(conf['presence_penalty'])
|
presence_penalty=float(source['presence_penalty'])
|
||||||
return True
|
)
|
||||||
except Exception as e:
|
|
||||||
print(f"OpenAI configuration is invalid: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Config(TypedDict):
|
@dataclass
|
||||||
|
class Config():
|
||||||
"""
|
"""
|
||||||
The configuration file structure.
|
The configuration file structure.
|
||||||
"""
|
"""
|
||||||
@ -42,22 +44,23 @@ class Config(TypedDict):
|
|||||||
db: str
|
db: str
|
||||||
openai: OpenAIConfig
|
openai: OpenAIConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
||||||
|
"""
|
||||||
|
Create OpenAIConfig from a dict.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
system=str(source['system']),
|
||||||
|
db=str(source['db']),
|
||||||
|
openai=OpenAIConfig.from_dict(source['openai'])
|
||||||
|
)
|
||||||
|
|
||||||
def config_valid(conf: dict[str, Any]) -> bool:
|
@classmethod
|
||||||
"""
|
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
|
||||||
Checks if the given configuration dict is complete
|
with open(path, 'r') as f:
|
||||||
and contains valid types and values.
|
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
"""
|
return cls.from_dict(source)
|
||||||
try:
|
|
||||||
str(conf['system'])
|
def to_file(self, path: str) -> None:
|
||||||
pathlib.Path(str(conf['db']))
|
with open(path, 'w') as f:
|
||||||
return True
|
yaml.dump(asdict(self), f)
|
||||||
except Exception as e:
|
|
||||||
print(f"Configuration is invalid: {e}")
|
|
||||||
return False
|
|
||||||
if 'openai' in conf:
|
|
||||||
return openai_config_valid(conf['openai'])
|
|
||||||
else:
|
|
||||||
# required as long as we only support OpenAI
|
|
||||||
print("Section 'openai' is missing in the configuration!")
|
|
||||||
return False
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import argcomplete
|
|||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
|
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
|
||||||
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
|
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
|
||||||
from .api_client import ai, openai_api_key, print_models
|
from .api_client import ai, openai_api_key, print_models
|
||||||
from .configuration import Config
|
from .configuration import Config
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
@ -72,10 +72,10 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
if args.list_models:
|
if args.list_models:
|
||||||
print_models()
|
print_models()
|
||||||
elif args.print_model:
|
elif args.print_model:
|
||||||
print(config['openai']['model'])
|
print(config.openai.model)
|
||||||
elif args.model:
|
elif args.model:
|
||||||
config['openai']['model'] = args.model
|
config.openai.model = args.model
|
||||||
write_config(args.config, config)
|
config.to_file(args.config)
|
||||||
|
|
||||||
|
|
||||||
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
@ -83,11 +83,11 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
Handler for the 'ask' command.
|
Handler for the 'ask' command.
|
||||||
"""
|
"""
|
||||||
if args.max_tokens:
|
if args.max_tokens:
|
||||||
config['openai']['max_tokens'] = args.max_tokens
|
config.openai.max_tokens = args.max_tokens
|
||||||
if args.temperature:
|
if args.temperature:
|
||||||
config['openai']['temperature'] = args.temperature
|
config.openai.temperature = args.temperature
|
||||||
if args.model:
|
if args.model:
|
||||||
config['openai']['model'] = args.model
|
config.openai.model = args.model
|
||||||
chat, question, tags = create_question_with_hist(args, config)
|
chat, question, tags = create_question_with_hist(args, config)
|
||||||
print_chat_hist(chat, False, args.only_source_code)
|
print_chat_hist(chat, False, args.only_source_code)
|
||||||
otags = args.output_tags or []
|
otags = args.output_tags or []
|
||||||
@ -225,9 +225,9 @@ def main() -> int:
|
|||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
command = parser.parse_args()
|
command = parser.parse_args()
|
||||||
config = read_config(args.config)
|
config = Config.from_file(args.config)
|
||||||
|
|
||||||
openai_api_key(config['openai']['api_key'])
|
openai_api_key(config.openai.api_key)
|
||||||
|
|
||||||
command.func(command, config)
|
command.func(command, config)
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
import yaml
|
import yaml
|
||||||
import sys
|
|
||||||
import io
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
from .utils import terminal_width, append_message, message_to_chat, ChatType
|
from .utils import terminal_width, append_message, message_to_chat, ChatType
|
||||||
from .configuration import Config, config_valid
|
from .configuration import Config
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
@ -24,19 +23,6 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
|
|||||||
"file": fname.name}
|
"file": fname.name}
|
||||||
|
|
||||||
|
|
||||||
def read_config(path: str) -> Config:
|
|
||||||
with open(path, 'r') as f:
|
|
||||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
||||||
if not config_valid(config):
|
|
||||||
sys.exit(1)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def write_config(path: str, config: Config) -> None:
|
|
||||||
with open(path, 'w') as f:
|
|
||||||
yaml.dump(config, f)
|
|
||||||
|
|
||||||
|
|
||||||
def dump_data(data: dict[str, Any]) -> str:
|
def dump_data(data: dict[str, Any]) -> str:
|
||||||
with io.StringIO() as fd:
|
with io.StringIO() as fd:
|
||||||
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
|
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
|
||||||
@ -60,7 +46,7 @@ def save_answers(question: str,
|
|||||||
) -> None:
|
) -> None:
|
||||||
wtags = otags or tags
|
wtags = otags or tags
|
||||||
num, inum = 0, 0
|
num, inum = 0, 0
|
||||||
next_fname = pathlib.Path(str(config['db'])) / '.next'
|
next_fname = pathlib.Path(str(config.db)) / '.next'
|
||||||
try:
|
try:
|
||||||
with open(next_fname, 'r') as f:
|
with open(next_fname, 'r') as f:
|
||||||
num = int(f.read())
|
num = int(f.read())
|
||||||
@ -87,8 +73,8 @@ def create_chat_hist(question: Optional[str],
|
|||||||
with_file: bool = False
|
with_file: bool = False
|
||||||
) -> ChatType:
|
) -> ChatType:
|
||||||
chat: ChatType = []
|
chat: ChatType = []
|
||||||
append_message(chat, 'system', str(config['system']).strip())
|
append_message(chat, 'system', str(config.system).strip())
|
||||||
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
|
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
|
||||||
if file.suffix == '.yaml':
|
if file.suffix == '.yaml':
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
@ -114,7 +100,7 @@ def create_chat_hist(question: Optional[str],
|
|||||||
|
|
||||||
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
|
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
|
||||||
result = []
|
result = []
|
||||||
for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
|
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
|
||||||
if file.suffix == '.yaml':
|
if file.suffix == '.yaml':
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import argparse
|
|||||||
from chatmastermind.utils import terminal_width
|
from chatmastermind.utils import terminal_width
|
||||||
from chatmastermind.main import create_parser, ask_cmd
|
from chatmastermind.main import create_parser, ask_cmd
|
||||||
from chatmastermind.api_client import ai
|
from chatmastermind.api_client import ai
|
||||||
from chatmastermind.configuration import Config, OpenAIConfig
|
from chatmastermind.configuration import Config
|
||||||
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch, MagicMock, Mock, ANY
|
from unittest.mock import patch, MagicMock, Mock, ANY
|
||||||
@ -19,18 +19,16 @@ class CmmTestCase(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Creates a dummy configuration.
|
Creates a dummy configuration.
|
||||||
"""
|
"""
|
||||||
return Config(
|
return Config.from_dict(
|
||||||
system='dummy_system',
|
{'system': 'dummy_system',
|
||||||
db=db,
|
'db': db,
|
||||||
openai=OpenAIConfig(
|
'openai': {'api_key': 'dummy_key',
|
||||||
api_key='dummy_key',
|
'model': 'dummy_model',
|
||||||
model='dummy_model',
|
'max_tokens': 4000,
|
||||||
max_tokens=4000,
|
'temperature': 1.0,
|
||||||
temperature=1.0,
|
'top_p': 1,
|
||||||
top_p=1,
|
'frequency_penalty': 0,
|
||||||
frequency_penalty=0,
|
'presence_penalty': 0}}
|
||||||
presence_penalty=0
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +53,7 @@ class TestCreateChat(CmmTestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(test_chat), 4)
|
self.assertEqual(len(test_chat), 4)
|
||||||
self.assertEqual(test_chat[0],
|
self.assertEqual(test_chat[0],
|
||||||
{'role': 'system', 'content': self.config['system']})
|
{'role': 'system', 'content': self.config.system})
|
||||||
self.assertEqual(test_chat[1],
|
self.assertEqual(test_chat[1],
|
||||||
{'role': 'user', 'content': 'test_content'})
|
{'role': 'user', 'content': 'test_content'})
|
||||||
self.assertEqual(test_chat[2],
|
self.assertEqual(test_chat[2],
|
||||||
@ -77,7 +75,7 @@ class TestCreateChat(CmmTestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(test_chat), 2)
|
self.assertEqual(len(test_chat), 2)
|
||||||
self.assertEqual(test_chat[0],
|
self.assertEqual(test_chat[0],
|
||||||
{'role': 'system', 'content': self.config['system']})
|
{'role': 'system', 'content': self.config.system})
|
||||||
self.assertEqual(test_chat[1],
|
self.assertEqual(test_chat[1],
|
||||||
{'role': 'user', 'content': self.question})
|
{'role': 'user', 'content': self.question})
|
||||||
|
|
||||||
@ -100,7 +98,7 @@ class TestCreateChat(CmmTestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(test_chat), 6)
|
self.assertEqual(len(test_chat), 6)
|
||||||
self.assertEqual(test_chat[0],
|
self.assertEqual(test_chat[0],
|
||||||
{'role': 'system', 'content': self.config['system']})
|
{'role': 'system', 'content': self.config.system})
|
||||||
self.assertEqual(test_chat[1],
|
self.assertEqual(test_chat[1],
|
||||||
{'role': 'user', 'content': 'test_content'})
|
{'role': 'user', 'content': 'test_content'})
|
||||||
self.assertEqual(test_chat[2],
|
self.assertEqual(test_chat[2],
|
||||||
@ -209,9 +207,9 @@ class TestAI(CmmTestCase):
|
|||||||
|
|
||||||
chat = [{"role": "system", "content": "hello ai"}]
|
chat = [{"role": "system", "content": "hello ai"}]
|
||||||
config = self.dummy_config(db='dummy')
|
config = self.dummy_config(db='dummy')
|
||||||
config['openai']['model'] = "text-davinci-002"
|
config.openai.model = "text-davinci-002"
|
||||||
config['openai']['max_tokens'] = 150
|
config.openai.max_tokens = 150
|
||||||
config['openai']['temperature'] = 0.5
|
config.openai.temperature = 0.5
|
||||||
|
|
||||||
result = ai(chat, config, 2)
|
result = ai(chat, config, 2)
|
||||||
expected_result = (['response_text_1', 'response_text_2'],
|
expected_result = (['response_text_1', 'response_text_2'],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user