Compare commits

...

2 Commits

6 changed files with 339 additions and 279 deletions

View File

@ -17,8 +17,9 @@ class OpenAI(AI):
The OpenAI AI client. The OpenAI AI client.
""" """
def __init__(self, name: str, config: OpenAIConfig) -> None: def __init__(self, config: OpenAIConfig) -> None:
self.name = name self.ai_type = config.ai_type
self.name = config.name
self.config = config self.config = config
def request(self, def request(self,
@ -31,8 +32,7 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
# FIXME: use real 'system' message (store in OpenAIConfig) oai_chat = self.openai_chat(chat, self.config.system, question)
oai_chat = self.openai_chat(chat, "system", question)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,

View File

@ -1,16 +1,26 @@
import yaml import yaml
from typing import Type, TypeVar, Any from pathlib import Path
from dataclasses import dataclass, asdict from typing import Type, TypeVar, Any, Optional
from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
class ConfigError(Exception):
pass
@dataclass @dataclass
class AIConfig: class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
ai_type: str
name: str name: str
@ -19,13 +29,18 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
api_key: str # all members have default values, so we can easily create
model: str # a default configuration
temperature: float ai_type: str = 'openai'
max_tokens: int name: str = 'openai_1'
top_p: float system: str = 'You are an assistant'
frequency_penalty: float api_key: str = '0123456789'
presence_penalty: float model: str = 'gpt-3.5'
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@ -33,7 +48,9 @@ class OpenAIConfig(AIConfig):
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
return cls( return cls(
name='OpenAI', ai_type='openai',
name=str(source['name']),
system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
@ -43,36 +60,79 @@ class OpenAIConfig(AIConfig):
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty'])
) )
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def ai_type_instance(ai_type: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given type.
"""
if ai_type.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"AI type '{ai_type}' is not supported")
def create_default_ai_configs() -> dict[str, AIConfig]:
"""
Create a dict containing default configurations for all supported AIs.
"""
return {ai_type_instance(ai_type).name: ai_type_instance(ai_type) for ai_type in supported_ais}
@dataclass @dataclass
class Config: class Config:
""" """
The configuration file structure. The configuration file structure.
""" """
system: str # all members have default values, so we can easily create
db: str # a default configuration
openai: OpenAIConfig db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
""" """
Create Config from a dict. Create Config from a dict.
""" """
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for name, conf in source['ais'].items():
ai_conf = ai_type_instance(conf['type'], conf)
ais[name] = ai_conf
return cls( return cls(
system=str(source['system']),
db=str(source['db']), db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai']) ais=ais
) )
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
# add the AI name to the config (for easy internal access)
for name, conf in source['ais'].items():
conf['name'] = name
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, path: str) -> None: def to_file(self, file_path: Path) -> None:
with open(path, 'w') as f: # remove the AI name from the config (for a cleaner format)
yaml.dump(asdict(self), f, sort_keys=False) data = self.as_dict()
for ai_name, ai_conf in data['ais'].items():
del (ai_conf['name'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -1,3 +1,4 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
@ -6,10 +7,9 @@ from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
from .test_main import CmmTestCase
class TestChat(CmmTestCase): class TestChat(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
@ -131,7 +131,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(CmmTestCase): class TestChatDB(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()

View File

@ -1,236 +1,236 @@
import unittest # import unittest
import io # import io
import pathlib # import pathlib
import argparse # import argparse
from chatmastermind.utils import terminal_width # from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd # from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai # from chatmastermind.api_client import ai
from chatmastermind.configuration import Config # from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data # from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock # from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY # from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase): # class CmmTestCase(unittest.TestCase):
""" # """
Base class for all cmm testcases. # Base class for all cmm testcases.
""" # """
def dummy_config(self, db: str) -> Config: # def dummy_config(self, db: str) -> Config:
""" # """
Creates a dummy configuration. # Creates a dummy configuration.
""" # """
return Config.from_dict( # return Config.from_dict(
{'system': 'dummy_system', # {'system': 'dummy_system',
'db': db, # 'db': db,
'openai': {'api_key': 'dummy_key', # 'openai': {'api_key': 'dummy_key',
'model': 'dummy_model', # 'model': 'dummy_model',
'max_tokens': 4000, # 'max_tokens': 4000,
'temperature': 1.0, # 'temperature': 1.0,
'top_p': 1, # 'top_p': 1,
'frequency_penalty': 0, # 'frequency_penalty': 0,
'presence_penalty': 0}} # 'presence_penalty': 0}}
) # )
#
#
class TestCreateChat(CmmTestCase): # class TestCreateChat(CmmTestCase):
#
def setUp(self) -> None: # def setUp(self) -> None:
self.config = self.dummy_config(db='test_files') # self.config = self.dummy_config(db='test_files')
self.question = "test question" # self.question = "test question"
self.tags = ['test_tag'] # self.tags = ['test_tag']
#
@patch('os.listdir') # @patch('os.listdir')
@patch('pathlib.Path.iterdir') # @patch('pathlib.Path.iterdir')
@patch('builtins.open') # @patch('builtins.open')
def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: # def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt'] # listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] # iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( # open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer', # {'question': 'test_content', 'answer': 'some answer',
'tags': ['test_tag']})) # 'tags': ['test_tag']}))
#
test_chat = create_chat_hist(self.question, self.tags, None, self.config) # test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
self.assertEqual(len(test_chat), 4) # self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0], # self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system}) # {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], # self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'}) # {'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2], # self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'}) # {'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3], # self.assertEqual(test_chat[3],
{'role': 'user', 'content': self.question}) # {'role': 'user', 'content': self.question})
#
@patch('os.listdir') # @patch('os.listdir')
@patch('pathlib.Path.iterdir') # @patch('pathlib.Path.iterdir')
@patch('builtins.open') # @patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: # def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt'] # listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] # iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( # open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer', # {'question': 'test_content', 'answer': 'some answer',
'tags': ['other_tag']})) # 'tags': ['other_tag']}))
#
test_chat = create_chat_hist(self.question, self.tags, None, self.config) # test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
self.assertEqual(len(test_chat), 2) # self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0], # self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system}) # {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], # self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question}) # {'role': 'user', 'content': self.question})
#
@patch('os.listdir') # @patch('os.listdir')
@patch('pathlib.Path.iterdir') # @patch('pathlib.Path.iterdir')
@patch('builtins.open') # @patch('builtins.open')
def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: # def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] # listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] # iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.side_effect = ( # open_mock.side_effect = (
io.StringIO(dump_data({'question': 'test_content', # io.StringIO(dump_data({'question': 'test_content',
'answer': 'some answer', # 'answer': 'some answer',
'tags': ['test_tag']})), # 'tags': ['test_tag']})),
io.StringIO(dump_data({'question': 'test_content2', # io.StringIO(dump_data({'question': 'test_content2',
'answer': 'some answer2', # 'answer': 'some answer2',
'tags': ['test_tag2']})), # 'tags': ['test_tag2']})),
) # )
#
test_chat = create_chat_hist(self.question, [], None, self.config) # test_chat = create_chat_hist(self.question, [], None, self.config)
#
self.assertEqual(len(test_chat), 6) # self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0], # self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system}) # {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], # self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'}) # {'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2], # self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'}) # {'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3], # self.assertEqual(test_chat[3],
{'role': 'user', 'content': 'test_content2'}) # {'role': 'user', 'content': 'test_content2'})
self.assertEqual(test_chat[4], # self.assertEqual(test_chat[4],
{'role': 'assistant', 'content': 'some answer2'}) # {'role': 'assistant', 'content': 'some answer2'})
#
#
class TestHandleQuestion(CmmTestCase): # class TestHandleQuestion(CmmTestCase):
#
def setUp(self) -> None: # def setUp(self) -> None:
self.question = "test question" # self.question = "test question"
self.args = argparse.Namespace( # self.args = argparse.Namespace(
or_tags=['tag1'], # or_tags=['tag1'],
and_tags=None, # and_tags=None,
exclude_tags=['xtag1'], # exclude_tags=['xtag1'],
output_tags=None, # output_tags=None,
question=[self.question], # question=[self.question],
source=None, # source=None,
source_code_only=False, # source_code_only=False,
num_answers=3, # num_answers=3,
max_tokens=None, # max_tokens=None,
temperature=None, # temperature=None,
model=None, # model=None,
match_all_tags=False, # match_all_tags=False,
with_tags=False, # with_tags=False,
with_file=False, # with_file=False,
) # )
self.config = self.dummy_config(db='test_files') # self.config = self.dummy_config(db='test_files')
#
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat") # @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args") # @patch("chatmastermind.main.print_tag_args")
@patch("chatmastermind.main.print_chat_hist") # @patch("chatmastermind.main.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) # @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp") # @patch("chatmastermind.utils.pp")
@patch("builtins.print") # @patch("builtins.print")
def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, # def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, # mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
mock_create_chat_hist: MagicMock) -> None: # mock_create_chat_hist: MagicMock) -> None:
open_mock = MagicMock() # open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): # with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config) # ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.or_tags, # mock_print_tag_args.assert_called_once_with(self.args.or_tags,
self.args.exclude_tags, # self.args.exclude_tags,
[]) # [])
mock_create_chat_hist.assert_called_once_with(self.question, # mock_create_chat_hist.assert_called_once_with(self.question,
self.args.or_tags, # self.args.or_tags,
self.args.exclude_tags, # self.args.exclude_tags,
self.config, # self.config,
match_all_tags=False, # match_all_tags=False,
with_tags=False, # with_tags=False,
with_file=False) # with_file=False)
mock_print_chat_hist.assert_called_once_with('test_chat', # mock_print_chat_hist.assert_called_once_with('test_chat',
False, # False,
self.args.source_code_only) # self.args.source_code_only)
mock_ai.assert_called_with("test_chat", # mock_ai.assert_called_with("test_chat",
self.config, # self.config,
self.args.num_answers) # self.args.num_answers)
expected_calls = [] # expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1): # for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} ' # title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title)) # title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),)) # expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),)) # expected_calls.append(((answer,),))
expected_calls.append((("-" * terminal_width(),),)) # expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) # expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
self.assertEqual(mock_print.call_args_list, expected_calls) # self.assertEqual(mock_print.call_args_list, expected_calls)
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) # open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
open_mock.assert_has_calls(open_expected_calls, any_order=True) # open_mock.assert_has_calls(open_expected_calls, any_order=True)
#
#
class TestSaveAnswers(CmmTestCase): # class TestSaveAnswers(CmmTestCase):
@mock.patch('builtins.open') # @mock.patch('builtins.open')
@mock.patch('chatmastermind.storage.print') # @mock.patch('chatmastermind.storage.print')
def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: # def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
question = "Test question?" # question = "Test question?"
answers = ["Answer 1", "Answer 2"] # answers = ["Answer 1", "Answer 2"]
tags = ["tag1", "tag2"] # tags = ["tag1", "tag2"]
otags = ["otag1", "otag2"] # otags = ["otag1", "otag2"]
config = self.dummy_config(db='test_db') # config = self.dummy_config(db='test_db')
#
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ # with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
mock.patch('chatmastermind.storage.yaml.dump'), \ # mock.patch('chatmastermind.storage.yaml.dump'), \
mock.patch('io.StringIO') as stringio_mock: # mock.patch('io.StringIO') as stringio_mock:
stringio_instance = stringio_mock.return_value # stringio_instance = stringio_mock.return_value
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] # stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
save_answers(question, answers, tags, otags, config) # save_answers(question, answers, tags, otags, config)
#
open_calls = [ # open_calls = [
mock.call(pathlib.Path('test_db/.next'), 'r'), # mock.call(pathlib.Path('test_db/.next'), 'r'),
mock.call(pathlib.Path('test_db/.next'), 'w'), # mock.call(pathlib.Path('test_db/.next'), 'w'),
] # ]
open_mock.assert_has_calls(open_calls, any_order=True) # open_mock.assert_has_calls(open_calls, any_order=True)
#
#
class TestAI(CmmTestCase): # class TestAI(CmmTestCase):
#
@patch("openai.ChatCompletion.create") # @patch("openai.ChatCompletion.create")
def test_ai(self, mock_create: MagicMock) -> None: # def test_ai(self, mock_create: MagicMock) -> None:
mock_create.return_value = { # mock_create.return_value = {
'choices': [ # 'choices': [
{'message': {'content': 'response_text_1'}}, # {'message': {'content': 'response_text_1'}},
{'message': {'content': 'response_text_2'}} # {'message': {'content': 'response_text_2'}}
], # ],
'usage': {'tokens': 10} # 'usage': {'tokens': 10}
} # }
#
chat = [{"role": "system", "content": "hello ai"}] # chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy') # config = self.dummy_config(db='dummy')
config.openai.model = "text-davinci-002" # config.openai.model = "text-davinci-002"
config.openai.max_tokens = 150 # config.openai.max_tokens = 150
config.openai.temperature = 0.5 # config.openai.temperature = 0.5
#
result = ai(chat, config, 2) # result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'], # expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10}) # {'tokens': 10})
self.assertEqual(result, expected_result) # self.assertEqual(result, expected_result)
#
#
class TestCreateParser(CmmTestCase): # class TestCreateParser(CmmTestCase):
def test_create_parser(self) -> None: # def test_create_parser(self) -> None:
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: # with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_cmdparser = Mock() # mock_cmdparser = Mock()
mock_add_subparsers.return_value = mock_cmdparser # mock_add_subparsers.return_value = mock_cmdparser
parser = create_parser() # parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser) # self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) # mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) # self.assertTrue('.config.yaml' in parser.get_default('config'))

View File

@ -1,12 +1,12 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
from typing import cast from typing import cast
from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(CmmTestCase): class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None: def test_source_code_with_include_delims(self) -> None:
text = """ text = """
Some text before the code block Some text before the code block
@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase):
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
class QuestionTestCase(CmmTestCase): class QuestionTestCase(unittest.TestCase):
def test_question_with_header(self) -> None: def test_question_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Question(f"{Question.txt_header}\nWhat is your name?") Question(f"{Question.txt_header}\nWhat is your name?")
@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase):
self.assertEqual(question, "What is your favorite color?") self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase): class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") Answer(f"{Answer.txt_header}\nno")
@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase):
self.assertEqual(answer, "No") self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(CmmTestCase): class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -160,7 +160,7 @@ This is a question.
self.message_complete.file_path = self.file_path self.message_complete.file_path = self.file_path
class MessageToFileYamlTestCase(CmmTestCase): class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
class MessageFromFileTxtTestCase(CmmTestCase): class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -388,7 +388,7 @@ This is a question.
self.assertIsNone(message) self.assertIsNone(message)
class MessageFromFileYamlTestCase(CmmTestCase): class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertIsNone(message) self.assertIsNone(message)
class TagsFromFileTestCase(CmmTestCase): class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name) self.file_path_txt = pathlib.Path(self.file_txt.name)
@ -663,7 +663,7 @@ This is an answer.
self.assertSetEqual(tags, set()) self.assertSetEqual(tags, set())
class TagsFromDirTestCase(CmmTestCase): class TagsFromDirTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory()
@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase):
self.assertSetEqual(all_tags, set()) self.assertSetEqual(all_tags, set())
class MessageIDTestCase(CmmTestCase): class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase):
self.message_no_file_path.msg_id() self.message_no_file_path.msg_id()
class MessageHashTestCase(CmmTestCase): class MessageHashTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase):
self.assertIn(msg, msgs) self.assertIn(msg, msgs)
class MessageTagsStrTestCase(CmmTestCase): class MessageTagsStrTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase):
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(CmmTestCase): class MessageFilterTagsTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase):
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(CmmTestCase): class MessageInTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase):
self.assertFalse(message_in(self.message1, [self.message2])) self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(CmmTestCase): class MessageRenameTagsTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase):
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(CmmTestCase): class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),

View File

@ -1,8 +1,8 @@
from .test_main import CmmTestCase import unittest
from chatmastermind.tags import Tag, TagLine, TagError from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(CmmTestCase): class TestTag(unittest.TestCase):
def test_valid_tag(self) -> None: def test_valid_tag(self) -> None:
tag = Tag('mytag') tag = Tag('mytag')
self.assertEqual(tag, 'mytag') self.assertEqual(tag, 'mytag')
@ -18,7 +18,7 @@ class TestTag(CmmTestCase):
self.assertEqual(Tag.alternative_separators, [',']) self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(CmmTestCase): class TestTagLine(unittest.TestCase):
def test_valid_tagline(self) -> None: def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2') tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2')