Compare commits

...

3 Commits

10 changed files with 210 additions and 608 deletions

View File

@ -33,18 +33,23 @@ class AI(Protocol):
The base class for AI clients. The base class for AI clients.
""" """
ID: str
name: str name: str
config: AIConfig config: AIConfig
def request(self, def request(self,
question: Message, question: Message,
context: Chat, chat: Chat,
num_answers: int = 1, num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse: otags: Optional[set[Tag]] = None) -> AIResponse:
""" """
Make an AI request, asking the given question with the given Make an AI request. Parameters:
context (i. e. chat history). The nr. of requested answers * question: the question to ask
corresponds to the nr. of messages in the 'AIResponse'. * chat: the chat history to be added as context
* num_answers: nr. of requested answers (corresponds
to the nr. of messages in the 'AIResponse')
* otags: the output tags, i. e. the tags that all
returned messages should contain
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -3,18 +3,24 @@ Creates different AI instances, based on the given configuration.
""" """
import argparse import argparse
from .configuration import Config from typing import cast
from .configuration import Config, OpenAIConfig, default_ai_ID
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai_cmm import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: def create_ai(args: argparse.Namespace, config: Config) -> AI:
""" """
Creates an AI subclass instance from the given args and configuration. Creates an AI subclass instance from the given args and configuration.
""" """
if args.ai == 'openai': if args.ai:
# FIXME: create actual 'OpenAIConfig' and set values from 'args' ai_conf = config.ais[args.ai]
# FIXME: use actual name from config elif default_ai_ID in config.ais:
return OpenAI("openai", config.openai) ai_conf = config.ais[default_ai_ID]
else:
raise AIError("No AI name given and no default exists")
if ai_conf.name == 'openai':
return OpenAI(cast(OpenAIConfig, ai_conf))
else: else:
raise AIError(f"AI '{args.ai}' is not supported") raise AIError(f"AI '{args.ai}' is not supported")

View File

@ -17,9 +17,11 @@ class OpenAI(AI):
The OpenAI AI client. The OpenAI AI client.
""" """
def __init__(self, name: str, config: OpenAIConfig) -> None: def __init__(self, config: OpenAIConfig) -> None:
self.name = name self.ID = config.ID
self.name = config.name
self.config = config self.config = config
openai.api_key = config.api_key
def request(self, def request(self,
question: Message, question: Message,
@ -31,8 +33,7 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
# FIXME: use real 'system' message (store in OpenAIConfig) oai_chat = self.openai_chat(chat, self.config.system, question)
oai_chat = self.openai_chat(chat, "system", question)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,

View File

@ -1,45 +0,0 @@
import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: ChatType,
config: Config,
number: int
) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
response = openai.ChatCompletion.create(
model=config.openai.model,
messages=chat,
temperature=config.openai.temperature,
max_tokens=config.openai.max_tokens,
top_p=config.openai.top_p,
n=number,
frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config.openai.presence_penalty)
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore

View File

@ -1,16 +1,28 @@
import yaml import yaml
from typing import Type, TypeVar, Any from pathlib import Path
from dataclasses import dataclass, asdict from typing import Type, TypeVar, Any, Optional, Final
from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_ai_ID: str = 'default'
default_config_path = '.config.yaml'
class ConfigError(Exception):
pass
@dataclass @dataclass
class AIConfig: class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
ID: str
name: str name: str
@ -19,21 +31,27 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
api_key: str # the name must not be changed
model: str name: Final[str] = 'openai'
temperature: float # all members have default values, so we can easily create
max_tokens: int # a default configuration
top_p: float ID: str = 'default'
frequency_penalty: float api_key: str = '0123456789'
presence_penalty: float system: str = 'You are an assistant'
model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
""" """
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
return cls( res = cls(
name='OpenAI', system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
@ -42,6 +60,33 @@ class OpenAIConfig(AIConfig):
frequency_penalty=float(source['frequency_penalty']), frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty'])
) )
# overwrite default ID if provided
if 'ID' in source:
res.ID = source['ID']
return res
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given name.
"""
if name.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"AI '{name}' is not supported")
def create_default_ai_configs() -> dict[str, AIConfig]:
"""
Create a dict containing default configurations for all supported AIs.
"""
return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais}
@dataclass @dataclass
@ -49,30 +94,49 @@ class Config:
""" """
The configuration file structure. The configuration file structure.
""" """
system: str # all members have default values, so we can easily create
db: str # a default configuration
openai: OpenAIConfig db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
""" """
Create Config from a dict. Create Config from a dict (with the same format as the config file).
""" """
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for ID, conf in source['ais'].items():
# add the AI ID to the config (for easy internal access)
conf['ID'] = ID
ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf
return cls( return cls(
system=str(source['system']),
db=str(source['db']), db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai']) ais=ais
) )
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, path: str) -> None: def to_file(self, file_path: Path) -> None:
with open(path, 'w') as f: # remove the AI name from the config (for a cleaner format)
yaml.dump(asdict(self), f, sort_keys=False) data = self.as_dict()
for conf in data['ais'].values():
del (conf['ID'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -6,61 +6,19 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
from pathlib import Path from pathlib import Path
from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType from .configuration import Config, default_config_path
from .storage import save_answers, create_chat_hist
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from .chat import ChatDB from .chat import ChatDB
from .message import Message, MessageFilter, MessageError, Question from .message import Message, MessageFilter, MessageError, Question
from .ai_factory import create_ai from .ai_factory import create_ai
from .ai import AI, AIResponse from .ai import AI, AIResponse
from itertools import zip_longest
from typing import Any from typing import Any
default_config = '.config.yaml'
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
config = Config.from_file(parsed_args.config) config = Config.from_file(parsed_args.config)
return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
def create_question_with_hist(args: argparse.Namespace,
config: Config,
) -> tuple[ChatType, str, list[str]]:
"""
Creates the "AI request", including the question and chat history as determined
by the specified tags.
"""
tags = args.or_tags or []
xtags = args.exclude_tags or []
otags = args.output_tags or []
if not args.source_code_only:
print_tag_args(tags, xtags, otags)
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
for question, source in zip_longest(question_list, source_list, fillvalue=None):
if question is not None and source is not None:
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
elif question is not None:
question_parts.append(question)
elif source is not None:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, xtags, config,
match_all_tags=True if args.and_tags else False, # FIXME
with_tags=False,
with_file=False)
return chat, full_question, tags
def tags_cmd(args: argparse.Namespace, config: Config) -> None: def tags_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tags' command. Handler for the 'tags' command.
@ -74,17 +32,12 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
# TODO: add renaming # TODO: add renaming
def config_cmd(args: argparse.Namespace, config: Config) -> None: def config_cmd(args: argparse.Namespace) -> None:
""" """
Handler for the 'config' command. Handler for the 'config' command.
""" """
if args.list_models: if args.create:
print_models() Config.create_default(Path(args.create))
elif args.print_model:
print(config.openai.model)
elif args.model:
config.openai.model = args.model
config.to_file(args.config)
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
@ -95,6 +48,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
db_path=Path(config.db)) db_path=Path(config.db))
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately
if args.ask or args.create: if args.ask or args.create:
# FIXME: add sources to the question
message = Message(question=Question(args.question), message = Message(question=Question(args.question),
tags=args.ouput_tags, # FIXME tags=args.ouput_tags, # FIXME
ai=args.ai, ai=args.ai,
@ -128,25 +82,6 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
pass pass
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'ask' command.
"""
if args.max_tokens:
config.openai.max_tokens = args.max_tokens
if args.temperature:
config.openai.temperature = args.temperature
if args.model:
config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.source_code_only)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.num_answers)
save_answers(question, answers, tags, otags, config)
print("-" * terminal_width())
print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: Config) -> None: def hist_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
@ -182,7 +117,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config) parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path)
# subcommand-parser # subcommand-parser
cmdparser = parser.add_subparsers(dest='command', cmdparser = parser.add_subparsers(dest='command',
@ -227,22 +162,6 @@ def create_parser() -> argparse.ArgumentParser:
question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true') action='store_true')
# 'ask' command parser
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
help="Ask a question.",
aliases=['a'])
ask_cmd_parser.set_defaults(func=ask_cmd)
ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask',
required=True)
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int,
default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true')
# 'hist' command parser # 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.", help="Print chat history.",
@ -278,7 +197,7 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true') action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model", config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true') action='store_true')
config_group.add_argument('-M', '--model', help="Set model in the config file") config_group.add_argument('-c', '--create', help="Create config with default settings in the given file")
# 'print' command parser # 'print' command parser
print_cmd_parser = cmdparser.add_parser('print', print_cmd_parser = cmdparser.add_parser('print',
@ -297,10 +216,11 @@ def main() -> int:
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
command = parser.parse_args() command = parser.parse_args()
if command.func == config_cmd:
command.func(command)
else:
config = Config.from_file(args.config) config = Config.from_file(args.config)
openai_api_key(config.openai.api_key)
command.func(command, config) command.func(command, config)
return 0 return 0

View File

@ -1,121 +0,0 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format)
separator = ',' if ',' in tagline else ' '
tags = [t.strip() for t in tagline.split(separator)]
if tags_only:
return {"tags": tags}
text = fd.read().strip().split('\n')
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags,
"file": fname.name}
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]],
config: Config
) -> None:
wtags = otags or tags
num, inum = 0, 0
next_fname = pathlib.Path(str(config.db)) / '.next'
try:
with open(next_fname, 'r') as f:
num = int(f.read())
except Exception:
pass
for answer in answers:
num += 1
inum += 1
title = f'-- ANSWER {inum} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
def create_chat_hist(question: Optional[str],
tags: Optional[list[str]],
extags: Optional[list[str]],
config: Config,
match_all_tags: bool = False,
with_tags: bool = False,
with_file: bool = False
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['file'] = file.name
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match: bool
if match_all_tags:
tags_match = not tags or set(tags).issubset(data_tags)
else:
tags_match = not tags or bool(data_tags.intersection(tags))
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat, with_tags, with_file)
if question:
append_message(chat, 'user', question)
return chat
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif file.suffix == '.txt':
data = read_file(file, tags_only=True)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return result
def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix)))

View File

@ -1,80 +0,0 @@
import shutil
from pprint import PrettyPrinter
from typing import Any
ChatType = list[dict[str, str]]
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = []
if tags:
printed_messages.append(f"Tags: {' '.join(tags)}")
if extags:
printed_messages.append(f"Excluding tags: {' '.join(extags)}")
if otags:
printed_messages.append(f"Output tags: {' '.join(otags)}")
if printed_messages:
print("\n".join(printed_messages))
print()
def append_message(chat: ChatType,
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: dict[str, str],
chat: ChatType,
with_tags: bool = False,
with_file: bool = False
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
if with_tags:
tags = " ".join(message['tags'])
append_message(chat, 'tags', tags)
if with_file:
append_message(chat, 'file', message['file'])
def display_source_code(content: str) -> None:
try:
content_start = content.index('```')
content_end = content.rindex('```')
if content_start + 3 < content_end:
print(content[content_start + 3:content_end].strip())
except ValueError:
pass
def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump:
pp(chat)
return
for message in chat:
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
if source_code:
display_source_code(message['content'])
continue
if message['role'] == 'user':
print('-' * terminal_width())
if text_too_long:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")

View File

@ -0,0 +1,88 @@
import os
import unittest
import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
ai_reference = OpenAIConfig()
self.assertEqual(ai_config.ID, ai_reference.ID)
self.assertEqual(ai_config.name, ai_reference.name)
self.assertEqual(ai_config.api_key, ai_reference.api_key)
self.assertEqual(ai_config.system, ai_reference.system)
self.assertEqual(ai_config.model, ai_reference.model)
self.assertEqual(ai_config.temperature, ai_reference.temperature)
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
self.assertEqual(ai_config.top_p, ai_reference.top_p)
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
conf_dict = {
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
self.assertEqual(ai_config.system, 'Custom system')
self.assertEqual(ai_config.api_key, '9876543210')
self.assertEqual(ai_config.model, 'custom_model')
self.assertEqual(ai_config.max_tokens, 5000)
self.assertAlmostEqual(ai_config.temperature, 0.5)
self.assertAlmostEqual(ai_config.top_p, 0.8)
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
with self.assertRaises(ConfigError):
ai_config_instance('invalid_name')
class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.test_file = NamedTemporaryFile(delete=False)
def tearDown(self) -> None:
os.remove(self.test_file.name)
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['default'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['default'].ID, 'default')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)

View File

@ -1,236 +0,0 @@
# import unittest
# import io
# import pathlib
# import argparse
# from chatmastermind.utils import terminal_width
# from chatmastermind.main import create_parser, ask_cmd
# from chatmastermind.api_client import ai
# from chatmastermind.configuration import Config
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data
# from unittest import mock
# from unittest.mock import patch, MagicMock, Mock, ANY
# class CmmTestCase(unittest.TestCase):
# """
# Base class for all cmm testcases.
# """
# def dummy_config(self, db: str) -> Config:
# """
# Creates a dummy configuration.
# """
# return Config.from_dict(
# {'system': 'dummy_system',
# 'db': db,
# 'openai': {'api_key': 'dummy_key',
# 'model': 'dummy_model',
# 'max_tokens': 4000,
# 'temperature': 1.0,
# 'top_p': 1,
# 'frequency_penalty': 0,
# 'presence_penalty': 0}}
# )
#
#
# class TestCreateChat(CmmTestCase):
#
# def setUp(self) -> None:
# self.config = self.dummy_config(db='test_files')
# self.question = "test question"
# self.tags = ['test_tag']
#
# @patch('os.listdir')
# @patch('pathlib.Path.iterdir')
# @patch('builtins.open')
# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
# {'question': 'test_content', 'answer': 'some answer',
# 'tags': ['test_tag']}))
#
# test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
# self.assertEqual(len(test_chat), 4)
# self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1],
# {'role': 'user', 'content': 'test_content'})
# self.assertEqual(test_chat[2],
# {'role': 'assistant', 'content': 'some answer'})
# self.assertEqual(test_chat[3],
# {'role': 'user', 'content': self.question})
#
# @patch('os.listdir')
# @patch('pathlib.Path.iterdir')
# @patch('builtins.open')
# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
# {'question': 'test_content', 'answer': 'some answer',
# 'tags': ['other_tag']}))
#
# test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
# self.assertEqual(len(test_chat), 2)
# self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1],
# {'role': 'user', 'content': self.question})
#
# @patch('os.listdir')
# @patch('pathlib.Path.iterdir')
# @patch('builtins.open')
# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.side_effect = (
# io.StringIO(dump_data({'question': 'test_content',
# 'answer': 'some answer',
# 'tags': ['test_tag']})),
# io.StringIO(dump_data({'question': 'test_content2',
# 'answer': 'some answer2',
# 'tags': ['test_tag2']})),
# )
#
# test_chat = create_chat_hist(self.question, [], None, self.config)
#
# self.assertEqual(len(test_chat), 6)
# self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1],
# {'role': 'user', 'content': 'test_content'})
# self.assertEqual(test_chat[2],
# {'role': 'assistant', 'content': 'some answer'})
# self.assertEqual(test_chat[3],
# {'role': 'user', 'content': 'test_content2'})
# self.assertEqual(test_chat[4],
# {'role': 'assistant', 'content': 'some answer2'})
#
#
# class TestHandleQuestion(CmmTestCase):
#
# def setUp(self) -> None:
# self.question = "test question"
# self.args = argparse.Namespace(
# or_tags=['tag1'],
# and_tags=None,
# exclude_tags=['xtag1'],
# output_tags=None,
# question=[self.question],
# source=None,
# source_code_only=False,
# num_answers=3,
# max_tokens=None,
# temperature=None,
# model=None,
# match_all_tags=False,
# with_tags=False,
# with_file=False,
# )
# self.config = self.dummy_config(db='test_files')
#
# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
# @patch("chatmastermind.main.print_tag_args")
# @patch("chatmastermind.main.print_chat_hist")
# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
# @patch("chatmastermind.utils.pp")
# @patch("builtins.print")
# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
# mock_create_chat_hist: MagicMock) -> None:
# open_mock = MagicMock()
# with patch("chatmastermind.storage.open", open_mock):
# ask_cmd(self.args, self.config)
# mock_print_tag_args.assert_called_once_with(self.args.or_tags,
# self.args.exclude_tags,
# [])
# mock_create_chat_hist.assert_called_once_with(self.question,
# self.args.or_tags,
# self.args.exclude_tags,
# self.config,
# match_all_tags=False,
# with_tags=False,
# with_file=False)
# mock_print_chat_hist.assert_called_once_with('test_chat',
# False,
# self.args.source_code_only)
# mock_ai.assert_called_with("test_chat",
# self.config,
# self.args.num_answers)
# expected_calls = []
# for num, answer in enumerate(mock_ai.return_value[0], start=1):
# title = f'-- ANSWER {num} '
# title_end = '-' * (terminal_width() - len(title))
# expected_calls.append(((f'{title}{title_end}',),))
# expected_calls.append(((answer,),))
# expected_calls.append((("-" * terminal_width(),),))
# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
# self.assertEqual(mock_print.call_args_list, expected_calls)
# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
# open_mock.assert_has_calls(open_expected_calls, any_order=True)
#
#
# class TestSaveAnswers(CmmTestCase):
# @mock.patch('builtins.open')
# @mock.patch('chatmastermind.storage.print')
# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
# question = "Test question?"
# answers = ["Answer 1", "Answer 2"]
# tags = ["tag1", "tag2"]
# otags = ["otag1", "otag2"]
# config = self.dummy_config(db='test_db')
#
# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
# mock.patch('chatmastermind.storage.yaml.dump'), \
# mock.patch('io.StringIO') as stringio_mock:
# stringio_instance = stringio_mock.return_value
# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
# save_answers(question, answers, tags, otags, config)
#
# open_calls = [
# mock.call(pathlib.Path('test_db/.next'), 'r'),
# mock.call(pathlib.Path('test_db/.next'), 'w'),
# ]
# open_mock.assert_has_calls(open_calls, any_order=True)
#
#
# class TestAI(CmmTestCase):
#
# @patch("openai.ChatCompletion.create")
# def test_ai(self, mock_create: MagicMock) -> None:
# mock_create.return_value = {
# 'choices': [
# {'message': {'content': 'response_text_1'}},
# {'message': {'content': 'response_text_2'}}
# ],
# 'usage': {'tokens': 10}
# }
#
# chat = [{"role": "system", "content": "hello ai"}]
# config = self.dummy_config(db='dummy')
# config.openai.model = "text-davinci-002"
# config.openai.max_tokens = 150
# config.openai.temperature = 0.5
#
# result = ai(chat, config, 2)
# expected_result = (['response_text_1', 'response_text_2'],
# {'tokens': 10})
# self.assertEqual(result, expected_result)
#
#
# class TestCreateParser(CmmTestCase):
# def test_create_parser(self) -> None:
# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
# mock_cmdparser = Mock()
# mock_add_subparsers.return_value = mock_cmdparser
# parser = create_parser()
# self.assertIsInstance(parser, argparse.ArgumentParser)
# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
# self.assertTrue('.config.yaml' in parser.get_default('config'))