Splain main.py to several files.

This commit is contained in:
Oleksandr Kozachuk 2023-04-07 17:40:24 +02:00
parent 0470109434
commit b23a9f663f
6 changed files with 228 additions and 175 deletions

View File

@ -1,8 +1,8 @@
# ChatMastermind # ChatMastermind
ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes a relevant chat history for the next question. ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes relevant chat history for the next question.
The project uses the OpenAI API to generate responses, and stores the data in YAML files. It also allows you to filter the chat history based on tags, and supports autocompletion for tags. The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags.
## Requirements ## Requirements
@ -13,7 +13,7 @@ The project uses the OpenAI API to generate responses, and stores the data in YA
You can install these requirements using `pip`: You can install these requirements using `pip`:
``` ```bash
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@ -21,13 +21,13 @@ pip install -r requirements.txt
You can install the package with the requirements using `pip`: You can install the package with the requirements using `pip`:
``` ```bash
pip install . pip install .
``` ```
## Usage ## Usage
``` ```bash
cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]] cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]]
``` ```
@ -50,37 +50,37 @@ cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMP
1. Print the contents of a YAML file: 1. Print the contents of a YAML file:
``` ```bash
cmm -p example.yaml cmm -p example.yaml
``` ```
2. Ask a question: 2. Ask a question:
``` ```bash
cmm -q "What is the meaning of life?" -t philosophy -e religion cmm -q "What is the meaning of life?" -t philosophy -e religion
``` ```
3. Display the chat history as a Python structure: 3. Display the chat history as a Python structure:
``` ```bash
cmm -D cmm -D
``` ```
4. Display the chat history as readable text: 4. Display the chat history as readable text:
``` ```bash
cmm -d cmm -d
``` ```
5. Filter chat history by tags: 5. Filter chat history by tags:
``` ```bash
cmm -d -t tag1 tag2 cmm -d -t tag1 tag2
``` ```
6. Exclude chat history by tags: 6. Exclude chat history by tags:
``` ```bash
cmm -d -e tag3 tag4 cmm -d -e tag3 tag4
``` ```
@ -103,13 +103,12 @@ The configuration file (`.config.yaml`) should contain the following fields:
To activate autocompletion for tags, add the following line to your shell's configuration file (e.g., `.bashrc`, `.zshrc`, or `.profile`): To activate autocompletion for tags, add the following line to your shell's configuration file (e.g., `.bashrc`, `.zshrc`, or `.profile`):
``` ```bash
eval "$(register-python-argcomplete cmm)" eval "$(register-python-argcomplete cmm)"
``` ```
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `chatmastermind` script. After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `cmm` script.
## License ## License
This project is licensed under the terms of the WTFPL License. This project is licensed under the terms of the WTFPL License.

View File

@ -0,0 +1,24 @@
import openai
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def ai(chat: list[dict[str, str]],
config: dict,
number: int
) -> tuple[list[str], dict[str, int]]:
response = openai.ChatCompletion.create(
model=config['openai']['model'],
messages=chat,
temperature=config['openai']['temperature'],
max_tokens=config['openai']['max_tokens'],
top_p=config['openai']['top_p'],
n=number,
frequency_penalty=config['openai']['frequency_penalty'],
presence_penalty=config['openai']['presence_penalty'])
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore

View File

@ -3,19 +3,12 @@
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
import yaml import yaml
import io
import sys import sys
import shutil
import openai
import pathlib
import argcomplete import argcomplete
import argparse import argparse
from pprint import PrettyPrinter from .utils import terminal_width, pp, tags_completer, process_tags, display_chat
from typing import List, Dict, Any, Optional from .storage import save_answers, create_chat
from .api_client import ai, openai_api_key
terminal_size = shutil.get_terminal_size()
terminal_width = terminal_size.columns
pp = PrettyPrinter(width=terminal_width).pprint
def run_print_command(args: argparse.Namespace, config: dict) -> None: def run_print_command(args: argparse.Namespace, config: dict) -> None:
@ -24,143 +17,56 @@ def run_print_command(args: argparse.Namespace, config: dict) -> None:
pp(data) pp(data)
def process_tags(config: dict, tags: list, extags: list) -> None:
print(f"Tags: {', '.join(tags)}")
if len(extags) > 0:
print(f"Excluding tags: {', '.join(extags)}")
print()
def append_message(chat: List[Dict[str, str]],
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: Dict[str, str],
chat: List[Dict[str, str]]
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
def create_chat(question: Optional[str],
tags: Optional[List[str]],
extags: Optional[List[str]],
config: Dict[str, Any]
) -> List[Dict[str, str]]:
chat = []
append_message(chat, 'system', config['system'].strip())
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
if question:
append_message(chat, 'user', question)
return chat
def ai(chat: list[dict[str, str]],
config: dict,
number: int
) -> tuple[list[str], dict[str, int]]:
response = openai.ChatCompletion.create(
model=config['openai']['model'],
messages=chat,
temperature=config['openai']['temperature'],
max_tokens=config['openai']['max_tokens'],
top_p=config['openai']['top_p'],
n=number,
frequency_penalty=config['openai']['frequency_penalty'],
presence_penalty=config['openai']['presence_penalty'])
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore
def process_and_display_chat(args: argparse.Namespace, def process_and_display_chat(args: argparse.Namespace,
config: dict, config: dict,
dump: bool = False dump: bool = False
) -> tuple[list[dict[str, str]], list[str]]: ) -> tuple[list[dict[str, str]], str, list[str]]:
tags = args.tags or [] tags = args.tags or []
extags = args.extags or [] extags = args.extags or []
process_tags(config, tags, extags) process_tags(config, tags, extags)
chat = create_chat(args.question, tags, extags, config)
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
for question, source in zip(question_list, source_list):
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
if len(question_list) > len(source_list):
for question in question_list[len(source_list):]:
question_parts.append(question)
else:
for source in source_list[len(question_list):]:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
question = '\n\n'.join(question_parts)
chat = create_chat(question, tags, extags, config)
display_chat(chat, dump) display_chat(chat, dump)
return chat, tags return chat, question, tags
def display_chat(chat, dump=False) -> None:
if dump:
pp(chat)
return
for message in chat:
if message['role'] == 'user':
print('-' * terminal_width)
if len(message['content']) > terminal_width-len(message['role'])-2:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def handle_question(args: argparse.Namespace, def handle_question(args: argparse.Namespace,
config: dict, config: dict,
dump: bool = False dump: bool = False
) -> None: ) -> None:
chat, tags = process_and_display_chat(args, config, dump) chat, question, tags = process_and_display_chat(args, config, dump)
otags = args.output_tags or [] otags = args.output_tags or []
answers, usage = ai(chat, config, args.number) answers, usage = ai(chat, config, args.number)
save_answers(args.question, answers, tags, otags) save_answers(question, answers, tags, otags)
print("-" * terminal_width) print("-" * terminal_width())
print(f"Usage: {usage}") print(f"Usage: {usage}")
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]]
) -> None:
wtags = otags or tags
for num, answer in enumerate(answers, start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width - len(title))
print(f'{title}{title_end}')
print(answer)
with open(f"{num:02d}.yaml", "w") as fd:
with io.StringIO() as f:
yaml.dump({'question': question},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"question":', "question:", 1))
with io.StringIO() as f:
yaml.dump({'answer': answer},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
yaml.dump({'tags': wtags},
fd,
default_flow_style=False)
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
default_config = '.config.yaml' default_config = '.config.yaml'
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI") description="ChatMastermind is a Python application that automates conversation with AI")
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-p', '--print', help='YAML file to print') group.add_argument('-p', '--print', help='YAML file to print')
group.add_argument('-q', '--question', help='Question to ask') group.add_argument('-q', '--question', nargs='*', help='Question to ask')
group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true') group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true')
parser.add_argument('-c', '--config', help='Config file name.', default=default_config) parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
@ -168,6 +74,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
parser.add_argument('-M', '--model', help='Model to use') parser.add_argument('-M', '--model', help='Model to use')
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3) parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3)
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS') tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
tags_arg.completer = tags_completer # type: ignore tags_arg.completer = tags_completer # type: ignore
extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS') extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS')
@ -185,7 +92,7 @@ def main() -> int:
with open(args.config, 'r') as f: with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader)
openai.api_key = config['openai']['api_key'] openai_api_key(config['openai']['api_key'])
if args.max_tokens: if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens config['openai']['max_tokens'] = args.max_tokens
@ -208,22 +115,5 @@ def main() -> int:
return 0 return 0
def tags_completer(prefix, parsed_args, **kwargs):
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
result = []
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main()) sys.exit(main())

57
chatmastermind/storage.py Normal file
View File

@ -0,0 +1,57 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat
from typing import List, Dict, Any, Optional
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]]
) -> None:
wtags = otags or tags
for num, answer in enumerate(answers, start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
with open(f"{num:02d}.yaml", "w") as fd:
with io.StringIO() as f:
yaml.dump({'question': question},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"question":', "question:", 1))
with io.StringIO() as f:
yaml.dump({'answer': answer},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
yaml.dump({'tags': wtags},
fd,
default_flow_style=False)
def create_chat(question: Optional[str],
tags: Optional[List[str]],
extags: Optional[List[str]],
config: Dict[str, Any]
) -> List[Dict[str, str]]:
chat = []
append_message(chat, 'system', config['system'].strip())
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
if question:
append_message(chat, 'user', question)
return chat

65
chatmastermind/utils.py Normal file
View File

@ -0,0 +1,65 @@
import shutil
import yaml
import pathlib
from pprint import PrettyPrinter
from typing import List, Dict
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args, **kwargs) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def process_tags(config: dict, tags: list, extags: list) -> None:
print(f"Tags: {', '.join(tags)}")
if len(extags) > 0:
print(f"Excluding tags: {', '.join(extags)}")
print()
def append_message(chat: List[Dict[str, str]],
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: Dict[str, str],
chat: List[Dict[str, str]]
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
def display_chat(chat, dump=False) -> None:
if dump:
pp(chat)
return
for message in chat:
if message['role'] == 'user':
print('-' * (terminal_width()))
if len(message['content']) > terminal_width() - len(message['role']) - 2:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def tags_completer(prefix, parsed_args, **kwargs):
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
result = []
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))

View File

@ -3,10 +3,12 @@ import io
import os import os
import yaml import yaml
import argparse import argparse
import chatmastermind.main from chatmastermind.utils import terminal_width
from chatmastermind.main import create_chat, ai, handle_question, save_answers from chatmastermind.main import create_parser, handle_question
from chatmastermind.api_client import ai
from chatmastermind.storage import create_chat, save_answers
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock, Mock
class TestCreateChat(unittest.TestCase): class TestCreateChat(unittest.TestCase):
@ -86,11 +88,13 @@ class TestCreateChat(unittest.TestCase):
class TestHandleQuestion(unittest.TestCase): class TestHandleQuestion(unittest.TestCase):
def setUp(self): def setUp(self):
self.question = "test question"
self.args = argparse.Namespace( self.args = argparse.Namespace(
tags=['tag1'], tags=['tag1'],
extags=['extag1'], extags=['extag1'],
output_tags=None, output_tags=None,
question='test question', question=[self.question],
source=None,
number=3 number=3
) )
self.config = { self.config = {
@ -100,20 +104,19 @@ class TestHandleQuestion(unittest.TestCase):
@patch("chatmastermind.main.create_chat", return_value="test_chat") @patch("chatmastermind.main.create_chat", return_value="test_chat")
@patch("chatmastermind.main.process_tags") @patch("chatmastermind.main.process_tags")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
"test_usage")) @patch("chatmastermind.utils.pp")
@patch("chatmastermind.main.pp") @patch("builtins.print")
@patch("chatmastermind.main.print") @patch("chatmastermind.storage.yaml.dump")
@patch("chatmastermind.main.yaml.dump")
def test_handle_question(self, _, mock_print, mock_pp, mock_ai, def test_handle_question(self, _, mock_print, mock_pp, mock_ai,
mock_process_tags, mock_create_chat): mock_process_tags, mock_create_chat):
open_mock = MagicMock() open_mock = MagicMock()
with patch("chatmastermind.main.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
handle_question(self.args, self.config, True) handle_question(self.args, self.config, True)
mock_process_tags.assert_called_once_with(self.config, mock_process_tags.assert_called_once_with(self.config,
self.args.tags, self.args.tags,
self.args.extags) self.args.extags)
mock_create_chat.assert_called_once_with(self.args.question, mock_create_chat.assert_called_once_with(self.question,
self.args.tags, self.args.tags,
self.args.extags, self.args.extags,
self.config) self.config)
@ -124,15 +127,14 @@ class TestHandleQuestion(unittest.TestCase):
expected_calls = [] expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1): for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} ' title = f'-- ANSWER {num} '
title_end = '-' * (chatmastermind.main.terminal_width - len(title)) title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),)) expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),)) expected_calls.append(((answer,),))
expected_calls.append((("-" * chatmastermind.main.terminal_width,),)) expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
open_mock.assert_has_calls([ open_mock.assert_has_calls(
mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4) [mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)] + [
] + [mock.call().__enter__(), mock.call().__enter__(), mock.call().__exit__(None, None, None)] * 3,
mock.call().__exit__(None, None, None)] * 3,
any_order=True) any_order=True)
self.assertEqual(mock_print.call_args_list, expected_calls) self.assertEqual(mock_print.call_args_list, expected_calls)
@ -152,9 +154,9 @@ class TestSaveAnswers(unittest.TestCase):
def test_save_answers(self): def test_save_answers(self):
try: try:
self.assert_stdout(f"-- ANSWER 1 {'-'*(chatmastermind.main.terminal_width-12)}\n" self.assert_stdout(f"-- ANSWER 1 {'-'*(terminal_width()-12)}\n"
"AI is Artificial Intelligence\n" "AI is Artificial Intelligence\n"
f"-- ANSWER 2 {'-'*(chatmastermind.main.terminal_width-12)}\n" f"-- ANSWER 2 {'-'*(terminal_width()-12)}\n"
"AI is a simulation of human intelligence\n") "AI is a simulation of human intelligence\n")
for idx, answer in enumerate(self.answers, start=1): for idx, answer in enumerate(self.answers, start=1):
with open(f"{idx:02d}.yaml", "r") as file: with open(f"{idx:02d}.yaml", "r") as file:
@ -198,3 +200,19 @@ class TestAI(unittest.TestCase):
expected_result = (['response_text_1', 'response_text_2'], expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10}) {'tokens': 10})
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
class TestCreateParser(unittest.TestCase):
def test_create_parser(self):
with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group:
mock_group = Mock()
mock_add_mutually_exclusive_group.return_value = mock_group
parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print')
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
self.assertTrue('.config.yaml' in parser.get_default('config'))
self.assertEqual(parser.get_default('number'), 3)