Compare commits

..

5 Commits

3 changed files with 222 additions and 47 deletions

View File

@ -222,12 +222,36 @@ class Message():
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
The hash value is computed based on immutable members. The hash value is computed based on immutable members.
""" """
return hash((self.question, self.answer)) return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """
@ -418,9 +442,6 @@ class Message():
output.append(self.answer) output.append(self.answer)
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Write a Message to the given file. Type is determined based on the suffix. Write a Message to the given file. Type is determined based on the suffix.

View File

@ -10,7 +10,18 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChat(unittest.TestCase): class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
@ -26,24 +37,24 @@ class TestChat(unittest.TestCase):
def test_unique_id(self) -> None: def test_unique_id(self) -> None:
# test with two identical messages # test with two identical messages
self.chat.msg_add([self.message1, self.message1]) self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id() self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages # test with two different messages
self.chat.msg_add([self.message2]) self.chat.msg_add([self.message2])
self.chat.msg_unique_id() self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None: def test_unique_content(self) -> None:
# test with two identical messages # test with two identical messages
self.chat.msg_add([self.message1, self.message1]) self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content() self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages # test with two different messages
self.chat.msg_add([self.message2]) self.chat.msg_add([self.message2])
self.chat.msg_unique_content() self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
@ -150,7 +161,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase): class TestChatDB(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
@ -569,7 +580,7 @@ class TestChatDB(unittest.TestCase):
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all') result = chat_db.msg_find(search_names, loc='all')
self.assertSequenceEqual(result, expected_result) self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@ -595,47 +606,47 @@ class TestChatDB(unittest.TestCase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list # add a new message, but only to the internal list
new_message = Message(Question("What?")) new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message] all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message]) chat_db.msg_add([new_message])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result # the nr. of messages on disk did not change -> expect old result
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter # test with MessageFilter
self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1]) [self.message1])
self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2]) [self.message2])
self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[]) [])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message]) [new_message])
def test_msg_move_and_gather(self) -> None: def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache # move first message to the cache
chat_db.cache_move(self.message1) chat_db.cache_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1]) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB # now move first message back to the DB
chat_db.db_move(self.message1) chat_db.db_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)

View File

@ -11,10 +11,21 @@ from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import Chat, ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class TestMessageCreate(unittest.TestCase): class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestMessageCreate(TestQuestionCmdBase):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
@ -201,7 +212,7 @@ It is embedded code
""")) """))
class TestQuestionCmd(unittest.TestCase): class TestQuestionCmd(TestQuestionCmdBase):
def setUp(self) -> None: def setUp(self) -> None:
# create DB and cache # create DB and cache
@ -225,7 +236,8 @@ class TestQuestionCmd(unittest.TestCase):
source_code=None, source_code=None,
create=None, create=None,
repeat=None, repeat=None,
process=None process=None,
overwrite=None
) )
# create a mock AI instance # create a mock AI instance
self.ai = MagicMock(spec=AI) self.ai = MagicMock(spec=AI)
@ -254,7 +266,7 @@ class TestQuestionCmd(unittest.TestCase):
Mock the 'ai.request()' function Mock the 'ai.request()' function
""" """
question.answer = Answer("Answer 0") question.answer = Answer("Answer 0")
question.tags = otags question.tags = set(otags) if otags else None
question.ai = 'FakeAI' question.ai = 'FakeAI'
question.model = 'FakeModel' question.model = 'FakeModel'
answers: list[Message] = [question] answers: list[Message] = [question]
@ -273,7 +285,7 @@ class TestQuestionCmd(unittest.TestCase):
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
""" """
Test single answer with no errors Test single answer with no errors.
""" """
mock_create_ai.return_value = self.ai mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
@ -295,13 +307,13 @@ class TestQuestionCmd(unittest.TestCase):
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assertSequenceEqual(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
""" """
Test single answer with no errors (mocked ChatDB version) Test single answer with no errors (mocked ChatDB version).
""" """
chat = MagicMock(spec=ChatDB) chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat mock_from_dir.return_value = chat
@ -331,3 +343,134 @@ class TestQuestionCmd(unittest.TestCase):
# check that the messages have not been added to the internal message list # check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called() chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (WITH overwriting)
# -> expect a single message afterwards
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# 2. repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.ai.request.side_effect = self.mock_request
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)