Compare commits

..

3 Commits

Author SHA1 Message Date
7453f5d8d7 added new module 'openai.py' 2023-09-05 08:02:54 +02:00
e3f1776d0d added new module 'ai.py' 2023-09-05 08:02:54 +02:00
c73b84f568 cmm: added 'question' command 2023-09-05 08:02:54 +02:00
5 changed files with 17 additions and 47 deletions

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
from abc import abstractmethod
from typing import Protocol, Optional, Union
from .configuration import AIConfig
from .tags import Tag
from .message import Message
from .chat import Chat
@ -36,11 +36,11 @@ class AI(Protocol):
name: str
config: AIConfig
@abstractmethod
def request(self,
question: Message,
context: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
@ -48,6 +48,7 @@ class AI(Protocol):
"""
raise NotImplementedError
@abstractmethod
def models(self) -> list[str]:
"""
Return all models supported by this AI.

View File

@ -1,20 +0,0 @@
"""
Creates different AI instances, based on the given configuration.
"""
import argparse
from .configuration import Config
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI:
"""
Creates an AI subclass instance from the given args and configuration.
"""
if args.ai == 'openai':
# FIXME: create actual 'OpenAIConfig' and set values from 'args'
# FIXME: use actual name from config
return OpenAI("openai", config.openai)
else:
raise AIError(f"AI '{args.ai}' is not supported")

View File

@ -2,12 +2,12 @@
Implements the OpenAI client classes and functions.
"""
import openai
from typing import Optional, Union
from typing import Optional
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
from ..ai import AI, AIResponse, Tokens
from ..configuration import OpenAIConfig
from ..config import OpenAIConfig
ChatType = list[dict[str, str]]
@ -17,9 +17,7 @@ class OpenAI(AI):
The OpenAI AI client.
"""
def __init__(self, name: str, config: OpenAIConfig) -> None:
self.name = name
self.config = config
config: OpenAIConfig
def request(self,
question: Message,
@ -31,8 +29,7 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
# FIXME: use real 'system' message (store in OpenAIConfig)
oai_chat = self.openai_chat(chat, "system", question)
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
@ -91,6 +88,3 @@ class OpenAI(AI):
if question:
append('user', question.question)
return oai_chat
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError

View File

@ -12,7 +12,6 @@ from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from .chat import ChatDB
from .message import Message, MessageFilter, MessageError, Question
from .ai_factory import create_ai
from itertools import zip_longest
from typing import Any
@ -101,15 +100,11 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
chat.add_to_cache([message])
if args.create:
return
# create the correct AI instance
ai = create_ai(args, config)
if args.ask:
ai.request(message,
chat,
args.num_answers, # FIXME
args.otags) # FIXME
elif args.ask:
# TODO:
# * select the correct AIConfig
# * modify it according to the given arguments
# * create AI instance and make AI request
# * add answer to the message above (and create
# more messages for any additional answers)
pass
@ -137,7 +132,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.source_code_only)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.num_answers)
answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config)
print("-" * terminal_width())
print(f"Usage: {usage}")
@ -217,7 +212,7 @@ def create_parser() -> argparse.ArgumentParser:
question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
question_cmd_parser.add_argument('-A', '--AI', help='AI to use')
question_cmd_parser.add_argument('-M', '--model', help='Model to use')
question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
question_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1)
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
@ -233,7 +228,7 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',

View File

@ -121,7 +121,7 @@ class TestHandleQuestion(CmmTestCase):
question=[self.question],
source=None,
source_code_only=False,
num_answers=3,
number=3,
max_tokens=None,
temperature=None,
model=None,
@ -158,7 +158,7 @@ class TestHandleQuestion(CmmTestCase):
self.args.source_code_only)
mock_ai.assert_called_with("test_chat",
self.config,
self.args.num_answers)
self.args.number)
expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} '