Compare commits

...

4 Commits

Author SHA1 Message Date
841233a522 question fix 2023-09-05 08:02:15 +02:00
f5e9bed9bf added new module 'openai.py' 2023-09-04 23:03:29 +02:00
d199e2bc26 added new module 'ai.py' 2023-09-04 23:03:29 +02:00
8884283c05 cmm: added 'question' command 2023-09-04 23:03:29 +02:00
4 changed files with 234 additions and 24 deletions

64
chatmastermind/ai.py Normal file
View File

@ -0,0 +1,64 @@
from dataclasses import dataclass
from abc import abstractmethod
from typing import Protocol, Optional, Union
from .configuration import AIConfig
from .message import Message
from .chat import Chat
class AIError(Exception):
pass
@dataclass
class Tokens:
prompt: int = 0
completion: int = 0
total: int = 0
@dataclass
class AIResponse:
"""
The response to an AI request. Consists of one or more messages
(each containing the question and a single answer) and the nr.
of used tokens.
"""
messages: list[Message]
tokens: Optional[Tokens] = None
class AI(Protocol):
"""
The base class for AI clients.
"""
name: str
config: AIConfig
@abstractmethod
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
"""
raise NotImplementedError
@abstractmethod
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
or chat. Note that the computation may not be 100% accurate
and is not implemented for all AIs.
"""
raise NotImplementedError

View File

@ -0,0 +1,90 @@
"""
Implements the OpenAI client classes and functions.
"""
import openai
from typing import Optional
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
from ..ai import AI, AIResponse, Tokens
from ..config import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAI(AI):
"""
The OpenAI AI client.
"""
config: OpenAIConfig
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request, asking the given question with the given
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
answers: list[Message] = []
for choice in response['choices']: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.name,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'],
response['usage']['completion'],
response['usage']['total']))
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
if question:
append('user', question.question)
return oai_chat

View File

@ -11,7 +11,7 @@ from .storage import save_answers, create_chat_hist
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from .chat import ChatDB
from .message import Message, MessageFilter, MessageError
from .message import Message, MessageFilter, MessageError, Question
from itertools import zip_longest
from typing import Any
@ -30,12 +30,12 @@ def create_question_with_hist(args: argparse.Namespace,
Creates the "AI request", including the question and chat history as determined
by the specified tags.
"""
tags = args.tags or []
etags = args.etags or []
tags = args.or_tags or []
xtags = args.exclude_tags or []
otags = args.output_tags or []
if not args.source_code_only:
print_tag_args(tags, etags, otags)
print_tag_args(tags, xtags, otags)
question_parts = []
question_list = args.question if args.question is not None else []
@ -52,8 +52,8 @@ def create_question_with_hist(args: argparse.Namespace,
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, etags, config,
match_all_tags=True if args.atags else False, # FIXME
chat = create_chat_hist(full_question, tags, xtags, config,
match_all_tags=True if args.and_tags else False, # FIXME
with_tags=False,
with_file=False)
return chat, full_question, tags
@ -85,6 +85,40 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
config.to_file(args.config)
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = Message(question=Question(args.question),
tags=args.ouput_tags, # FIXME
ai=args.ai,
model=args.model)
chat.add_to_cache([message])
if args.create:
return
elif args.ask:
# TODO:
# * select the correct AIConfig
# * modify it according to the given arguments
# * create AI instance and make AI request
# * add answer to the message above (and create
# more messages for any additional answers)
pass
elif args.repeat:
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'ask' command.
@ -109,9 +143,9 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.tags,
tags_and=args.atags,
tags_not=args.etags,
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
@ -139,7 +173,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
parser.add_argument('-C', '--config', help='Config file name.', default=default_config)
# subcommand-parser
cmdparser = parser.add_subparsers(dest='command',
@ -149,19 +183,41 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tag names (one must match)', metavar='TAGS')
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
help='List of tag names (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+',
help='List of tag names (all must match)', metavar='TAGS')
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
help='List of tag names (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+',
help='List of tag names to exclude', metavar='ETAGS')
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tag names to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tag names, default is input', metavar='OTAGS')
help='List of output tag names, default is input', metavar='OUTTAGS')
otag_arg.completer = tags_completer # type: ignore
# 'question' command parser
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser],
help="ask, create and process questions.",
aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
question_cmd_parser.add_argument('-A', '--AI', help='AI to use')
question_cmd_parser.add_argument('-M', '--model', help='Model to use')
question_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1)
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true')
# 'ask' command parser
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
help="Ask a question.",

View File

@ -114,9 +114,9 @@ class TestHandleQuestion(CmmTestCase):
def setUp(self) -> None:
self.question = "test question"
self.args = argparse.Namespace(
tags=['tag1'],
atags=None,
etags=['etag1'],
or_tags=['tag1'],
and_tags=None,
exclude_tags=['xtag1'],
output_tags=None,
question=[self.question],
source=None,
@ -143,12 +143,12 @@ class TestHandleQuestion(CmmTestCase):
open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.etags,
mock_print_tag_args.assert_called_once_with(self.args.or_tags,
self.args.exclude_tags,
[])
mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags,
self.args.etags,
self.args.or_tags,
self.args.exclude_tags,
self.config,
match_all_tags=False,
with_tags=False,