Splain main.py to several files.
This commit is contained in:
parent
0470109434
commit
b23a9f663f
29
README.md
29
README.md
@ -1,8 +1,8 @@
|
||||
# ChatMastermind
|
||||
|
||||
ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes a relevant chat history for the next question.
|
||||
ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes relevant chat history for the next question.
|
||||
|
||||
The project uses the OpenAI API to generate responses, and stores the data in YAML files. It also allows you to filter the chat history based on tags, and supports autocompletion for tags.
|
||||
The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags.
|
||||
|
||||
## Requirements
|
||||
|
||||
@ -13,7 +13,7 @@ The project uses the OpenAI API to generate responses, and stores the data in YA
|
||||
|
||||
You can install these requirements using `pip`:
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@ -21,13 +21,13 @@ pip install -r requirements.txt
|
||||
|
||||
You can install the package with the requirements using `pip`:
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]]
|
||||
```
|
||||
|
||||
@ -50,37 +50,37 @@ cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMP
|
||||
|
||||
1. Print the contents of a YAML file:
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm -p example.yaml
|
||||
```
|
||||
|
||||
2. Ask a question:
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm -q "What is the meaning of life?" -t philosophy -e religion
|
||||
```
|
||||
|
||||
3. Display the chat history as a Python structure:
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm -D
|
||||
```
|
||||
|
||||
4. Display the chat history as readable text:
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm -d
|
||||
```
|
||||
|
||||
5. Filter chat history by tags:
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm -d -t tag1 tag2
|
||||
```
|
||||
|
||||
6. Exclude chat history by tags:
|
||||
|
||||
```
|
||||
```bash
|
||||
cmm -d -e tag3 tag4
|
||||
```
|
||||
|
||||
@ -103,13 +103,12 @@ The configuration file (`.config.yaml`) should contain the following fields:
|
||||
|
||||
To activate autocompletion for tags, add the following line to your shell's configuration file (e.g., `.bashrc`, `.zshrc`, or `.profile`):
|
||||
|
||||
```
|
||||
```bash
|
||||
eval "$(register-python-argcomplete cmm)"
|
||||
```
|
||||
|
||||
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `chatmastermind` script.
|
||||
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `cmm` script.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the terms of the WTFPL License.
|
||||
|
||||
This project is licensed under the terms of the WTFPL License.
|
||||
24
chatmastermind/api_client.py
Normal file
24
chatmastermind/api_client.py
Normal file
@ -0,0 +1,24 @@
|
||||
import openai
|
||||
|
||||
|
||||
def openai_api_key(api_key: str) -> None:
|
||||
openai.api_key = api_key
|
||||
|
||||
|
||||
def ai(chat: list[dict[str, str]],
|
||||
config: dict,
|
||||
number: int
|
||||
) -> tuple[list[str], dict[str, int]]:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config['openai']['model'],
|
||||
messages=chat,
|
||||
temperature=config['openai']['temperature'],
|
||||
max_tokens=config['openai']['max_tokens'],
|
||||
top_p=config['openai']['top_p'],
|
||||
n=number,
|
||||
frequency_penalty=config['openai']['frequency_penalty'],
|
||||
presence_penalty=config['openai']['presence_penalty'])
|
||||
result = []
|
||||
for choice in response['choices']: # type: ignore
|
||||
result.append(choice['message']['content'].strip())
|
||||
return result, dict(response['usage']) # type: ignore
|
||||
@ -3,19 +3,12 @@
|
||||
# vim: set fileencoding=utf-8 :
|
||||
|
||||
import yaml
|
||||
import io
|
||||
import sys
|
||||
import shutil
|
||||
import openai
|
||||
import pathlib
|
||||
import argcomplete
|
||||
import argparse
|
||||
from pprint import PrettyPrinter
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
terminal_size = shutil.get_terminal_size()
|
||||
terminal_width = terminal_size.columns
|
||||
pp = PrettyPrinter(width=terminal_width).pprint
|
||||
from .utils import terminal_width, pp, tags_completer, process_tags, display_chat
|
||||
from .storage import save_answers, create_chat
|
||||
from .api_client import ai, openai_api_key
|
||||
|
||||
|
||||
def run_print_command(args: argparse.Namespace, config: dict) -> None:
|
||||
@ -24,143 +17,56 @@ def run_print_command(args: argparse.Namespace, config: dict) -> None:
|
||||
pp(data)
|
||||
|
||||
|
||||
def process_tags(config: dict, tags: list, extags: list) -> None:
|
||||
print(f"Tags: {', '.join(tags)}")
|
||||
if len(extags) > 0:
|
||||
print(f"Excluding tags: {', '.join(extags)}")
|
||||
print()
|
||||
|
||||
|
||||
def append_message(chat: List[Dict[str, str]],
|
||||
role: str,
|
||||
content: str
|
||||
) -> None:
|
||||
chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||
|
||||
|
||||
def message_to_chat(message: Dict[str, str],
|
||||
chat: List[Dict[str, str]]
|
||||
) -> None:
|
||||
append_message(chat, 'user', message['question'])
|
||||
append_message(chat, 'assistant', message['answer'])
|
||||
|
||||
|
||||
def create_chat(question: Optional[str],
|
||||
tags: Optional[List[str]],
|
||||
extags: Optional[List[str]],
|
||||
config: Dict[str, Any]
|
||||
) -> List[Dict[str, str]]:
|
||||
chat = []
|
||||
append_message(chat, 'system', config['system'].strip())
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data_tags = set(data.get('tags', []))
|
||||
tags_match = \
|
||||
not tags or data_tags.intersection(tags)
|
||||
extags_do_not_match = \
|
||||
not extags or not data_tags.intersection(extags)
|
||||
if tags_match and extags_do_not_match:
|
||||
message_to_chat(data, chat)
|
||||
if question:
|
||||
append_message(chat, 'user', question)
|
||||
return chat
|
||||
|
||||
|
||||
def ai(chat: list[dict[str, str]],
|
||||
config: dict,
|
||||
number: int
|
||||
) -> tuple[list[str], dict[str, int]]:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config['openai']['model'],
|
||||
messages=chat,
|
||||
temperature=config['openai']['temperature'],
|
||||
max_tokens=config['openai']['max_tokens'],
|
||||
top_p=config['openai']['top_p'],
|
||||
n=number,
|
||||
frequency_penalty=config['openai']['frequency_penalty'],
|
||||
presence_penalty=config['openai']['presence_penalty'])
|
||||
result = []
|
||||
for choice in response['choices']: # type: ignore
|
||||
result.append(choice['message']['content'].strip())
|
||||
return result, dict(response['usage']) # type: ignore
|
||||
|
||||
|
||||
def process_and_display_chat(args: argparse.Namespace,
|
||||
config: dict,
|
||||
dump: bool = False
|
||||
) -> tuple[list[dict[str, str]], list[str]]:
|
||||
) -> tuple[list[dict[str, str]], str, list[str]]:
|
||||
tags = args.tags or []
|
||||
extags = args.extags or []
|
||||
process_tags(config, tags, extags)
|
||||
chat = create_chat(args.question, tags, extags, config)
|
||||
|
||||
question_parts = []
|
||||
question_list = args.question if args.question is not None else []
|
||||
source_list = args.source if args.source is not None else []
|
||||
|
||||
for question, source in zip(question_list, source_list):
|
||||
with open(source) as r:
|
||||
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
|
||||
|
||||
if len(question_list) > len(source_list):
|
||||
for question in question_list[len(source_list):]:
|
||||
question_parts.append(question)
|
||||
else:
|
||||
for source in source_list[len(question_list):]:
|
||||
with open(source) as r:
|
||||
question_parts.append(f"```\n{r.read().strip()}\n```")
|
||||
|
||||
question = '\n\n'.join(question_parts)
|
||||
|
||||
chat = create_chat(question, tags, extags, config)
|
||||
display_chat(chat, dump)
|
||||
return chat, tags
|
||||
|
||||
|
||||
def display_chat(chat, dump=False) -> None:
|
||||
if dump:
|
||||
pp(chat)
|
||||
return
|
||||
for message in chat:
|
||||
if message['role'] == 'user':
|
||||
print('-' * terminal_width)
|
||||
if len(message['content']) > terminal_width-len(message['role'])-2:
|
||||
print(f"{message['role'].upper()}:")
|
||||
print(message['content'])
|
||||
else:
|
||||
print(f"{message['role'].upper()}: {message['content']}")
|
||||
return chat, question, tags
|
||||
|
||||
|
||||
def handle_question(args: argparse.Namespace,
|
||||
config: dict,
|
||||
dump: bool = False
|
||||
) -> None:
|
||||
chat, tags = process_and_display_chat(args, config, dump)
|
||||
chat, question, tags = process_and_display_chat(args, config, dump)
|
||||
otags = args.output_tags or []
|
||||
answers, usage = ai(chat, config, args.number)
|
||||
save_answers(args.question, answers, tags, otags)
|
||||
print("-" * terminal_width)
|
||||
save_answers(question, answers, tags, otags)
|
||||
print("-" * terminal_width())
|
||||
print(f"Usage: {usage}")
|
||||
|
||||
|
||||
def save_answers(question: str,
|
||||
answers: list[str],
|
||||
tags: list[str],
|
||||
otags: Optional[list[str]]
|
||||
) -> None:
|
||||
wtags = otags or tags
|
||||
for num, answer in enumerate(answers, start=1):
|
||||
title = f'-- ANSWER {num} '
|
||||
title_end = '-' * (terminal_width - len(title))
|
||||
print(f'{title}{title_end}')
|
||||
print(answer)
|
||||
with open(f"{num:02d}.yaml", "w") as fd:
|
||||
with io.StringIO() as f:
|
||||
yaml.dump({'question': question},
|
||||
f,
|
||||
default_style="|",
|
||||
default_flow_style=False)
|
||||
fd.write(f.getvalue().replace('"question":', "question:", 1))
|
||||
with io.StringIO() as f:
|
||||
yaml.dump({'answer': answer},
|
||||
f,
|
||||
default_style="|",
|
||||
default_flow_style=False)
|
||||
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
|
||||
yaml.dump({'tags': wtags},
|
||||
fd,
|
||||
default_flow_style=False)
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
default_config = '.config.yaml'
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ChatMastermind is a Python application that automates conversation with AI")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument('-p', '--print', help='YAML file to print')
|
||||
group.add_argument('-q', '--question', help='Question to ask')
|
||||
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
|
||||
group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
||||
group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true')
|
||||
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
|
||||
@ -168,6 +74,7 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
||||
parser.add_argument('-M', '--model', help='Model to use')
|
||||
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3)
|
||||
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
|
||||
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
|
||||
tags_arg.completer = tags_completer # type: ignore
|
||||
extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS')
|
||||
@ -185,7 +92,7 @@ def main() -> int:
|
||||
with open(args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
openai.api_key = config['openai']['api_key']
|
||||
openai_api_key(config['openai']['api_key'])
|
||||
|
||||
if args.max_tokens:
|
||||
config['openai']['max_tokens'] = args.max_tokens
|
||||
@ -208,22 +115,5 @@ def main() -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def tags_completer(prefix, parsed_args, **kwargs):
|
||||
with open(parsed_args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
result = []
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for tag in data.get('tags', []):
|
||||
if prefix and len(prefix) > 0:
|
||||
if tag.startswith(prefix):
|
||||
result.append(tag)
|
||||
else:
|
||||
result.append(tag)
|
||||
return list(set(result))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
|
||||
57
chatmastermind/storage.py
Normal file
57
chatmastermind/storage.py
Normal file
@ -0,0 +1,57 @@
|
||||
import yaml
|
||||
import io
|
||||
import pathlib
|
||||
from .utils import terminal_width, append_message, message_to_chat
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
def save_answers(question: str,
|
||||
answers: list[str],
|
||||
tags: list[str],
|
||||
otags: Optional[list[str]]
|
||||
) -> None:
|
||||
wtags = otags or tags
|
||||
for num, answer in enumerate(answers, start=1):
|
||||
title = f'-- ANSWER {num} '
|
||||
title_end = '-' * (terminal_width() - len(title))
|
||||
print(f'{title}{title_end}')
|
||||
print(answer)
|
||||
with open(f"{num:02d}.yaml", "w") as fd:
|
||||
with io.StringIO() as f:
|
||||
yaml.dump({'question': question},
|
||||
f,
|
||||
default_style="|",
|
||||
default_flow_style=False)
|
||||
fd.write(f.getvalue().replace('"question":', "question:", 1))
|
||||
with io.StringIO() as f:
|
||||
yaml.dump({'answer': answer},
|
||||
f,
|
||||
default_style="|",
|
||||
default_flow_style=False)
|
||||
fd.write(f.getvalue().replace('"answer":', "answer:", 1))
|
||||
yaml.dump({'tags': wtags},
|
||||
fd,
|
||||
default_flow_style=False)
|
||||
|
||||
|
||||
def create_chat(question: Optional[str],
|
||||
tags: Optional[List[str]],
|
||||
extags: Optional[List[str]],
|
||||
config: Dict[str, Any]
|
||||
) -> List[Dict[str, str]]:
|
||||
chat = []
|
||||
append_message(chat, 'system', config['system'].strip())
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data_tags = set(data.get('tags', []))
|
||||
tags_match = \
|
||||
not tags or data_tags.intersection(tags)
|
||||
extags_do_not_match = \
|
||||
not extags or not data_tags.intersection(extags)
|
||||
if tags_match and extags_do_not_match:
|
||||
message_to_chat(data, chat)
|
||||
if question:
|
||||
append_message(chat, 'user', question)
|
||||
return chat
|
||||
65
chatmastermind/utils.py
Normal file
65
chatmastermind/utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
import shutil
|
||||
import yaml
|
||||
import pathlib
|
||||
from pprint import PrettyPrinter
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def terminal_width() -> int:
|
||||
return shutil.get_terminal_size().columns
|
||||
|
||||
|
||||
def pp(*args, **kwargs) -> None:
|
||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def process_tags(config: dict, tags: list, extags: list) -> None:
|
||||
print(f"Tags: {', '.join(tags)}")
|
||||
if len(extags) > 0:
|
||||
print(f"Excluding tags: {', '.join(extags)}")
|
||||
print()
|
||||
|
||||
|
||||
def append_message(chat: List[Dict[str, str]],
|
||||
role: str,
|
||||
content: str
|
||||
) -> None:
|
||||
chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||
|
||||
|
||||
def message_to_chat(message: Dict[str, str],
|
||||
chat: List[Dict[str, str]]
|
||||
) -> None:
|
||||
append_message(chat, 'user', message['question'])
|
||||
append_message(chat, 'assistant', message['answer'])
|
||||
|
||||
|
||||
def display_chat(chat, dump=False) -> None:
|
||||
if dump:
|
||||
pp(chat)
|
||||
return
|
||||
for message in chat:
|
||||
if message['role'] == 'user':
|
||||
print('-' * (terminal_width()))
|
||||
if len(message['content']) > terminal_width() - len(message['role']) - 2:
|
||||
print(f"{message['role'].upper()}:")
|
||||
print(message['content'])
|
||||
else:
|
||||
print(f"{message['role'].upper()}: {message['content']}")
|
||||
|
||||
|
||||
def tags_completer(prefix, parsed_args, **kwargs):
|
||||
with open(parsed_args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
result = []
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for tag in data.get('tags', []):
|
||||
if prefix and len(prefix) > 0:
|
||||
if tag.startswith(prefix):
|
||||
result.append(tag)
|
||||
else:
|
||||
result.append(tag)
|
||||
return list(set(result))
|
||||
@ -3,10 +3,12 @@ import io
|
||||
import os
|
||||
import yaml
|
||||
import argparse
|
||||
import chatmastermind.main
|
||||
from chatmastermind.main import create_chat, ai, handle_question, save_answers
|
||||
from chatmastermind.utils import terminal_width
|
||||
from chatmastermind.main import create_parser, handle_question
|
||||
from chatmastermind.api_client import ai
|
||||
from chatmastermind.storage import create_chat, save_answers
|
||||
from unittest import mock
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
|
||||
class TestCreateChat(unittest.TestCase):
|
||||
@ -86,11 +88,13 @@ class TestCreateChat(unittest.TestCase):
|
||||
class TestHandleQuestion(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.question = "test question"
|
||||
self.args = argparse.Namespace(
|
||||
tags=['tag1'],
|
||||
extags=['extag1'],
|
||||
output_tags=None,
|
||||
question='test question',
|
||||
question=[self.question],
|
||||
source=None,
|
||||
number=3
|
||||
)
|
||||
self.config = {
|
||||
@ -100,20 +104,19 @@ class TestHandleQuestion(unittest.TestCase):
|
||||
|
||||
@patch("chatmastermind.main.create_chat", return_value="test_chat")
|
||||
@patch("chatmastermind.main.process_tags")
|
||||
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"],
|
||||
"test_usage"))
|
||||
@patch("chatmastermind.main.pp")
|
||||
@patch("chatmastermind.main.print")
|
||||
@patch("chatmastermind.main.yaml.dump")
|
||||
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
||||
@patch("chatmastermind.utils.pp")
|
||||
@patch("builtins.print")
|
||||
@patch("chatmastermind.storage.yaml.dump")
|
||||
def test_handle_question(self, _, mock_print, mock_pp, mock_ai,
|
||||
mock_process_tags, mock_create_chat):
|
||||
open_mock = MagicMock()
|
||||
with patch("chatmastermind.main.open", open_mock):
|
||||
with patch("chatmastermind.storage.open", open_mock):
|
||||
handle_question(self.args, self.config, True)
|
||||
mock_process_tags.assert_called_once_with(self.config,
|
||||
self.args.tags,
|
||||
self.args.extags)
|
||||
mock_create_chat.assert_called_once_with(self.args.question,
|
||||
mock_create_chat.assert_called_once_with(self.question,
|
||||
self.args.tags,
|
||||
self.args.extags,
|
||||
self.config)
|
||||
@ -124,15 +127,14 @@ class TestHandleQuestion(unittest.TestCase):
|
||||
expected_calls = []
|
||||
for num, answer in enumerate(mock_ai.return_value[0], start=1):
|
||||
title = f'-- ANSWER {num} '
|
||||
title_end = '-' * (chatmastermind.main.terminal_width - len(title))
|
||||
title_end = '-' * (terminal_width() - len(title))
|
||||
expected_calls.append(((f'{title}{title_end}',),))
|
||||
expected_calls.append(((answer,),))
|
||||
expected_calls.append((("-" * chatmastermind.main.terminal_width,),))
|
||||
expected_calls.append((("-" * terminal_width(),),))
|
||||
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
||||
open_mock.assert_has_calls([
|
||||
mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)
|
||||
] + [mock.call().__enter__(),
|
||||
mock.call().__exit__(None, None, None)] * 3,
|
||||
open_mock.assert_has_calls(
|
||||
[mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)] + [
|
||||
mock.call().__enter__(), mock.call().__exit__(None, None, None)] * 3,
|
||||
any_order=True)
|
||||
self.assertEqual(mock_print.call_args_list, expected_calls)
|
||||
|
||||
@ -152,9 +154,9 @@ class TestSaveAnswers(unittest.TestCase):
|
||||
|
||||
def test_save_answers(self):
|
||||
try:
|
||||
self.assert_stdout(f"-- ANSWER 1 {'-'*(chatmastermind.main.terminal_width-12)}\n"
|
||||
self.assert_stdout(f"-- ANSWER 1 {'-'*(terminal_width()-12)}\n"
|
||||
"AI is Artificial Intelligence\n"
|
||||
f"-- ANSWER 2 {'-'*(chatmastermind.main.terminal_width-12)}\n"
|
||||
f"-- ANSWER 2 {'-'*(terminal_width()-12)}\n"
|
||||
"AI is a simulation of human intelligence\n")
|
||||
for idx, answer in enumerate(self.answers, start=1):
|
||||
with open(f"{idx:02d}.yaml", "r") as file:
|
||||
@ -198,3 +200,19 @@ class TestAI(unittest.TestCase):
|
||||
expected_result = (['response_text_1', 'response_text_2'],
|
||||
{'tokens': 10})
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
class TestCreateParser(unittest.TestCase):
|
||||
def test_create_parser(self):
|
||||
with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group:
|
||||
mock_group = Mock()
|
||||
mock_add_mutually_exclusive_group.return_value = mock_group
|
||||
parser = create_parser()
|
||||
self.assertIsInstance(parser, argparse.ArgumentParser)
|
||||
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
|
||||
mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print')
|
||||
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
|
||||
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
||||
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
|
||||
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
||||
self.assertEqual(parser.get_default('number'), 3)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user