Compare commits
2 Commits
f3dfbc627e
...
05ae13c147
| Author | SHA1 | Date | |
|---|---|---|---|
| 05ae13c147 | |||
| 86663e072e |
64
chatmastermind/ai.py
Normal file
64
chatmastermind/ai.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Protocol, Optional, Union
|
||||||
|
from .configuration import AIConfig
|
||||||
|
from .message import Message
|
||||||
|
from .chat import Chat
|
||||||
|
|
||||||
|
|
||||||
|
class AIError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tokens:
|
||||||
|
prompt: int = 0
|
||||||
|
completion: int = 0
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIResponse:
|
||||||
|
"""
|
||||||
|
The response to an AI request. Consists of one or more messages
|
||||||
|
(each containing the question and a single answer) and the nr.
|
||||||
|
of used tokens.
|
||||||
|
"""
|
||||||
|
messages: list[Message]
|
||||||
|
tokens: Optional[Tokens] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AI(Protocol):
|
||||||
|
"""
|
||||||
|
The base class for AI clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
config: AIConfig
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
context: Chat,
|
||||||
|
num_answers: int = 1) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request, asking the given question with the given
|
||||||
|
context (i. e. chat history). The nr. of requested answers
|
||||||
|
corresponds to the nr. of messages in the 'AIResponse'.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
"""
|
||||||
|
Computes the nr. of AI language tokens for the given message
|
||||||
|
or chat. Note that the computation may not be 100% accurate
|
||||||
|
and is not implemented for all AIs.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@ -7,7 +7,15 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenAIConfig():
|
class AIConfig:
|
||||||
|
"""
|
||||||
|
The base class of all AI configurations.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenAIConfig(AIConfig):
|
||||||
"""
|
"""
|
||||||
The OpenAI section of the configuration file.
|
The OpenAI section of the configuration file.
|
||||||
"""
|
"""
|
||||||
@ -25,6 +33,7 @@ class OpenAIConfig():
|
|||||||
Create OpenAIConfig from a dict.
|
Create OpenAIConfig from a dict.
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
|
name='OpenAI',
|
||||||
api_key=str(source['api_key']),
|
api_key=str(source['api_key']),
|
||||||
model=str(source['model']),
|
model=str(source['model']),
|
||||||
max_tokens=int(source['max_tokens']),
|
max_tokens=int(source['max_tokens']),
|
||||||
@ -36,7 +45,7 @@ class OpenAIConfig():
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config():
|
class Config:
|
||||||
"""
|
"""
|
||||||
The configuration file structure.
|
The configuration file structure.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user