diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 74438b8..ffbfa7a 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -17,8 +17,9 @@ class OpenAI(AI): The OpenAI AI client. """ - def __init__(self, name: str, config: OpenAIConfig) -> None: - self.name = name + def __init__(self, config: OpenAIConfig) -> None: + self.ai_type = config.ai_type + self.name = config.name self.config = config def request(self, @@ -31,8 +32,7 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - # FIXME: use real 'system' message (store in OpenAIConfig) - oai_chat = self.openai_chat(chat, "system", question) + oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0780604..39d9a46 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,16 +1,26 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] + + +class ConfigError(Exception): + pass + + @dataclass class AIConfig: """ The base class of all AI configurations. """ + ai_type: str name: str @@ -19,13 +29,18 @@ class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + # all members have default values, so we can easily create + # a default configuration + ai_type: str = 'openai' + name: str = 'openai_1' + system: str = 'You are an assistant' + api_key: str = '0123456789' + model: str = 'gpt-3.5' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: @@ -33,7 +48,9 @@ class OpenAIConfig(AIConfig): Create OpenAIConfig from a dict. """ return cls( - name='OpenAI', + ai_type='openai', + name=str(source['name']), + system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -43,36 +60,79 @@ class OpenAIConfig(AIConfig): presence_penalty=float(source['presence_penalty']) ) + def as_dict(self) -> dict[str, Any]: + return asdict(self) + + +def ai_type_instance(ai_type: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given type. + """ + if ai_type.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"AI type '{ai_type}' is not supported") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_type_instance(ai_type).name: ai_type_instance(ai_type) for ai_type in supported_ais} + @dataclass class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ Create Config from a dict. """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for name, conf in source['ais'].items(): + ai_conf = ai_type_instance(conf['type'], conf) + ais[name] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) + # add the AI name to the config (for easy internal access) + for name, conf in source['ais'].items(): + conf['name'] = name return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f, sort_keys=False) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for ai_name, ai_conf in data['ais'].items(): + del (ai_conf['name']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) def as_dict(self) -> dict[str, Any]: return asdict(self)