Splain main.py to several files.

This commit is contained in:
Oleksandr Kozachuk 2023-04-07 17:40:24 +02:00
parent 0470109434
commit b23a9f663f
6 changed files with 228 additions and 175 deletions

View File

@ -1,8 +1,8 @@
# ChatMastermind
ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes a relevant chat history for the next question.
ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes relevant chat history for the next question.
The project uses the OpenAI API to generate responses, and stores the data in YAML files. It also allows you to filter the chat history based on tags, and supports autocompletion for tags.
The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags.
## Requirements
@ -13,7 +13,7 @@ The project uses the OpenAI API to generate responses, and stores the data in YA
You can install these requirements using `pip`:
```
```bash
pip install -r requirements.txt
```
@ -21,13 +21,13 @@ pip install -r requirements.txt
You can install the package with the requirements using `pip`:
```
```bash
pip install .
```
## Usage
```
```bash
cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]]
```
@ -50,37 +50,37 @@ cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMP
1. Print the contents of a YAML file:
```
```bash
cmm -p example.yaml
```
2. Ask a question:
```
```bash
cmm -q "What is the meaning of life?" -t philosophy -e religion
```
3. Display the chat history as a Python structure:
```
```bash
cmm -D
```
4. Display the chat history as readable text:
```
```bash
cmm -d
```
5. Filter chat history by tags:
```
```bash
cmm -d -t tag1 tag2
```
6. Exclude chat history by tags:
```
```bash
cmm -d -e tag3 tag4
```
@ -103,13 +103,12 @@ The configuration file (`.config.yaml`) should contain the following fields:
To activate autocompletion for tags, add the following line to your shell's configuration file (e.g., `.bashrc`, `.zshrc`, or `.profile`):
```
```bash
eval "$(register-python-argcomplete cmm)"
```
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `chatmastermind` script.
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `cmm` script.
## License
This project is licensed under the terms of the WTFPL License.
This project is licensed under the terms of the WTFPL License.

View File

@ -0,0 +1,24 @@
import openai
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def ai(chat: list[dict[str, str]],
config: dict,
number: int
) -> tuple[list[str], dict[str, int]]:
response = openai.ChatCompletion.create(
model=config['openai']['model'],
messages=chat,
temperature=config['openai']['temperature'],
max_tokens=config['openai']['max_tokens'],
top_p=config['openai']['top_p'],
n=number,
frequency_penalty=config['openai']['frequency_penalty'],
presence_penalty=config['openai']['presence_penalty'])
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore

View File

@ -3,19 +3,12 @@
# vim: set fileencoding=utf-8 :
import yaml
import io
import sys
import shutil
import openai
import pathlib
import argcomplete
import argparse
from pprint import PrettyPrinter
from typing import List, Dict, Any, Optional
terminal_size = shutil.get_terminal_size()
terminal_width = terminal_size.columns
pp = PrettyPrinter(width=terminal_width).pprint
from .utils import terminal_width, pp, tags_completer, process_tags, display_chat
from .storage import save_answers, create_chat
from .api_client import ai, openai_api_key
def run_print_command(args: argparse.Namespace, config: dict) -> None:
@ -24,143 +17,56 @@ def run_print_command(args: argparse.Namespace, config: dict) -> None:
pp(data)
def process_tags(config: dict, tags: list, extags: list) -> None:
print(f"Tags: {', '.join(tags)}")
if len(extags) > 0:
print(f"Excluding tags: {', '.join(extags)}")
print()
def append_message(chat: List[Dict[str, str]],
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: Dict[str, str],
chat: List[Dict[str, str]]
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
def create_chat(question: Optional[str],
tags: Optional[List[str]],
extags: Optional[List[str]],
config: Dict[str, Any]
) -> List[Dict[str, str]]:
chat = []
append_message(chat, 'system', config['system'].strip())
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
if question:
append_message(chat, 'user', question)
return chat
def ai(chat: list[dict[str, str]],
config: dict,
number: int
) -> tuple[list[str], dict[str, int]]:
response = openai.ChatCompletion.create(
model=config['openai']['model'],
messages=chat,
temperature=config['openai']['temperature'],
max_tokens=config['openai']['max_tokens'],
top_p=config['openai']['top_p'],
n=number,
frequency_penalty=config['openai']['frequency_penalty'],
presence_penalty=config['openai']['presence_penalty'])
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore
def process_and_display_chat(args: argparse.Namespace,
config: dict,
dump: bool = False
) -> tuple[list[dict[str, str]], list[str]]:
) -> tuple[list[dict[str, str]], str, list[str]]:
tags = args.tags or []
extags = args.extags or []
process_tags(config, tags, extags)
chat = create_chat(args.question, tags, extags, config)
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
for question, source in zip(question_list, source_list):
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
if len(question_list) > len(source_list):
for question in question_list[len(source_list):]:
question_parts.append(question)
else:
for source in source_list[len(question_list):]:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
question = '\n\n'.join(question_parts)
chat = create_chat(question, tags, extags, config)
display_chat(chat, dump)
return chat, tags
def display_chat(chat, dump=False) -> None:
if dump:
pp(chat)
return
for message in chat:
if message['role'] == 'user':
print('-' * terminal_width)
if len(message['content']) > terminal_width-len(message['role'])-2:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
return chat, question, tags
def handle_question(args: argparse.Namespace,
config: dict,
dump: bool = False
) -> None:
chat, tags = process_and_display_chat(args, config, dump)
chat, question, tags = process_and_display_chat(args, config, dump)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.number)
save_answers(args.question, answers, tags, otags)
print("-" * terminal_width)
save_answers(question, answers, tags, otags)
print("-" * terminal_width())
print(f"Usage: {usage}")
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]]
) -> None:
wtags = otags or tags
for num, answer in enumerate(answers, start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width - len(title))
print(f'{title}{title_end}')
print(answer)
with open(f"{num:02d}.yaml", "w") as fd:
with io.StringIO() as f:
yaml.dump({'question': question},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"question":', "question:", 1))
with io.StringIO() as f:
yaml.dump({'answer': answer},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
yaml.dump({'tags': wtags},
fd,
default_flow_style=False)
def create_parser() -> argparse.ArgumentParser:
default_config = '.config.yaml'
parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-p', '--print', help='YAML file to print')
group.add_argument('-q', '--question', help='Question to ask')
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true')
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
@ -168,6 +74,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
parser.add_argument('-M', '--model', help='Model to use')
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3)
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
tags_arg.completer = tags_completer # type: ignore
extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS')
@ -185,7 +92,7 @@ def main() -> int:
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
openai.api_key = config['openai']['api_key']
openai_api_key(config['openai']['api_key'])
if args.max_tokens:
config['openai']['max_tokens'] = args.max_tokens
@ -208,22 +115,5 @@ def main() -> int:
return 0
def tags_completer(prefix, parsed_args, **kwargs):
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
result = []
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))
if __name__ == '__main__':
sys.exit(main())

57
chatmastermind/storage.py Normal file
View File

@ -0,0 +1,57 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat
from typing import List, Dict, Any, Optional
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]]
) -> None:
wtags = otags or tags
for num, answer in enumerate(answers, start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
with open(f"{num:02d}.yaml", "w") as fd:
with io.StringIO() as f:
yaml.dump({'question': question},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"question":', "question:", 1))
with io.StringIO() as f:
yaml.dump({'answer': answer},
f,
default_style="|",
default_flow_style=False)
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
yaml.dump({'tags': wtags},
fd,
default_flow_style=False)
def create_chat(question: Optional[str],
tags: Optional[List[str]],
extags: Optional[List[str]],
config: Dict[str, Any]
) -> List[Dict[str, str]]:
chat = []
append_message(chat, 'system', config['system'].strip())
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data_tags = set(data.get('tags', []))
tags_match = \
not tags or data_tags.intersection(tags)
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat)
if question:
append_message(chat, 'user', question)
return chat

65
chatmastermind/utils.py Normal file
View File

@ -0,0 +1,65 @@
import shutil
import yaml
import pathlib
from pprint import PrettyPrinter
from typing import List, Dict
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args, **kwargs) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def process_tags(config: dict, tags: list, extags: list) -> None:
print(f"Tags: {', '.join(tags)}")
if len(extags) > 0:
print(f"Excluding tags: {', '.join(extags)}")
print()
def append_message(chat: List[Dict[str, str]],
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: Dict[str, str],
chat: List[Dict[str, str]]
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
def display_chat(chat, dump=False) -> None:
if dump:
pp(chat)
return
for message in chat:
if message['role'] == 'user':
print('-' * (terminal_width()))
if len(message['content']) > terminal_width() - len(message['role']) - 2:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def tags_completer(prefix, parsed_args, **kwargs):
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
result = []
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))

View File

@ -3,10 +3,12 @@ import io
import os
import yaml
import argparse
import chatmastermind.main
from chatmastermind.main import create_chat, ai, handle_question, save_answers
from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, handle_question
from chatmastermind.api_client import ai
from chatmastermind.storage import create_chat, save_answers
from unittest import mock
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, Mock
class TestCreateChat(unittest.TestCase):
@ -86,11 +88,13 @@ class TestCreateChat(unittest.TestCase):
class TestHandleQuestion(unittest.TestCase):
def setUp(self):
self.question = "test question"
self.args = argparse.Namespace(
tags=['tag1'],
extags=['extag1'],
output_tags=None,
question='test question',
question=[self.question],
source=None,
number=3
)
self.config = {
@ -100,20 +104,19 @@ class TestHandleQuestion(unittest.TestCase):
@patch("chatmastermind.main.create_chat", return_value="test_chat")
@patch("chatmastermind.main.process_tags")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"],
"test_usage"))
@patch("chatmastermind.main.pp")
@patch("chatmastermind.main.print")
@patch("chatmastermind.main.yaml.dump")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp")
@patch("builtins.print")
@patch("chatmastermind.storage.yaml.dump")
def test_handle_question(self, _, mock_print, mock_pp, mock_ai,
mock_process_tags, mock_create_chat):
open_mock = MagicMock()
with patch("chatmastermind.main.open", open_mock):
with patch("chatmastermind.storage.open", open_mock):
handle_question(self.args, self.config, True)
mock_process_tags.assert_called_once_with(self.config,
self.args.tags,
self.args.extags)
mock_create_chat.assert_called_once_with(self.args.question,
mock_create_chat.assert_called_once_with(self.question,
self.args.tags,
self.args.extags,
self.config)
@ -124,15 +127,14 @@ class TestHandleQuestion(unittest.TestCase):
expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (chatmastermind.main.terminal_width - len(title))
title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),))
expected_calls.append((("-" * chatmastermind.main.terminal_width,),))
expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
open_mock.assert_has_calls([
mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)
] + [mock.call().__enter__(),
mock.call().__exit__(None, None, None)] * 3,
open_mock.assert_has_calls(
[mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)] + [
mock.call().__enter__(), mock.call().__exit__(None, None, None)] * 3,
any_order=True)
self.assertEqual(mock_print.call_args_list, expected_calls)
@ -152,9 +154,9 @@ class TestSaveAnswers(unittest.TestCase):
def test_save_answers(self):
try:
self.assert_stdout(f"-- ANSWER 1 {'-'*(chatmastermind.main.terminal_width-12)}\n"
self.assert_stdout(f"-- ANSWER 1 {'-'*(terminal_width()-12)}\n"
"AI is Artificial Intelligence\n"
f"-- ANSWER 2 {'-'*(chatmastermind.main.terminal_width-12)}\n"
f"-- ANSWER 2 {'-'*(terminal_width()-12)}\n"
"AI is a simulation of human intelligence\n")
for idx, answer in enumerate(self.answers, start=1):
with open(f"{idx:02d}.yaml", "r") as file:
@ -198,3 +200,19 @@ class TestAI(unittest.TestCase):
expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10})
self.assertEqual(result, expected_result)
class TestCreateParser(unittest.TestCase):
def test_create_parser(self):
with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group:
mock_group = Mock()
mock_add_mutually_exclusive_group.return_value = mock_group
parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print')
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
self.assertTrue('.config.yaml' in parser.get_default('config'))
self.assertEqual(parser.get_default('number'), 3)