Compare commits
2 Commits
60583a27b2
...
f3dfbc627e
| Author | SHA1 | Date | |
|---|---|---|---|
| f3dfbc627e | |||
| 4f11d78f37 |
64
chatmastermind/ai.py
Normal file
64
chatmastermind/ai.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Protocol, Optional, Union
|
||||||
|
from .configuration import Config
|
||||||
|
from .message import Message
|
||||||
|
from .chat import Chat
|
||||||
|
|
||||||
|
|
||||||
|
class AIError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tokens:
|
||||||
|
prompt: int = 0
|
||||||
|
completion: int = 0
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIResponse:
|
||||||
|
"""
|
||||||
|
The response to an AI request. Consists of one or more messages
|
||||||
|
(each containing the question and a single answer) and the nr.
|
||||||
|
of used tokens.
|
||||||
|
"""
|
||||||
|
messages: list[Message]
|
||||||
|
tokens: Optional[Tokens] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AI(Protocol):
|
||||||
|
"""
|
||||||
|
The base class for AI clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
config: Config
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
context: Chat,
|
||||||
|
num_answers: int = 1) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request, asking the given question with the given
|
||||||
|
context (i. e. chat history). The nr. of requested answers
|
||||||
|
corresponds to the nr. of messages in the 'AIResponse'.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
"""
|
||||||
|
Computes the nr. of AI language tokens for the given message
|
||||||
|
or chat. Note that the computation may not be 100% accurate
|
||||||
|
and is not implemented for all AIs.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@ -129,6 +129,13 @@ class Chat:
|
|||||||
tags |= m.filter_tags(prefix, contain)
|
tags |= m.filter_tags(prefix, contain)
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
def tokens(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the nr. of AI language tokens used by all messages in this chat.
|
||||||
|
If unknown, 0 is returned.
|
||||||
|
"""
|
||||||
|
return sum(m.tokens() for m in self.messages)
|
||||||
|
|
||||||
def print(self, dump: bool = False, source_code_only: bool = False,
|
def print(self, dump: bool = False, source_code_only: bool = False,
|
||||||
with_tags: bool = False, with_file: bool = False,
|
with_tags: bool = False, with_file: bool = False,
|
||||||
paged: bool = True) -> None:
|
paged: bool = True) -> None:
|
||||||
|
|||||||
@ -132,6 +132,7 @@ class Question(str):
|
|||||||
"""
|
"""
|
||||||
A single question with a defined header.
|
A single question with a defined header.
|
||||||
"""
|
"""
|
||||||
|
tokens: int = 0 # tokens used by this question
|
||||||
txt_header: ClassVar[str] = '=== QUESTION ==='
|
txt_header: ClassVar[str] = '=== QUESTION ==='
|
||||||
yaml_key: ClassVar[str] = 'question'
|
yaml_key: ClassVar[str] = 'question'
|
||||||
|
|
||||||
@ -165,6 +166,7 @@ class Answer(str):
|
|||||||
"""
|
"""
|
||||||
A single answer with a defined header.
|
A single answer with a defined header.
|
||||||
"""
|
"""
|
||||||
|
tokens: int = 0 # tokens used by this answer
|
||||||
txt_header: ClassVar[str] = '=== ANSWER ==='
|
txt_header: ClassVar[str] = '=== ANSWER ==='
|
||||||
yaml_key: ClassVar[str] = 'answer'
|
yaml_key: ClassVar[str] = 'answer'
|
||||||
|
|
||||||
@ -502,3 +504,13 @@ class Message():
|
|||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
def tokens(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the nr. of AI language tokens used by this message.
|
||||||
|
If unknown, 0 is returned.
|
||||||
|
"""
|
||||||
|
if self.answer:
|
||||||
|
return self.question.tokens + self.answer.tokens
|
||||||
|
else:
|
||||||
|
return self.question.tokens
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user