Compare commits
7 Commits
3b605abd14
...
3735e080c3
| Author | SHA1 | Date | |
|---|---|---|---|
| 3735e080c3 | |||
| 8e593a5fe2 | |||
| 3be42c0f18 | |||
| 1e92749bf3 | |||
| ab4d13de32 | |||
| ad81776b62 | |||
| e4f8520fda |
64
chatmastermind/ai.py
Normal file
64
chatmastermind/ai.py
Normal file
@ -0,0 +1,64 @@
|
||||
from dataclasses import dataclass
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol, Optional, Union
|
||||
from .configuration import AIConfig
|
||||
from .message import Message
|
||||
from .chat import Chat
|
||||
|
||||
|
||||
class AIError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokens:
|
||||
prompt: int = 0
|
||||
completion: int = 0
|
||||
total: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIResponse:
|
||||
"""
|
||||
The response to an AI request. Consists of one or more messages
|
||||
(each containing the question and a single answer) and the nr.
|
||||
of used tokens.
|
||||
"""
|
||||
messages: list[Message]
|
||||
tokens: Optional[Tokens] = None
|
||||
|
||||
|
||||
class AI(Protocol):
|
||||
"""
|
||||
The base class for AI clients.
|
||||
"""
|
||||
|
||||
name: str
|
||||
config: AIConfig
|
||||
|
||||
@abstractmethod
|
||||
def request(self,
|
||||
question: Message,
|
||||
context: Chat,
|
||||
num_answers: int = 1) -> AIResponse:
|
||||
"""
|
||||
Make an AI request, asking the given question with the given
|
||||
context (i. e. chat history). The nr. of requested answers
|
||||
corresponds to the nr. of messages in the 'AIResponse'.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def models(self) -> list[str]:
|
||||
"""
|
||||
Return all models supported by this AI.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||
"""
|
||||
Computes the nr. of AI language tokens for the given message
|
||||
or chat. Note that the computation may not be 100% accurate
|
||||
and is not implemented for all AIs.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
90
chatmastermind/ais/openai.py
Normal file
90
chatmastermind/ais/openai.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""
|
||||
Implements the OpenAI client classes and functions.
|
||||
"""
|
||||
import openai
|
||||
from typing import Optional
|
||||
from ..tags import Tag
|
||||
from ..message import Message, Answer
|
||||
from ..chat import Chat
|
||||
from ..ai import AI, AIResponse, Tokens
|
||||
from ..config import OpenAIConfig
|
||||
|
||||
ChatType = list[dict[str, str]]
|
||||
|
||||
|
||||
class OpenAI(AI):
|
||||
"""
|
||||
The OpenAI AI client.
|
||||
"""
|
||||
|
||||
config: OpenAIConfig
|
||||
|
||||
def request(self,
|
||||
question: Message,
|
||||
chat: Chat,
|
||||
num_answers: int = 1,
|
||||
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||
"""
|
||||
Make an AI request, asking the given question with the given
|
||||
chat history. The nr. of requested answers corresponds to the
|
||||
nr. of messages in the 'AIResponse'.
|
||||
"""
|
||||
oai_chat = self.openai_chat(chat, self.config.system, question)
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.config.model,
|
||||
messages=oai_chat,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=self.config.max_tokens,
|
||||
top_p=self.config.top_p,
|
||||
n=num_answers,
|
||||
frequency_penalty=self.config.frequency_penalty,
|
||||
presence_penalty=self.config.presence_penalty)
|
||||
answers: list[Message] = []
|
||||
for choice in response['choices']: # type: ignore
|
||||
answers.append(Message(question=question.question,
|
||||
answer=Answer(choice['message']['content']),
|
||||
tags=otags,
|
||||
ai=self.name,
|
||||
model=self.config.model))
|
||||
return AIResponse(answers, Tokens(response['usage']['prompt'],
|
||||
response['usage']['completion'],
|
||||
response['usage']['total']))
|
||||
|
||||
def models(self) -> list[str]:
|
||||
"""
|
||||
Return all models supported by this AI.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def print_models(self) -> None:
|
||||
"""
|
||||
Print all models supported by the current AI.
|
||||
"""
|
||||
not_ready = []
|
||||
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
||||
if engine['ready']:
|
||||
print(engine['id'])
|
||||
else:
|
||||
not_ready.append(engine['id'])
|
||||
if len(not_ready) > 0:
|
||||
print('\nNot ready: ' + ', '.join(not_ready))
|
||||
|
||||
def openai_chat(self, chat: Chat, system: str,
|
||||
question: Optional[Message] = None) -> ChatType:
|
||||
"""
|
||||
Create a chat history with system message in OpenAI format.
|
||||
Optionally append a new question.
|
||||
"""
|
||||
oai_chat: ChatType = []
|
||||
|
||||
def append(role: str, content: str) -> None:
|
||||
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||
|
||||
append('system', system)
|
||||
for message in chat.messages:
|
||||
if message.answer:
|
||||
append('user', message.question)
|
||||
append('assistant', message.answer)
|
||||
if question:
|
||||
append('user', question.question)
|
||||
return oai_chat
|
||||
@ -7,7 +7,7 @@ from pprint import PrettyPrinter
|
||||
from pydoc import pager
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
|
||||
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
|
||||
from .message import Message, MessageFilter, MessageError, message_in
|
||||
from .tags import Tag
|
||||
|
||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||
@ -170,18 +170,10 @@ class Chat:
|
||||
output: list[str] = []
|
||||
for message in self.messages:
|
||||
if source_code_only:
|
||||
output.extend(source_code(message.question, include_delims=True))
|
||||
output.append(message.to_str(source_code_only=True))
|
||||
continue
|
||||
output.append('-' * terminal_width())
|
||||
if with_tags:
|
||||
output.append(message.tags_str())
|
||||
if with_files:
|
||||
output.append('FILE: ' + str(message.file_path))
|
||||
output.append(Question.txt_header)
|
||||
output.append(message.question)
|
||||
if message.answer:
|
||||
output.append(Answer.txt_header)
|
||||
output.append(message.answer)
|
||||
output.append(message.to_str(with_tags, with_files))
|
||||
output.append('\n' + ('-' * terminal_width()) + '\n')
|
||||
if paged:
|
||||
print_paged('\n'.join(output))
|
||||
else:
|
||||
|
||||
@ -2,15 +2,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
|
||||
import yaml
|
||||
import sys
|
||||
import argcomplete
|
||||
import argparse
|
||||
import pathlib
|
||||
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
|
||||
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
|
||||
from pathlib import Path
|
||||
from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType
|
||||
from .storage import save_answers, create_chat_hist
|
||||
from .api_client import ai, openai_api_key, print_models
|
||||
from .configuration import Config
|
||||
from .chat import ChatDB
|
||||
from .message import Message, MessageFilter, MessageError
|
||||
from itertools import zip_longest
|
||||
from typing import Any
|
||||
|
||||
@ -18,9 +19,8 @@ default_config = '.config.yaml'
|
||||
|
||||
|
||||
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
||||
with open(parsed_args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return get_tags_unique(config, prefix)
|
||||
config = Config.from_file(parsed_args.config)
|
||||
return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
|
||||
|
||||
|
||||
def create_question_with_hist(args: argparse.Namespace,
|
||||
@ -31,11 +31,11 @@ def create_question_with_hist(args: argparse.Namespace,
|
||||
by the specified tags.
|
||||
"""
|
||||
tags = args.tags or []
|
||||
extags = args.extags or []
|
||||
etags = args.etags or []
|
||||
otags = args.output_tags or []
|
||||
|
||||
if not args.only_source_code:
|
||||
print_tag_args(tags, extags, otags)
|
||||
if not args.source_code_only:
|
||||
print_tag_args(tags, etags, otags)
|
||||
|
||||
question_parts = []
|
||||
question_list = args.question if args.question is not None else []
|
||||
@ -52,17 +52,24 @@ def create_question_with_hist(args: argparse.Namespace,
|
||||
question_parts.append(f"```\n{r.read().strip()}\n```")
|
||||
|
||||
full_question = '\n\n'.join(question_parts)
|
||||
chat = create_chat_hist(full_question, tags, extags, config,
|
||||
args.match_all_tags, False, False)
|
||||
chat = create_chat_hist(full_question, tags, etags, config,
|
||||
match_all_tags=True if args.atags else False, # FIXME
|
||||
with_tags=False,
|
||||
with_file=False)
|
||||
return chat, full_question, tags
|
||||
|
||||
|
||||
def tag_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'tag' command.
|
||||
Handler for the 'tags' command.
|
||||
"""
|
||||
chat = ChatDB.from_dir(cache_path=Path('.'),
|
||||
db_path=Path(config.db))
|
||||
if args.list:
|
||||
print_tags_frequency(get_tags(config, None))
|
||||
tags_freq = chat.tags_frequency(args.prefix, args.contain)
|
||||
for tag, freq in tags_freq.items():
|
||||
print(f"- {tag}: {freq}")
|
||||
# TODO: add renaming
|
||||
|
||||
|
||||
def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
@ -89,7 +96,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
if args.model:
|
||||
config.openai.model = args.model
|
||||
chat, question, tags = create_question_with_hist(args, config)
|
||||
print_chat_hist(chat, False, args.only_source_code)
|
||||
print_chat_hist(chat, False, args.source_code_only)
|
||||
otags = args.output_tags or []
|
||||
answers, usage = ai(chat, config, args.number)
|
||||
save_answers(question, answers, tags, otags, config)
|
||||
@ -101,33 +108,32 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'hist' command.
|
||||
"""
|
||||
tags = args.tags or []
|
||||
extags = args.extags or []
|
||||
|
||||
chat = create_chat_hist(None, tags, extags, config,
|
||||
args.match_all_tags,
|
||||
mfilter = MessageFilter(tags_or=args.tags,
|
||||
tags_and=args.atags,
|
||||
tags_not=args.etags,
|
||||
question_contains=args.question,
|
||||
answer_contains=args.answer)
|
||||
chat = ChatDB.from_dir(Path('.'),
|
||||
Path(config.db),
|
||||
mfilter=mfilter)
|
||||
chat.print(args.source_code_only,
|
||||
args.with_tags,
|
||||
args.with_files)
|
||||
print_chat_hist(chat, args.dump, args.only_source_code)
|
||||
|
||||
|
||||
def print_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'print' command.
|
||||
"""
|
||||
fname = pathlib.Path(args.file)
|
||||
if fname.suffix == '.yaml':
|
||||
with open(args.file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
elif fname.suffix == '.txt':
|
||||
data = read_file(fname)
|
||||
else:
|
||||
print(f"Unknown file type: {args.file}")
|
||||
fname = Path(args.file)
|
||||
try:
|
||||
message = Message.from_file(fname)
|
||||
if message:
|
||||
print(message.to_str(source_code_only=args.source_code_only))
|
||||
except MessageError:
|
||||
print(f"File is not a valid message: {args.file}")
|
||||
sys.exit(1)
|
||||
if args.only_source_code:
|
||||
display_source_code(data['answer'])
|
||||
else:
|
||||
print(dump_data(data).strip())
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
@ -144,18 +150,17 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
# a parent parser for all commands that support tag selection
|
||||
tag_parser = argparse.ArgumentParser(add_help=False)
|
||||
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
|
||||
help='List of tag names', metavar='TAGS')
|
||||
help='List of tag names (one must match)', metavar='TAGS')
|
||||
tag_arg.completer = tags_completer # type: ignore
|
||||
extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+',
|
||||
help='List of tag names to exclude', metavar='EXTAGS')
|
||||
extag_arg.completer = tags_completer # type: ignore
|
||||
atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+',
|
||||
help='List of tag names (all must match)', metavar='TAGS')
|
||||
atag_arg.completer = tags_completer # type: ignore
|
||||
etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+',
|
||||
help='List of tag names to exclude', metavar='ETAGS')
|
||||
etag_arg.completer = tags_completer # type: ignore
|
||||
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
||||
help='List of output tag names, default is input', metavar='OTAGS')
|
||||
otag_arg.completer = tags_completer # type: ignore
|
||||
tag_parser.add_argument('-a', '--match-all-tags',
|
||||
help="All given tags must match when selecting chat history entries",
|
||||
action='store_true')
|
||||
# enable autocompletion for tags
|
||||
|
||||
# 'ask' command parser
|
||||
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
|
||||
@ -170,7 +175,7 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
|
||||
default=1)
|
||||
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
|
||||
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
|
||||
ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
|
||||
action='store_true')
|
||||
|
||||
# 'hist' command parser
|
||||
@ -178,23 +183,25 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
help="Print chat history.",
|
||||
aliases=['h'])
|
||||
hist_cmd_parser.set_defaults(func=hist_cmd)
|
||||
hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure",
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
|
||||
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
|
||||
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
|
||||
|
||||
# 'tag' command parser
|
||||
tag_cmd_parser = cmdparser.add_parser('tag',
|
||||
# 'tags' command parser
|
||||
tags_cmd_parser = cmdparser.add_parser('tags',
|
||||
help="Manage tags.",
|
||||
aliases=['t'])
|
||||
tag_cmd_parser.set_defaults(func=tag_cmd)
|
||||
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
||||
tags_cmd_parser.set_defaults(func=tags_cmd)
|
||||
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
||||
action='store_true')
|
||||
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
|
||||
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
|
||||
|
||||
# 'config' command parser
|
||||
config_cmd_parser = cmdparser.add_parser('config',
|
||||
@ -210,11 +217,11 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
|
||||
# 'print' command parser
|
||||
print_cmd_parser = cmdparser.add_parser('print',
|
||||
help="Print files.",
|
||||
help="Print message files.",
|
||||
aliases=['p'])
|
||||
print_cmd_parser.set_defaults(func=print_cmd)
|
||||
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
|
||||
print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
|
||||
print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)',
|
||||
action='store_true')
|
||||
|
||||
argcomplete.autocomplete(parser)
|
||||
|
||||
@ -392,6 +392,30 @@ class Message():
|
||||
data[cls.file_yaml_key] = file_path
|
||||
return cls.from_dict(data)
|
||||
|
||||
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
|
||||
"""
|
||||
Return the current Message as a string.
|
||||
"""
|
||||
output: list[str] = []
|
||||
if source_code_only:
|
||||
# use the source code from answer only
|
||||
if self.answer:
|
||||
output.extend(self.answer.source_code(include_delims=True))
|
||||
return '\n'.join(output) if len(output) > 1 else ''
|
||||
if with_tags:
|
||||
output.append(self.tags_str())
|
||||
if with_file:
|
||||
output.append('FILE: ' + str(self.file_path))
|
||||
output.append(Question.txt_header)
|
||||
output.append(self.question)
|
||||
if self.answer:
|
||||
output.append(Answer.txt_header)
|
||||
output.append(self.answer)
|
||||
return '\n'.join(output)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_str(False, False, False)
|
||||
|
||||
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
|
||||
"""
|
||||
Write a Message to the given file. Type is determined based on the suffix.
|
||||
|
||||
@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
|
||||
print(message['content'])
|
||||
else:
|
||||
print(f"{message['role'].upper()}: {message['content']}")
|
||||
|
||||
|
||||
def print_tags_frequency(tags: list[str]) -> None:
|
||||
for tag in sorted(set(tags)):
|
||||
print(f"- {tag}: {tags.count(tag)}")
|
||||
|
||||
@ -66,16 +66,20 @@ class TestChat(CmmTestCase):
|
||||
def test_print(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.print(paged=False)
|
||||
expected_output = f"""{'-'*terminal_width()}
|
||||
{Question.txt_header}
|
||||
expected_output = f"""{Question.txt_header}
|
||||
Question 1
|
||||
{Answer.txt_header}
|
||||
Answer 1
|
||||
|
||||
{'-'*terminal_width()}
|
||||
|
||||
{Question.txt_header}
|
||||
Question 2
|
||||
{Answer.txt_header}
|
||||
Answer 2
|
||||
|
||||
{'-'*terminal_width()}
|
||||
|
||||
"""
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
@ -83,20 +87,24 @@ Answer 2
|
||||
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.print(paged=False, with_tags=True, with_files=True)
|
||||
expected_output = f"""{'-'*terminal_width()}
|
||||
{TagLine.prefix} atag1 btag2
|
||||
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||
FILE: 0001.txt
|
||||
{Question.txt_header}
|
||||
Question 1
|
||||
{Answer.txt_header}
|
||||
Answer 1
|
||||
|
||||
{'-'*terminal_width()}
|
||||
|
||||
{TagLine.prefix} btag2
|
||||
FILE: 0002.txt
|
||||
{Question.txt_header}
|
||||
Question 2
|
||||
{Answer.txt_header}
|
||||
Answer 2
|
||||
|
||||
{'-'*terminal_width()}
|
||||
|
||||
"""
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
|
||||
@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase):
|
||||
self.question = "test question"
|
||||
self.args = argparse.Namespace(
|
||||
tags=['tag1'],
|
||||
extags=['extag1'],
|
||||
atags=None,
|
||||
etags=['etag1'],
|
||||
output_tags=None,
|
||||
question=[self.question],
|
||||
source=None,
|
||||
only_source_code=False,
|
||||
source_code_only=False,
|
||||
number=3,
|
||||
max_tokens=None,
|
||||
temperature=None,
|
||||
@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase):
|
||||
with patch("chatmastermind.storage.open", open_mock):
|
||||
ask_cmd(self.args, self.config)
|
||||
mock_print_tag_args.assert_called_once_with(self.args.tags,
|
||||
self.args.extags,
|
||||
self.args.etags,
|
||||
[])
|
||||
mock_create_chat_hist.assert_called_once_with(self.question,
|
||||
self.args.tags,
|
||||
self.args.extags,
|
||||
self.args.etags,
|
||||
self.config,
|
||||
False, False, False)
|
||||
match_all_tags=False,
|
||||
with_tags=False,
|
||||
with_file=False)
|
||||
mock_print_chat_hist.assert_called_once_with('test_chat',
|
||||
False,
|
||||
self.args.only_source_code)
|
||||
self.args.source_code_only)
|
||||
mock_ai.assert_called_with("test_chat",
|
||||
self.config,
|
||||
self.args.number)
|
||||
@ -227,7 +230,7 @@ class TestCreateParser(CmmTestCase):
|
||||
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
|
||||
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
|
||||
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
|
||||
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY)
|
||||
mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
|
||||
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
|
||||
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
|
||||
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
||||
|
||||
@ -804,3 +804,27 @@ class MessageRenameTagsTestCase(CmmTestCase):
|
||||
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
|
||||
self.assertIsNotNone(self.message.tags)
|
||||
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
|
||||
|
||||
|
||||
class MessageToStrTestCase(CmmTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_to_str(self) -> None:
|
||||
expected_output = f"""{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer."""
|
||||
self.assertEqual(self.message.to_str(), expected_output)
|
||||
|
||||
def test_to_str_with_tags_and_file(self) -> None:
|
||||
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||
FILE: /tmp/foo/bla
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer."""
|
||||
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user