diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index 4a8b914..e94de8e 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -33,18 +33,23 @@ class AI(Protocol): The base class for AI clients. """ + ID: str name: str config: AIConfig def request(self, question: Message, - context: Chat, + chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ - Make an AI request, asking the given question with the given - context (i. e. chat history). The nr. of requested answers - corresponds to the nr. of messages in the 'AIResponse'. + Make an AI request. Parameters: + * question: the question to ask + * chat: the chat history to be added as context + * num_answers: nr. of requested answers (corresponds + to the nr. of messages in the 'AIResponse') + * otags: the output tags, i. e. the tags that all + returned messages should contain """ raise NotImplementedError diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c90366b..cd688b4 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -3,18 +3,24 @@ Creates different AI instances, based on the given configuration. """ import argparse -from .configuration import Config +from typing import cast +from .configuration import Config, OpenAIConfig, default_ai_ID from .ai import AI, AIError -from .ais.openai import OpenAI +from .ais.openai_cmm import OpenAI def create_ai(args: argparse.Namespace, config: Config) -> AI: """ Creates an AI subclass instance from the given args and configuration. """ - if args.ai == 'openai': - # FIXME: create actual 'OpenAIConfig' and set values from 'args' - # FIXME: use actual name from config - return OpenAI("openai", config.openai) + if args.ai: + ai_conf = config.ais[args.ai] + elif default_ai_ID in config.ais: + ai_conf = config.ais[default_ai_ID] + else: + raise AIError("No AI name given and no default exists") + + if ai_conf.name == 'openai': + return OpenAI(cast(OpenAIConfig, ai_conf)) else: raise AIError(f"AI '{args.ai}' is not supported") diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai_cmm.py similarity index 93% rename from chatmastermind/ais/openai.py rename to chatmastermind/ais/openai_cmm.py index 74438b8..14ce33f 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai_cmm.py @@ -17,9 +17,11 @@ class OpenAI(AI): The OpenAI AI client. """ - def __init__(self, name: str, config: OpenAIConfig) -> None: - self.name = name + def __init__(self, config: OpenAIConfig) -> None: + self.ID = config.ID + self.name = config.name self.config = config + openai.api_key = config.api_key def request(self, question: Message, @@ -31,8 +33,7 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - # FIXME: use real 'system' message (store in OpenAIConfig) - oai_chat = self.openai_chat(chat, "system", question) + oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0780604..d82f913 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,17 +1,40 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional, ClassVar +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] +default_ai_ID: str = 'default' +default_config_path = '.config.yaml' + + +class ConfigError(Exception): + pass + + @dataclass class AIConfig: """ The base class of all AI configurations. """ - name: str + # the name of the AI the config class represents + # -> it's a class variable and thus not part of the + # dataclass constructor + name: ClassVar[str] + # a user-defined ID for an AI configuration entry + ID: str + + # the name must not be changed + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name': + raise AttributeError("'{name}' is not allowed to be changed") + else: + super().__setattr__(name, value) @dataclass @@ -19,21 +42,27 @@ class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + name: ClassVar[str] = 'openai' + + # all members have default values, so we can easily create + # a default configuration + ID: str = 'default' + api_key: str = '0123456789' + system: str = 'You are an assistant' + model: str = 'gpt-3.5-turbo-16k' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: """ Create OpenAIConfig from a dict. """ - return cls( - name='OpenAI', + res = cls( + system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -42,6 +71,30 @@ class OpenAIConfig(AIConfig): frequency_penalty=float(source['frequency_penalty']), presence_penalty=float(source['presence_penalty']) ) + # overwrite default ID if provided + if 'ID' in source: + res.ID = source['ID'] + return res + + +def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given name. + """ + if name.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"AI '{name}' is not supported") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais} @dataclass @@ -49,30 +102,52 @@ class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create Config from a dict. + Create Config from a dict (with the same format as the config file). """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for ID, conf in source['ais'].items(): + # add the AI ID to the config (for easy internal access) + conf['ID'] = ID + ai_conf = ai_config_instance(conf['name'], conf) + ais[ID] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f, sort_keys=False) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for conf in data['ais'].values(): + del (conf['ID']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) def as_dict(self) -> dict[str, Any]: - return asdict(self) + res = asdict(self) + for ID, conf in res['ais'].items(): + conf.update({'name': self.ais[ID].name}) + return res