Compare commits

...

2 Commits

Author SHA1 Message Date
f3dfbc627e added new module 'ai.py' 2023-09-01 09:01:14 +02:00
4f11d78f37 added tokens() function to Message and Chat 2023-09-01 09:01:09 +02:00
3 changed files with 83 additions and 0 deletions

64
chatmastermind/ai.py Normal file
View File

@ -0,0 +1,64 @@
from dataclasses import dataclass
from abc import abstractmethod
from typing import Protocol, Optional, Union
from .configuration import Config
from .message import Message
from .chat import Chat
class AIError(Exception):
pass
@dataclass
class Tokens:
prompt: int = 0
completion: int = 0
total: int = 0
@dataclass
class AIResponse:
"""
The response to an AI request. Consists of one or more messages
(each containing the question and a single answer) and the nr.
of used tokens.
"""
messages: list[Message]
tokens: Optional[Tokens] = None
class AI(Protocol):
"""
The base class for AI clients.
"""
name: str
config: Config
@abstractmethod
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
"""
raise NotImplementedError
@abstractmethod
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
or chat. Note that the computation may not be 100% accurate
and is not implemented for all AIs.
"""
raise NotImplementedError

View File

@ -129,6 +129,13 @@ class Chat:
tags |= m.filter_tags(prefix, contain)
return tags
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by all messages in this chat.
If unknown, 0 is returned.
"""
return sum(m.tokens() for m in self.messages)
def print(self, dump: bool = False, source_code_only: bool = False,
with_tags: bool = False, with_file: bool = False,
paged: bool = True) -> None:

View File

@ -132,6 +132,7 @@ class Question(str):
"""
A single question with a defined header.
"""
tokens: int = 0 # tokens used by this question
txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'question'
@ -165,6 +166,7 @@ class Answer(str):
"""
A single answer with a defined header.
"""
tokens: int = 0 # tokens used by this answer
txt_header: ClassVar[str] = '=== ANSWER ==='
yaml_key: ClassVar[str] = 'answer'
@ -502,3 +504,13 @@ class Message():
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by this message.
If unknown, 0 is returned.
"""
if self.answer:
return self.question.tokens + self.answer.tokens
else:
return self.question.tokens