Compare commits
9 Commits
1e68617a46
...
ba5aa1fbc7
| Author | SHA1 | Date | |
|---|---|---|---|
| ba5aa1fbc7 | |||
| eb2fcba99d | |||
| b7e3ca7ca7 | |||
| aa322de718 | |||
| bf1cbff6a2 | |||
| f93a57c00d | |||
| b0504aedbe | |||
| eb0d97ddc8 | |||
| 7e25a08d6e |
63
chatmastermind/ai.py
Normal file
63
chatmastermind/ai.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol, Optional, Union
|
||||||
|
from .configuration import AIConfig
|
||||||
|
from .tags import Tag
|
||||||
|
from .message import Message
|
||||||
|
from .chat import Chat
|
||||||
|
|
||||||
|
|
||||||
|
class AIError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tokens:
|
||||||
|
prompt: int = 0
|
||||||
|
completion: int = 0
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AIResponse:
|
||||||
|
"""
|
||||||
|
The response to an AI request. Consists of one or more messages
|
||||||
|
(each containing the question and a single answer) and the nr.
|
||||||
|
of used tokens.
|
||||||
|
"""
|
||||||
|
messages: list[Message]
|
||||||
|
tokens: Optional[Tokens] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AI(Protocol):
|
||||||
|
"""
|
||||||
|
The base class for AI clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
config: AIConfig
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
context: Chat,
|
||||||
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request, asking the given question with the given
|
||||||
|
context (i. e. chat history). The nr. of requested answers
|
||||||
|
corresponds to the nr. of messages in the 'AIResponse'.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
"""
|
||||||
|
Computes the nr. of AI language tokens for the given message
|
||||||
|
or chat. Note that the computation may not be 100% accurate
|
||||||
|
and is not implemented for all AIs.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
20
chatmastermind/ai_factory.py
Normal file
20
chatmastermind/ai_factory.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Creates different AI instances, based on the given configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from .configuration import Config
|
||||||
|
from .ai import AI, AIError
|
||||||
|
from .ais.openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def create_ai(args: argparse.Namespace, config: Config) -> AI:
|
||||||
|
"""
|
||||||
|
Creates an AI subclass instance from the given args and configuration.
|
||||||
|
"""
|
||||||
|
if args.ai == 'openai':
|
||||||
|
# FIXME: create actual 'OpenAIConfig' and set values from 'args'
|
||||||
|
# FIXME: use actual name from config
|
||||||
|
return OpenAI("openai", config.openai)
|
||||||
|
else:
|
||||||
|
raise AIError(f"AI '{args.ai}' is not supported")
|
||||||
96
chatmastermind/ais/openai.py
Normal file
96
chatmastermind/ais/openai.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Implements the OpenAI client classes and functions.
|
||||||
|
"""
|
||||||
|
import openai
|
||||||
|
from typing import Optional, Union
|
||||||
|
from ..tags import Tag
|
||||||
|
from ..message import Message, Answer
|
||||||
|
from ..chat import Chat
|
||||||
|
from ..ai import AI, AIResponse, Tokens
|
||||||
|
from ..configuration import OpenAIConfig
|
||||||
|
|
||||||
|
ChatType = list[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAI(AI):
|
||||||
|
"""
|
||||||
|
The OpenAI AI client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, config: OpenAIConfig) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
chat: Chat,
|
||||||
|
num_answers: int = 1,
|
||||||
|
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request, asking the given question with the given
|
||||||
|
chat history. The nr. of requested answers corresponds to the
|
||||||
|
nr. of messages in the 'AIResponse'.
|
||||||
|
"""
|
||||||
|
# FIXME: use real 'system' message (store in OpenAIConfig)
|
||||||
|
oai_chat = self.openai_chat(chat, "system", question)
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=oai_chat,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
top_p=self.config.top_p,
|
||||||
|
n=num_answers,
|
||||||
|
frequency_penalty=self.config.frequency_penalty,
|
||||||
|
presence_penalty=self.config.presence_penalty)
|
||||||
|
answers: list[Message] = []
|
||||||
|
for choice in response['choices']: # type: ignore
|
||||||
|
answers.append(Message(question=question.question,
|
||||||
|
answer=Answer(choice['message']['content']),
|
||||||
|
tags=otags,
|
||||||
|
ai=self.name,
|
||||||
|
model=self.config.model))
|
||||||
|
return AIResponse(answers, Tokens(response['usage']['prompt'],
|
||||||
|
response['usage']['completion'],
|
||||||
|
response['usage']['total']))
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Print all models supported by the current AI.
|
||||||
|
"""
|
||||||
|
not_ready = []
|
||||||
|
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
||||||
|
if engine['ready']:
|
||||||
|
print(engine['id'])
|
||||||
|
else:
|
||||||
|
not_ready.append(engine['id'])
|
||||||
|
if len(not_ready) > 0:
|
||||||
|
print('\nNot ready: ' + ', '.join(not_ready))
|
||||||
|
|
||||||
|
def openai_chat(self, chat: Chat, system: str,
|
||||||
|
question: Optional[Message] = None) -> ChatType:
|
||||||
|
"""
|
||||||
|
Create a chat history with system message in OpenAI format.
|
||||||
|
Optionally append a new question.
|
||||||
|
"""
|
||||||
|
oai_chat: ChatType = []
|
||||||
|
|
||||||
|
def append(role: str, content: str) -> None:
|
||||||
|
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||||
|
|
||||||
|
append('system', system)
|
||||||
|
for message in chat.messages:
|
||||||
|
if message.answer:
|
||||||
|
append('user', message.question)
|
||||||
|
append('assistant', message.answer)
|
||||||
|
if question:
|
||||||
|
append('user', question.question)
|
||||||
|
return oai_chat
|
||||||
|
|
||||||
|
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
@ -2,7 +2,7 @@
|
|||||||
Module implementing various chat classes and functions for managing a chat history.
|
Module implementing various chat classes and functions for managing a chat history.
|
||||||
"""
|
"""
|
||||||
import shutil
|
import shutil
|
||||||
import pathlib
|
from pathlib import Path
|
||||||
from pprint import PrettyPrinter
|
from pprint import PrettyPrinter
|
||||||
from pydoc import pager
|
from pydoc import pager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -30,7 +30,7 @@ def print_paged(text: str) -> None:
|
|||||||
pager(text)
|
pager(text)
|
||||||
|
|
||||||
|
|
||||||
def read_dir(dir_path: pathlib.Path,
|
def read_dir(dir_path: Path,
|
||||||
glob: Optional[str] = None,
|
glob: Optional[str] = None,
|
||||||
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||||
"""
|
"""
|
||||||
@ -55,9 +55,9 @@ def read_dir(dir_path: pathlib.Path,
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def make_file_path(dir_path: pathlib.Path,
|
def make_file_path(dir_path: Path,
|
||||||
file_suffix: str,
|
file_suffix: str,
|
||||||
next_fid: Callable[[], int]) -> pathlib.Path:
|
next_fid: Callable[[], int]) -> Path:
|
||||||
"""
|
"""
|
||||||
Create a file_path for the given directory using the
|
Create a file_path for the given directory using the
|
||||||
given file_suffix and ID generator function.
|
given file_suffix and ID generator function.
|
||||||
@ -65,7 +65,7 @@ def make_file_path(dir_path: pathlib.Path,
|
|||||||
return dir_path / f"{next_fid():04d}{file_suffix}"
|
return dir_path / f"{next_fid():04d}{file_suffix}"
|
||||||
|
|
||||||
|
|
||||||
def write_dir(dir_path: pathlib.Path,
|
def write_dir(dir_path: Path,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
file_suffix: str,
|
file_suffix: str,
|
||||||
next_fid: Callable[[], int]) -> None:
|
next_fid: Callable[[], int]) -> None:
|
||||||
@ -90,7 +90,7 @@ def write_dir(dir_path: pathlib.Path,
|
|||||||
message.to_file(file_path)
|
message.to_file(file_path)
|
||||||
|
|
||||||
|
|
||||||
def clear_dir(dir_path: pathlib.Path,
|
def clear_dir(dir_path: Path,
|
||||||
glob: Optional[str] = None) -> None:
|
glob: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Deletes all Message files in the given directory.
|
Deletes all Message files in the given directory.
|
||||||
@ -139,6 +139,34 @@ class Chat:
|
|||||||
self.messages += messages
|
self.messages += messages
|
||||||
self.sort()
|
self.sort()
|
||||||
|
|
||||||
|
def latest_message(self) -> Optional[Message]:
|
||||||
|
"""
|
||||||
|
Returns the last added message (according to the file ID).
|
||||||
|
"""
|
||||||
|
if len(self.messages) > 0:
|
||||||
|
self.sort()
|
||||||
|
return self.messages[-1]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_messages(self, msg_names: list[str]) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Search and return the messages with the given names. Names can either be filenames
|
||||||
|
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
|
||||||
|
caller should check the result if he requires all messages).
|
||||||
|
"""
|
||||||
|
return [m for m in self.messages
|
||||||
|
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
|
||||||
|
|
||||||
|
def remove_messages(self, msg_names: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Remove the messages with the given names. Names can either be filenames
|
||||||
|
(incl. the suffix) or full paths.
|
||||||
|
"""
|
||||||
|
self.messages = [m for m in self.messages
|
||||||
|
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
|
||||||
|
self.sort()
|
||||||
|
|
||||||
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||||
"""
|
"""
|
||||||
Get the tags of all messages, optionally filtered by prefix or substring.
|
Get the tags of all messages, optionally filtered by prefix or substring.
|
||||||
@ -192,8 +220,8 @@ class ChatDB(Chat):
|
|||||||
|
|
||||||
default_file_suffix: ClassVar[str] = '.txt'
|
default_file_suffix: ClassVar[str] = '.txt'
|
||||||
|
|
||||||
cache_path: pathlib.Path
|
cache_path: Path
|
||||||
db_path: pathlib.Path
|
db_path: Path
|
||||||
# a MessageFilter that all messages must match (if given)
|
# a MessageFilter that all messages must match (if given)
|
||||||
mfilter: Optional[MessageFilter] = None
|
mfilter: Optional[MessageFilter] = None
|
||||||
file_suffix: str = default_file_suffix
|
file_suffix: str = default_file_suffix
|
||||||
@ -209,8 +237,8 @@ class ChatDB(Chat):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dir(cls: Type[ChatDBInst],
|
def from_dir(cls: Type[ChatDBInst],
|
||||||
cache_path: pathlib.Path,
|
cache_path: Path,
|
||||||
db_path: pathlib.Path,
|
db_path: Path,
|
||||||
glob: Optional[str] = None,
|
glob: Optional[str] = None,
|
||||||
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
||||||
"""
|
"""
|
||||||
@ -230,8 +258,8 @@ class ChatDB(Chat):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_messages(cls: Type[ChatDBInst],
|
def from_messages(cls: Type[ChatDBInst],
|
||||||
cache_path: pathlib.Path,
|
cache_path: Path,
|
||||||
db_path: pathlib.Path,
|
db_path: Path,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,15 +2,18 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# vim: set fileencoding=utf-8 :
|
# vim: set fileencoding=utf-8 :
|
||||||
|
|
||||||
import yaml
|
|
||||||
import sys
|
import sys
|
||||||
import argcomplete
|
import argcomplete
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
from pathlib import Path
|
||||||
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
|
from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType
|
||||||
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
|
from .storage import save_answers, create_chat_hist
|
||||||
from .api_client import ai, openai_api_key, print_models
|
from .api_client import ai, openai_api_key, print_models
|
||||||
from .configuration import Config
|
from .configuration import Config
|
||||||
|
from .chat import ChatDB
|
||||||
|
from .message import Message, MessageFilter, MessageError, Question
|
||||||
|
from .ai_factory import create_ai
|
||||||
|
from .ai import AI, AIResponse
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -18,9 +21,8 @@ default_config = '.config.yaml'
|
|||||||
|
|
||||||
|
|
||||||
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
||||||
with open(parsed_args.config, 'r') as f:
|
config = Config.from_file(parsed_args.config)
|
||||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
|
||||||
return get_tags_unique(config, prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def create_question_with_hist(args: argparse.Namespace,
|
def create_question_with_hist(args: argparse.Namespace,
|
||||||
@ -30,12 +32,12 @@ def create_question_with_hist(args: argparse.Namespace,
|
|||||||
Creates the "AI request", including the question and chat history as determined
|
Creates the "AI request", including the question and chat history as determined
|
||||||
by the specified tags.
|
by the specified tags.
|
||||||
"""
|
"""
|
||||||
tags = args.tags or []
|
tags = args.or_tags or []
|
||||||
extags = args.extags or []
|
xtags = args.exclude_tags or []
|
||||||
otags = args.output_tags or []
|
otags = args.output_tags or []
|
||||||
|
|
||||||
if not args.only_source_code:
|
if not args.source_code_only:
|
||||||
print_tag_args(tags, extags, otags)
|
print_tag_args(tags, xtags, otags)
|
||||||
|
|
||||||
question_parts = []
|
question_parts = []
|
||||||
question_list = args.question if args.question is not None else []
|
question_list = args.question if args.question is not None else []
|
||||||
@ -52,17 +54,24 @@ def create_question_with_hist(args: argparse.Namespace,
|
|||||||
question_parts.append(f"```\n{r.read().strip()}\n```")
|
question_parts.append(f"```\n{r.read().strip()}\n```")
|
||||||
|
|
||||||
full_question = '\n\n'.join(question_parts)
|
full_question = '\n\n'.join(question_parts)
|
||||||
chat = create_chat_hist(full_question, tags, extags, config,
|
chat = create_chat_hist(full_question, tags, xtags, config,
|
||||||
args.match_all_tags, False, False)
|
match_all_tags=True if args.and_tags else False, # FIXME
|
||||||
|
with_tags=False,
|
||||||
|
with_file=False)
|
||||||
return chat, full_question, tags
|
return chat, full_question, tags
|
||||||
|
|
||||||
|
|
||||||
def tag_cmd(args: argparse.Namespace, config: Config) -> None:
|
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
"""
|
"""
|
||||||
Handler for the 'tag' command.
|
Handler for the 'tags' command.
|
||||||
"""
|
"""
|
||||||
|
chat = ChatDB.from_dir(cache_path=Path('.'),
|
||||||
|
db_path=Path(config.db))
|
||||||
if args.list:
|
if args.list:
|
||||||
print_tags_frequency(get_tags(config, None))
|
tags_freq = chat.tags_frequency(args.prefix, args.contain)
|
||||||
|
for tag, freq in tags_freq.items():
|
||||||
|
print(f"- {tag}: {freq}")
|
||||||
|
# TODO: add renaming
|
||||||
|
|
||||||
|
|
||||||
def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
@ -78,6 +87,47 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
config.to_file(args.config)
|
config.to_file(args.config)
|
||||||
|
|
||||||
|
|
||||||
|
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
|
"""
|
||||||
|
Handler for the 'question' command.
|
||||||
|
"""
|
||||||
|
chat = ChatDB.from_dir(cache_path=Path('.'),
|
||||||
|
db_path=Path(config.db))
|
||||||
|
# if it's a new question, create and store it immediately
|
||||||
|
if args.ask or args.create:
|
||||||
|
message = Message(question=Question(args.question),
|
||||||
|
tags=args.ouput_tags, # FIXME
|
||||||
|
ai=args.ai,
|
||||||
|
model=args.model)
|
||||||
|
chat.add_to_cache([message])
|
||||||
|
if args.create:
|
||||||
|
return
|
||||||
|
|
||||||
|
# create the correct AI instance
|
||||||
|
ai: AI = create_ai(args, config)
|
||||||
|
if args.ask:
|
||||||
|
response: AIResponse = ai.request(message,
|
||||||
|
chat,
|
||||||
|
args.num_answers, # FIXME
|
||||||
|
args.otags) # FIXME
|
||||||
|
assert response
|
||||||
|
# TODO:
|
||||||
|
# * add answer to the message above (and create
|
||||||
|
# more messages for any additional answers)
|
||||||
|
pass
|
||||||
|
elif args.repeat:
|
||||||
|
lmessage = chat.latest_message()
|
||||||
|
assert lmessage
|
||||||
|
# TODO: repeat either the last question or the
|
||||||
|
# one(s) given in 'args.repeat' (overwrite
|
||||||
|
# existing ones if 'args.overwrite' is True)
|
||||||
|
pass
|
||||||
|
elif args.process:
|
||||||
|
# TODO: process either all questions without an
|
||||||
|
# answer or the one(s) given in 'args.process'
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
"""
|
"""
|
||||||
Handler for the 'ask' command.
|
Handler for the 'ask' command.
|
||||||
@ -89,9 +139,9 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
if args.model:
|
if args.model:
|
||||||
config.openai.model = args.model
|
config.openai.model = args.model
|
||||||
chat, question, tags = create_question_with_hist(args, config)
|
chat, question, tags = create_question_with_hist(args, config)
|
||||||
print_chat_hist(chat, False, args.only_source_code)
|
print_chat_hist(chat, False, args.source_code_only)
|
||||||
otags = args.output_tags or []
|
otags = args.output_tags or []
|
||||||
answers, usage = ai(chat, config, args.number)
|
answers, usage = ai(chat, config, args.num_answers)
|
||||||
save_answers(question, answers, tags, otags, config)
|
save_answers(question, answers, tags, otags, config)
|
||||||
print("-" * terminal_width())
|
print("-" * terminal_width())
|
||||||
print(f"Usage: {usage}")
|
print(f"Usage: {usage}")
|
||||||
@ -101,39 +151,38 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
"""
|
"""
|
||||||
Handler for the 'hist' command.
|
Handler for the 'hist' command.
|
||||||
"""
|
"""
|
||||||
tags = args.tags or []
|
|
||||||
extags = args.extags or []
|
|
||||||
|
|
||||||
chat = create_chat_hist(None, tags, extags, config,
|
mfilter = MessageFilter(tags_or=args.or_tags,
|
||||||
args.match_all_tags,
|
tags_and=args.and_tags,
|
||||||
args.with_tags,
|
tags_not=args.exclude_tags,
|
||||||
args.with_files)
|
question_contains=args.question,
|
||||||
print_chat_hist(chat, args.dump, args.only_source_code)
|
answer_contains=args.answer)
|
||||||
|
chat = ChatDB.from_dir(Path('.'),
|
||||||
|
Path(config.db),
|
||||||
|
mfilter=mfilter)
|
||||||
|
chat.print(args.source_code_only,
|
||||||
|
args.with_tags,
|
||||||
|
args.with_files)
|
||||||
|
|
||||||
|
|
||||||
def print_cmd(args: argparse.Namespace, config: Config) -> None:
|
def print_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
"""
|
"""
|
||||||
Handler for the 'print' command.
|
Handler for the 'print' command.
|
||||||
"""
|
"""
|
||||||
fname = pathlib.Path(args.file)
|
fname = Path(args.file)
|
||||||
if fname.suffix == '.yaml':
|
try:
|
||||||
with open(args.file, 'r') as f:
|
message = Message.from_file(fname)
|
||||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
if message:
|
||||||
elif fname.suffix == '.txt':
|
print(message.to_str(source_code_only=args.source_code_only))
|
||||||
data = read_file(fname)
|
except MessageError:
|
||||||
else:
|
print(f"File is not a valid message: {args.file}")
|
||||||
print(f"Unknown file type: {args.file}")
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if args.only_source_code:
|
|
||||||
display_source_code(data['answer'])
|
|
||||||
else:
|
|
||||||
print(dump_data(data).strip())
|
|
||||||
|
|
||||||
|
|
||||||
def create_parser() -> argparse.ArgumentParser:
|
def create_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="ChatMastermind is a Python application that automates conversation with AI")
|
description="ChatMastermind is a Python application that automates conversation with AI")
|
||||||
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
|
parser.add_argument('-C', '--config', help='Config file name.', default=default_config)
|
||||||
|
|
||||||
# subcommand-parser
|
# subcommand-parser
|
||||||
cmdparser = parser.add_subparsers(dest='command',
|
cmdparser = parser.add_subparsers(dest='command',
|
||||||
@ -143,19 +192,40 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
# a parent parser for all commands that support tag selection
|
# a parent parser for all commands that support tag selection
|
||||||
tag_parser = argparse.ArgumentParser(add_help=False)
|
tag_parser = argparse.ArgumentParser(add_help=False)
|
||||||
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
|
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
|
||||||
help='List of tag names', metavar='TAGS')
|
help='List of tag names (one must match)', metavar='OTAGS')
|
||||||
tag_arg.completer = tags_completer # type: ignore
|
tag_arg.completer = tags_completer # type: ignore
|
||||||
extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+',
|
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
|
||||||
help='List of tag names to exclude', metavar='EXTAGS')
|
help='List of tag names (all must match)', metavar='ATAGS')
|
||||||
extag_arg.completer = tags_completer # type: ignore
|
atag_arg.completer = tags_completer # type: ignore
|
||||||
|
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
|
||||||
|
help='List of tag names to exclude', metavar='XTAGS')
|
||||||
|
etag_arg.completer = tags_completer # type: ignore
|
||||||
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
||||||
help='List of output tag names, default is input', metavar='OTAGS')
|
help='List of output tag names, default is input', metavar='OUTTAGS')
|
||||||
otag_arg.completer = tags_completer # type: ignore
|
otag_arg.completer = tags_completer # type: ignore
|
||||||
tag_parser.add_argument('-a', '--match-all-tags',
|
|
||||||
help="All given tags must match when selecting chat history entries",
|
# 'question' command parser
|
||||||
action='store_true')
|
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser],
|
||||||
# enable autocompletion for tags
|
help="ask, create and process questions.",
|
||||||
|
aliases=['q'])
|
||||||
|
question_cmd_parser.set_defaults(func=question_cmd)
|
||||||
|
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
|
||||||
|
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
|
||||||
|
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
|
||||||
|
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
|
||||||
|
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
|
||||||
|
action='store_true')
|
||||||
|
question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
||||||
|
question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
||||||
|
question_cmd_parser.add_argument('-A', '--AI', help='AI to use')
|
||||||
|
question_cmd_parser.add_argument('-M', '--model', help='Model to use')
|
||||||
|
question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
|
||||||
|
default=1)
|
||||||
|
question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
|
||||||
|
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
|
||||||
|
action='store_true')
|
||||||
|
|
||||||
# 'ask' command parser
|
# 'ask' command parser
|
||||||
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
|
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
|
||||||
@ -167,10 +237,10 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
||||||
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
||||||
ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
|
ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
|
||||||
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
|
ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
|
||||||
default=1)
|
default=1)
|
||||||
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
|
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
|
||||||
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
|
ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|
||||||
# 'hist' command parser
|
# 'hist' command parser
|
||||||
@ -178,23 +248,25 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
help="Print chat history.",
|
help="Print chat history.",
|
||||||
aliases=['h'])
|
aliases=['h'])
|
||||||
hist_cmd_parser.set_defaults(func=hist_cmd)
|
hist_cmd_parser.set_defaults(func=hist_cmd)
|
||||||
hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure",
|
|
||||||
action='store_true')
|
|
||||||
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
|
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
|
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
|
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
|
||||||
|
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
|
||||||
|
|
||||||
# 'tag' command parser
|
# 'tags' command parser
|
||||||
tag_cmd_parser = cmdparser.add_parser('tag',
|
tags_cmd_parser = cmdparser.add_parser('tags',
|
||||||
help="Manage tags.",
|
help="Manage tags.",
|
||||||
aliases=['t'])
|
aliases=['t'])
|
||||||
tag_cmd_parser.set_defaults(func=tag_cmd)
|
tags_cmd_parser.set_defaults(func=tags_cmd)
|
||||||
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
|
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
|
||||||
|
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
|
||||||
|
|
||||||
# 'config' command parser
|
# 'config' command parser
|
||||||
config_cmd_parser = cmdparser.add_parser('config',
|
config_cmd_parser = cmdparser.add_parser('config',
|
||||||
@ -210,11 +282,11 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
# 'print' command parser
|
# 'print' command parser
|
||||||
print_cmd_parser = cmdparser.add_parser('print',
|
print_cmd_parser = cmdparser.add_parser('print',
|
||||||
help="Print files.",
|
help="Print message files.",
|
||||||
aliases=['p'])
|
aliases=['p'])
|
||||||
print_cmd_parser.set_defaults(func=print_cmd)
|
print_cmd_parser.set_defaults(func=print_cmd)
|
||||||
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
|
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
|
||||||
print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
|
print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)',
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|
||||||
argcomplete.autocomplete(parser)
|
argcomplete.autocomplete(parser)
|
||||||
|
|||||||
@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
|
|||||||
print(message['content'])
|
print(message['content'])
|
||||||
else:
|
else:
|
||||||
print(f"{message['role'].upper()}: {message['content']}")
|
print(f"{message['role'].upper()}: {message['content']}")
|
||||||
|
|
||||||
|
|
||||||
def print_tags_frequency(tags: list[str]) -> None:
|
|
||||||
for tag in sorted(set(tags)):
|
|
||||||
print(f"- {tag}: {tags.count(tag)}")
|
|
||||||
|
|||||||
@ -62,6 +62,28 @@ class TestChat(CmmTestCase):
|
|||||||
tags_freq = self.chat.tags_frequency()
|
tags_freq = self.chat.tags_frequency()
|
||||||
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
||||||
|
|
||||||
|
def test_find_remove_messages(self) -> None:
|
||||||
|
self.chat.add_messages([self.message1, self.message2])
|
||||||
|
msgs = self.chat.find_messages(['0001.txt'])
|
||||||
|
self.assertListEqual(msgs, [self.message1])
|
||||||
|
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
|
||||||
|
self.assertListEqual(msgs, [self.message1, self.message2])
|
||||||
|
# add new Message with full path
|
||||||
|
message3 = Message(Question('Question 2'),
|
||||||
|
Answer('Answer 2'),
|
||||||
|
{Tag('btag2')},
|
||||||
|
file_path=pathlib.Path('/foo/bla/0003.txt'))
|
||||||
|
self.chat.add_messages([message3])
|
||||||
|
# find new Message by full path
|
||||||
|
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
|
||||||
|
self.assertListEqual(msgs, [message3])
|
||||||
|
# find Message with full path only by filename
|
||||||
|
msgs = self.chat.find_messages(['0003.txt'])
|
||||||
|
self.assertListEqual(msgs, [message3])
|
||||||
|
# remove last message
|
||||||
|
self.chat.remove_messages(['0003.txt'])
|
||||||
|
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
||||||
|
|
||||||
@patch('sys.stdout', new_callable=StringIO)
|
@patch('sys.stdout', new_callable=StringIO)
|
||||||
def test_print(self, mock_stdout: StringIO) -> None:
|
def test_print(self, mock_stdout: StringIO) -> None:
|
||||||
self.chat.add_messages([self.message1, self.message2])
|
self.chat.add_messages([self.message1, self.message2])
|
||||||
|
|||||||
@ -114,13 +114,14 @@ class TestHandleQuestion(CmmTestCase):
|
|||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.question = "test question"
|
self.question = "test question"
|
||||||
self.args = argparse.Namespace(
|
self.args = argparse.Namespace(
|
||||||
tags=['tag1'],
|
or_tags=['tag1'],
|
||||||
extags=['extag1'],
|
and_tags=None,
|
||||||
|
exclude_tags=['xtag1'],
|
||||||
output_tags=None,
|
output_tags=None,
|
||||||
question=[self.question],
|
question=[self.question],
|
||||||
source=None,
|
source=None,
|
||||||
only_source_code=False,
|
source_code_only=False,
|
||||||
number=3,
|
num_answers=3,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
temperature=None,
|
temperature=None,
|
||||||
model=None,
|
model=None,
|
||||||
@ -142,20 +143,22 @@ class TestHandleQuestion(CmmTestCase):
|
|||||||
open_mock = MagicMock()
|
open_mock = MagicMock()
|
||||||
with patch("chatmastermind.storage.open", open_mock):
|
with patch("chatmastermind.storage.open", open_mock):
|
||||||
ask_cmd(self.args, self.config)
|
ask_cmd(self.args, self.config)
|
||||||
mock_print_tag_args.assert_called_once_with(self.args.tags,
|
mock_print_tag_args.assert_called_once_with(self.args.or_tags,
|
||||||
self.args.extags,
|
self.args.exclude_tags,
|
||||||
[])
|
[])
|
||||||
mock_create_chat_hist.assert_called_once_with(self.question,
|
mock_create_chat_hist.assert_called_once_with(self.question,
|
||||||
self.args.tags,
|
self.args.or_tags,
|
||||||
self.args.extags,
|
self.args.exclude_tags,
|
||||||
self.config,
|
self.config,
|
||||||
False, False, False)
|
match_all_tags=False,
|
||||||
|
with_tags=False,
|
||||||
|
with_file=False)
|
||||||
mock_print_chat_hist.assert_called_once_with('test_chat',
|
mock_print_chat_hist.assert_called_once_with('test_chat',
|
||||||
False,
|
False,
|
||||||
self.args.only_source_code)
|
self.args.source_code_only)
|
||||||
mock_ai.assert_called_with("test_chat",
|
mock_ai.assert_called_with("test_chat",
|
||||||
self.config,
|
self.config,
|
||||||
self.args.number)
|
self.args.num_answers)
|
||||||
expected_calls = []
|
expected_calls = []
|
||||||
for num, answer in enumerate(mock_ai.return_value[0], start=1):
|
for num, answer in enumerate(mock_ai.return_value[0], start=1):
|
||||||
title = f'-- ANSWER {num} '
|
title = f'-- ANSWER {num} '
|
||||||
@ -227,7 +230,7 @@ class TestCreateParser(CmmTestCase):
|
|||||||
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
|
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
|
||||||
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
|
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
|
||||||
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
|
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
|
||||||
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY)
|
mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
|
||||||
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
|
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
|
||||||
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
|
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
|
||||||
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user