diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py deleted file mode 100644 index 2c4a094..0000000 --- a/chatmastermind/api_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import openai - -from .utils import ChatType -from .configuration import Config - - -def openai_api_key(api_key: str) -> None: - openai.api_key = api_key - - -def print_models() -> None: - """ - Print all models supported by the current AI. - """ - not_ready = [] - for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): - if engine['ready']: - print(engine['id']) - else: - not_ready.append(engine['id']) - if len(not_ready) > 0: - print('\nNot ready: ' + ', '.join(not_ready)) - - -def ai(chat: ChatType, - config: Config, - number: int - ) -> tuple[list[str], dict[str, int]]: - """ - Make AI request with the given chat history and configuration. - Return AI response and tokens used. - """ - response = openai.ChatCompletion.create( - model=config.openai.model, - messages=chat, - temperature=config.openai.temperature, - max_tokens=config.openai.max_tokens, - top_p=config.openai.top_p, - n=number, - frequency_penalty=config.openai.frequency_penalty, - presence_penalty=config.openai.presence_penalty) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 67eafae..58ce9ed 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,61 +6,19 @@ import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType -from .storage import save_answers, create_chat_hist -from .api_client import ai, openai_api_key, print_models -from .configuration import Config +from .configuration import Config, default_config_path from .chat import ChatDB from .message import Message, MessageFilter, MessageError, Question from .ai_factory import create_ai from .ai import AI, AIResponse -from itertools import zip_longest from typing import Any -default_config = '.config.yaml' - def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) return get_tags_unique(config, prefix) -def create_question_with_hist(args: argparse.Namespace, - config: Config, - ) -> tuple[ChatType, str, list[str]]: - """ - Creates the "AI request", including the question and chat history as determined - by the specified tags. - """ - tags = args.or_tags or [] - xtags = args.exclude_tags or [] - otags = args.output_tags or [] - - if not args.source_code_only: - print_tag_args(tags, xtags, otags) - - question_parts = [] - question_list = args.question if args.question is not None else [] - source_list = args.source if args.source is not None else [] - - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: - question_parts.append(question) - elif source is not None: - with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") - - full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, xtags, config, - match_all_tags=True if args.and_tags else False, # FIXME - with_tags=False, - with_file=False) - return chat, full_question, tags - - def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. @@ -74,17 +32,12 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: # TODO: add renaming -def config_cmd(args: argparse.Namespace, config: Config) -> None: +def config_cmd(args: argparse.Namespace) -> None: """ Handler for the 'config' command. """ - if args.list_models: - print_models() - elif args.print_model: - print(config.openai.model) - elif args.model: - config.openai.model = args.model - config.to_file(args.config) + if args.create: + Config.create_default(Path(args.create)) def question_cmd(args: argparse.Namespace, config: Config) -> None: @@ -95,6 +48,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: db_path=Path(config.db)) # if it's a new question, create and store it immediately if args.ask or args.create: + # FIXME: add sources to the question message = Message(question=Question(args.question), tags=args.ouput_tags, # FIXME ai=args.ai, @@ -128,25 +82,6 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: pass -def ask_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'ask' command. - """ - if args.max_tokens: - config.openai.max_tokens = args.max_tokens - if args.temperature: - config.openai.temperature = args.temperature - if args.model: - config.openai.model = args.model - chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.source_code_only) - otags = args.output_tags or [] - answers, usage = ai(chat, config, args.num_answers) - save_answers(question, answers, tags, otags, config) - print("-" * terminal_width()) - print(f"Usage: {usage}") - - def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. @@ -190,7 +125,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-C', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -235,22 +170,6 @@ def create_parser() -> argparse.ArgumentParser: question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') - # 'ask' command parser - ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.", - aliases=['a']) - ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', - required=True) - ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') - # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], help="Print chat history.", @@ -286,7 +205,7 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') - config_group.add_argument('-M', '--model', help="Set model in the config file") + config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', @@ -315,11 +234,12 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = Config.from_file(args.config) - openai_api_key(config.openai.api_key) - - command.func(command, config) + if command.func == config_cmd: + command.func(command) + else: + config = Config.from_file(args.config) + command.func(command, config) return 0 diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py deleted file mode 100644 index 8b9ed97..0000000 --- a/chatmastermind/storage.py +++ /dev/null @@ -1,121 +0,0 @@ -import yaml -import io -import pathlib -from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config -from typing import Any, Optional - - -def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: - with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() - # also support tags separated by ',' (old format) - separator = ',' if ',' in tagline else ' ' - tags = [t.strip() for t in tagline.split(separator)] - if tags_only: - return {"tags": tags} - text = fd.read().strip().split('\n') - question_idx = text.index("=== QUESTION ===") + 1 - answer_idx = text.index("==== ANSWER ====") - question = "\n".join(text[question_idx:answer_idx]).strip() - answer = "\n".join(text[answer_idx + 1:]).strip() - return {"question": question, "answer": answer, "tags": tags, - "file": fname.name} - - -def dump_data(data: dict[str, Any]) -> str: - with io.StringIO() as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - return fd.getvalue() - - -def write_file(fname: str, data: dict[str, Any]) -> None: - with open(fname, "w") as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - - -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]], - config: Config - ) -> None: - wtags = otags or tags - num, inum = 0, 0 - next_fname = pathlib.Path(str(config.db)) / '.next' - try: - with open(next_fname, 'r') as f: - num = int(f.read()) - except Exception: - pass - for answer in answers: - num += 1 - inum += 1 - title = f'-- ANSWER {inum} ' - title_end = '-' * (terminal_width() - len(title)) - print(f'{title}{title_end}') - print(answer) - write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) - with open(next_fname, 'w') as f: - f.write(f'{num}') - - -def create_chat_hist(question: Optional[str], - tags: Optional[list[str]], - extags: Optional[list[str]], - config: Config, - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> ChatType: - chat: ChatType = [] - append_message(chat, 'system', str(config.system).strip()) - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data['file'] = file.name - elif file.suffix == '.txt': - data = read_file(file) - else: - continue - data_tags = set(data.get('tags', [])) - tags_match: bool - if match_all_tags: - tags_match = not tags or set(tags).issubset(data_tags) - else: - tags_match = not tags or bool(data_tags.intersection(tags)) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat, with_tags, with_file) - if question: - append_message(chat, 'user', question) - return chat - - -def get_tags(config: Config, prefix: Optional[str]) -> list[str]: - result = [] - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif file.suffix == '.txt': - data = read_file(file, tags_only=True) - else: - continue - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return result - - -def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: - return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py deleted file mode 100644 index 4135ae3..0000000 --- a/chatmastermind/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import shutil -from pprint import PrettyPrinter -from typing import Any - -ChatType = list[dict[str, str]] - - -def terminal_width() -> int: - return shutil.get_terminal_size().columns - - -def pp(*args: Any, **kwargs: Any) -> None: - return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) - - -def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: - """ - Prints the tags specified in the given args. - """ - printed_messages = [] - - if tags: - printed_messages.append(f"Tags: {' '.join(tags)}") - if extags: - printed_messages.append(f"Excluding tags: {' '.join(extags)}") - if otags: - printed_messages.append(f"Output tags: {' '.join(otags)}") - - if printed_messages: - print("\n".join(printed_messages)) - print() - - -def append_message(chat: ChatType, - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: dict[str, str], - chat: ChatType, - with_tags: bool = False, - with_file: bool = False - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - if with_tags: - tags = " ".join(message['tags']) - append_message(chat, 'tags', tags) - if with_file: - append_message(chat, 'file', message['file']) - - -def display_source_code(content: str) -> None: - try: - content_start = content.index('```') - content_start = content.index('\n', content_start) + 1 - content_end = content.rindex('```') - if content_start < content_end: - print(content[content_start:content_end].strip()) - except ValueError: - pass - - -def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: - if dump: - pp(chat) - return - for message in chat: - text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 - if source_code: - display_source_code(message['content']) - continue - if message['role'] == 'user': - print('-' * terminal_width()) - if text_too_long: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 91e6462..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,236 +0,0 @@ -# import unittest -# import io -# import pathlib -# import argparse -# from chatmastermind.utils import terminal_width -# from chatmastermind.main import create_parser, ask_cmd -# from chatmastermind.api_client import ai -# from chatmastermind.configuration import Config -# from chatmastermind.storage import create_chat_hist, save_answers, dump_data -# from unittest import mock -# from unittest.mock import patch, MagicMock, Mock, ANY - - -# class CmmTestCase(unittest.TestCase): -# """ -# Base class for all cmm testcases. -# """ -# def dummy_config(self, db: str) -> Config: -# """ -# Creates a dummy configuration. -# """ -# return Config.from_dict( -# {'system': 'dummy_system', -# 'db': db, -# 'openai': {'api_key': 'dummy_key', -# 'model': 'dummy_model', -# 'max_tokens': 4000, -# 'temperature': 1.0, -# 'top_p': 1, -# 'frequency_penalty': 0, -# 'presence_penalty': 0}} -# ) -# -# -# class TestCreateChat(CmmTestCase): -# -# def setUp(self) -> None: -# self.config = self.dummy_config(db='test_files') -# self.question = "test question" -# self.tags = ['test_tag'] -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['test_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 4) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['other_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 2) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.side_effect = ( -# io.StringIO(dump_data({'question': 'test_content', -# 'answer': 'some answer', -# 'tags': ['test_tag']})), -# io.StringIO(dump_data({'question': 'test_content2', -# 'answer': 'some answer2', -# 'tags': ['test_tag2']})), -# ) -# -# test_chat = create_chat_hist(self.question, [], None, self.config) -# -# self.assertEqual(len(test_chat), 6) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': 'test_content2'}) -# self.assertEqual(test_chat[4], -# {'role': 'assistant', 'content': 'some answer2'}) -# -# -# class TestHandleQuestion(CmmTestCase): -# -# def setUp(self) -> None: -# self.question = "test question" -# self.args = argparse.Namespace( -# or_tags=['tag1'], -# and_tags=None, -# exclude_tags=['xtag1'], -# output_tags=None, -# question=[self.question], -# source=None, -# source_code_only=False, -# num_answers=3, -# max_tokens=None, -# temperature=None, -# model=None, -# match_all_tags=False, -# with_tags=False, -# with_file=False, -# ) -# self.config = self.dummy_config(db='test_files') -# -# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") -# @patch("chatmastermind.main.print_tag_args") -# @patch("chatmastermind.main.print_chat_hist") -# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) -# @patch("chatmastermind.utils.pp") -# @patch("builtins.print") -# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, -# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, -# mock_create_chat_hist: MagicMock) -> None: -# open_mock = MagicMock() -# with patch("chatmastermind.storage.open", open_mock): -# ask_cmd(self.args, self.config) -# mock_print_tag_args.assert_called_once_with(self.args.or_tags, -# self.args.exclude_tags, -# []) -# mock_create_chat_hist.assert_called_once_with(self.question, -# self.args.or_tags, -# self.args.exclude_tags, -# self.config, -# match_all_tags=False, -# with_tags=False, -# with_file=False) -# mock_print_chat_hist.assert_called_once_with('test_chat', -# False, -# self.args.source_code_only) -# mock_ai.assert_called_with("test_chat", -# self.config, -# self.args.num_answers) -# expected_calls = [] -# for num, answer in enumerate(mock_ai.return_value[0], start=1): -# title = f'-- ANSWER {num} ' -# title_end = '-' * (terminal_width() - len(title)) -# expected_calls.append(((f'{title}{title_end}',),)) -# expected_calls.append(((answer,),)) -# expected_calls.append((("-" * terminal_width(),),)) -# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) -# self.assertEqual(mock_print.call_args_list, expected_calls) -# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) -# open_mock.assert_has_calls(open_expected_calls, any_order=True) -# -# -# class TestSaveAnswers(CmmTestCase): -# @mock.patch('builtins.open') -# @mock.patch('chatmastermind.storage.print') -# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: -# question = "Test question?" -# answers = ["Answer 1", "Answer 2"] -# tags = ["tag1", "tag2"] -# otags = ["otag1", "otag2"] -# config = self.dummy_config(db='test_db') -# -# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ -# mock.patch('chatmastermind.storage.yaml.dump'), \ -# mock.patch('io.StringIO') as stringio_mock: -# stringio_instance = stringio_mock.return_value -# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] -# save_answers(question, answers, tags, otags, config) -# -# open_calls = [ -# mock.call(pathlib.Path('test_db/.next'), 'r'), -# mock.call(pathlib.Path('test_db/.next'), 'w'), -# ] -# open_mock.assert_has_calls(open_calls, any_order=True) -# -# -# class TestAI(CmmTestCase): -# -# @patch("openai.ChatCompletion.create") -# def test_ai(self, mock_create: MagicMock) -> None: -# mock_create.return_value = { -# 'choices': [ -# {'message': {'content': 'response_text_1'}}, -# {'message': {'content': 'response_text_2'}} -# ], -# 'usage': {'tokens': 10} -# } -# -# chat = [{"role": "system", "content": "hello ai"}] -# config = self.dummy_config(db='dummy') -# config.openai.model = "text-davinci-002" -# config.openai.max_tokens = 150 -# config.openai.temperature = 0.5 -# -# result = ai(chat, config, 2) -# expected_result = (['response_text_1', 'response_text_2'], -# {'tokens': 10}) -# self.assertEqual(result, expected_result) -# -# -# class TestCreateParser(CmmTestCase): -# def test_create_parser(self) -> None: -# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: -# mock_cmdparser = Mock() -# mock_add_subparsers.return_value = mock_cmdparser -# parser = create_parser() -# self.assertIsInstance(parser, argparse.ArgumentParser) -# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) -# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) -# self.assertTrue('.config.yaml' in parser.get_default('config'))