test_question_cmd: modified tests to use '.msg' file suffix

This commit is contained in:
juk0de 2023-09-26 18:24:36 +02:00
parent 3bc5f7cd63
commit 3c1c9860a0

View File

@ -14,6 +14,9 @@ from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI from .test_common import TestWithFakeAI
msg_suffix = Message.file_suffix_write
class TestMessageCreate(TestWithFakeAI): class TestMessageCreate(TestWithFakeAI):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
@ -83,7 +86,7 @@ Aaaand again some text."""
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*')) return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
@ -231,7 +234,7 @@ class TestQuestionCmd(TestWithFakeAI):
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
class TestQuestionCmdAsk(TestQuestionCmd): class TestQuestionCmdAsk(TestQuestionCmd):
@ -330,14 +333,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question. Repeat a single question.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message # create a message
message = Message(Question(self.args.ask[0]), message = Message(Question(self.args.ask[0]),
Answer('Old Answer'), Answer('Old Answer'),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
message.to_file() chat.msg_write([message])
# repeat the last question (without overwriting) # repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path) # -> expect two identical messages (except for the file_path)
@ -353,8 +358,6 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# we expect the original message + the one with the new response # we expect the original message + the one with the new response
expected_responses = [message] + [expected_response] expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir)) print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
@ -366,16 +369,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question and overwrite the old one. Repeat a single question and overwrite the old one.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message # create a message
message = Message(Question(self.args.ask[0]), message = Message(Question(self.args.ask[0]),
Answer('Old Answer'), Answer('Old Answer'),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
message.to_file() chat.msg_write([message])
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
@ -405,16 +408,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question after an error. Repeat a single question after an error.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a question WITHOUT an answer # create a question WITHOUT an answer
# -> just like after an error, which is tested above # -> just like after an error, which is tested above
message = Message(Question(self.args.ask[0]), message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
message.to_file() chat.msg_write([message])
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
@ -445,16 +448,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question with new arguments. Repeat a single question with new arguments.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message # create a message
message = Message(Question(self.args.ask[0]), message = Message(Question(self.args.ask[0]),
Answer('Old Answer'), Answer('Old Answer'),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
message.to_file() chat.msg_write([message])
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
@ -483,16 +486,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question with new arguments, overwriting the old one. Repeat a single question with new arguments, overwriting the old one.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message # create a message
message = Message(Question(self.args.ask[0]), message = Message(Question(self.args.ask[0]),
Answer('Old Answer'), Answer('Old Answer'),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
message.to_file() chat.msg_write([message])
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
@ -520,29 +523,29 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat multiple questions. Repeat multiple questions.
""" """
mock_create_ai.side_effect = self.mock_create_ai mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# 1. === create three questions === # 1. === create three questions ===
# cached message without an answer # cached message without an answer
message1 = Message(Question(self.args.ask[0]), message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags, tags=self.args.output_tags,
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0001.txt') file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
# cached message with an answer # cached message with an answer
message2 = Message(Question(self.args.ask[0]), message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'), Answer('Old Answer'),
tags=self.args.output_tags, tags=self.args.output_tags,
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.cache_dir.name) / '0002.txt') file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
# DB message without an answer # DB message without an answer
message3 = Message(Question(self.args.ask[0]), message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags, tags=self.args.output_tags,
ai=self.args.AI, ai=self.args.AI,
model=self.args.model, model=self.args.model,
file_path=Path(self.db_dir.name) / '0003.txt') file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
message1.to_file() chat.msg_write([message1, message2, message3])
message2.to_file()
message3.to_file()
questions = [message1, message2, message3] questions = [message1, message2, message3]
expected_responses: list[Message] = [] expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config) fake_ai = self.mock_create_ai(self.args, self.config)
@ -566,8 +569,6 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.assertEqual(len(self.message_list(self.cache_dir)), 4) self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1) self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]] expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages) self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all # check that the DB message has not been modified at all