Compare commits

...

5 Commits

4 changed files with 416 additions and 71 deletions

View File

@ -1,3 +1,4 @@
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
@ -51,7 +52,7 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
""" """
Creates a new message from the given arguments and writes it Create a new message from the given arguments and write it
to the cache directory. to the cache directory.
""" """
question_parts = [] question_parts = []
@ -73,10 +74,37 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
tags=args.output_tags, # FIXME tags=args.output_tags, # FIXME
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
chat.cache_add([message]) # only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the given AI, chat history, message and arguments.
Write the response(s) to the cache directory, without appending it to the
given chat history. Then print the response(s).
"""
# print history and message question before making the request
ai.print()
chat.print(paged=False)
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
@ -95,28 +123,29 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance # create the correct AI instance
ai: AI = create_ai(args, config) ai: AI = create_ai(args, config)
# === ASK ===
if args.ask: if args.ask:
ai.print() make_request(ai, chat, message, args)
chat.print(paged=False) # === REPEAT ===
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.msg_update([response.messages[0]])
chat.cache_add(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None: elif args.repeat is not None:
lmessage = chat.msg_latest() lmessage = chat.msg_latest(loc='cache')
assert lmessage if lmessage is None:
# TODO: repeat either the last question or the print("No message found to repeat!")
# one(s) given in 'args.repeat' (overwrite sys.exit(1)
# existing ones if 'args.overwrite' is True) else:
pass print(f"Repeating message '{lmessage.msg_id()}':")
# overwrite the latest message if requested or empty
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS ===
elif args.process is not None: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an
# answer or the one(s) given in 'args.process' # answer or the one(s) given in 'args.process'

View File

@ -222,12 +222,36 @@ class Message():
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __post_init__(self) -> None:
# convert some types that are often set wrong
if self.tags is not None and not isinstance(self.tags, set):
self.tags = set(self.tags)
if self.file_path is not None and not isinstance(self.file_path, pathlib.Path):
self.file_path = pathlib.Path(self.file_path)
def __hash__(self) -> int: def __hash__(self) -> int:
""" """
The hash value is computed based on immutable members. The hash value is computed based on immutable members.
""" """
return hash((self.question, self.answer)) return hash((self.question, self.answer))
def equals(self, other: MessageInst, tags: bool = True, ai: bool = True,
model: bool = True, file_path: bool = True, verbose: bool = False) -> bool:
"""
Compare this message with another one, including the metadata.
Return True if everything is identical, False otherwise.
"""
equal: bool = ((not tags or (self.tags == other.tags))
and (not ai or (self.ai == other.ai)) # noqa: W503
and (not model or (self.model == other.model)) # noqa: W503
and (not file_path or (self.file_path == other.file_path)) # noqa: W503
and (self == other)) # noqa: W503
if not equal and verbose:
print("Messages not equal:")
print(self)
print(other)
return equal
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """
@ -418,9 +442,6 @@ class Message():
output.append(self.answer) output.append(self.answer)
return '\n'.join(output) return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
""" """
Write a Message to the given file. Type is determined based on the suffix. Write a Message to the given file. Type is determined based on the suffix.
@ -540,6 +561,9 @@ class Message():
if self.tags: if self.tags:
self.tags = rename_tags(self.tags, tags_rename) self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str: def msg_id(self) -> str:
""" """
Returns an ID that is unique throughout all messages in the same (DB) directory. Returns an ID that is unique throughout all messages in the same (DB) directory.

View File

@ -10,7 +10,18 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError from chatmastermind.chat import Chat, ChatDB, ChatError
class TestChat(unittest.TestCase): class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestChat(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
@ -26,24 +37,24 @@ class TestChat(unittest.TestCase):
def test_unique_id(self) -> None: def test_unique_id(self) -> None:
# test with two identical messages # test with two identical messages
self.chat.msg_add([self.message1, self.message1]) self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id() self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages # test with two different messages
self.chat.msg_add([self.message2]) self.chat.msg_add([self.message2])
self.chat.msg_unique_id() self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None: def test_unique_content(self) -> None:
# test with two identical messages # test with two identical messages
self.chat.msg_add([self.message1, self.message1]) self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content() self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1]) self.assert_messages_equal(self.chat.messages, [self.message1])
# test with two different messages # test with two different messages
self.chat.msg_add([self.message2]) self.chat.msg_add([self.message2])
self.chat.msg_unique_content() self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2]) self.assert_messages_equal(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
@ -150,7 +161,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase): class TestChatDB(TestChatBase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()
@ -569,7 +580,7 @@ class TestChatDB(unittest.TestCase):
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)] search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3] expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all') result = chat_db.msg_find(search_names, loc='all')
self.assertSequenceEqual(result, expected_result) self.assert_messages_equal(result, expected_result)
def test_msg_latest(self) -> None: def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@ -595,47 +606,47 @@ class TestChatDB(unittest.TestCase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# add a new message, but only to the internal list # add a new message, but only to the internal list
new_message = Message(Question("What?")) new_message = Message(Question("What?"))
all_messages_mem = all_messages + [new_message] all_messages_mem = all_messages + [new_message]
chat_db.msg_add([new_message]) chat_db.msg_add([new_message])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages_mem)
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages_mem) self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages_mem)
# the nr. of messages on disk did not change -> expect old result # the nr. of messages on disk did not change -> expect old result
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# test with MessageFilter # test with MessageFilter
self.assertSequenceEqual(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})), self.assert_messages_equal(chat_db.msg_gather(loc='all', mfilter=MessageFilter(tags_or={Tag('tag1')})),
[self.message1]) [self.message1])
self.assertSequenceEqual(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})), self.assert_messages_equal(chat_db.msg_gather(loc='disk', mfilter=MessageFilter(tags_or={Tag('tag2')})),
[self.message2]) [self.message2])
self.assertSequenceEqual(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})), self.assert_messages_equal(chat_db.msg_gather(loc='cache', mfilter=MessageFilter(tags_or={Tag('tag3')})),
[]) [])
self.assertSequenceEqual(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")), self.assert_messages_equal(chat_db.msg_gather(loc='mem', mfilter=MessageFilter(question_contains="What")),
[new_message]) [new_message])
def test_msg_move_and_gather(self) -> None: def test_msg_move_and_gather(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
all_messages = [self.message1, self.message2, self.message3, self.message4] all_messages = [self.message1, self.message2, self.message3, self.message4]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
# move first message to the cache # move first message to the cache
chat_db.cache_move(self.message1) chat_db.cache_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), [self.message1]) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [self.message1])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4]) self.assert_messages_equal(chat_db.msg_gather(loc='db'), [self.message2, self.message3, self.message4])
self.assertSequenceEqual(chat_db.msg_gather(loc='all'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='all'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='disk'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='disk'), all_messages)
self.assertSequenceEqual(chat_db.msg_gather(loc='mem'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='mem'), all_messages)
# now move first message back to the DB # now move first message back to the DB
chat_db.db_move(self.message1) chat_db.db_move(self.message1)
self.assertSequenceEqual(chat_db.msg_gather(loc='cache'), []) self.assert_messages_equal(chat_db.msg_gather(loc='cache'), [])
self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(self.message1.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
self.assertSequenceEqual(chat_db.msg_gather(loc='db'), all_messages) self.assert_messages_equal(chat_db.msg_gather(loc='db'), all_messages)

View File

@ -3,23 +3,39 @@ import unittest
import argparse import argparse
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock from unittest import mock
from chatmastermind.commands.question import create_message from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class TestMessageCreate(unittest.TestCase): class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
Compare messages using more than just Question and Answer.
"""
self.assertEqual(len(msg1), len(msg2))
for m1, m2 in zip(msg1, msg2):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
class TestMessageCreate(TestQuestionCmdBase):
""" """
Test if messages created by the 'question' command have Test if messages created by the 'question' command have
the correct format. the correct format.
""" """
def setUp(self) -> None: def setUp(self) -> None:
# create ChatDB structure # create ChatDB structure
self.db_path = tempfile.TemporaryDirectory() self.db_dir = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_path.name)) db_path=Path(self.db_dir.name))
# create some messages # create some messages
self.message_text = Message(Question("What is this?"), self.message_text = Message(Question("What is this?"),
Answer("It is pure text")) Answer("It is pure text"))
@ -74,6 +90,7 @@ Aaaand again some text."""
os.remove(self.source_file1.name) os.remove(self.source_file1.name)
os.remove(self.source_file2.name) os.remove(self.source_file2.name)
os.remove(self.source_file3.name) os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
@ -81,10 +98,10 @@ Aaaand again some text."""
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args) create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0]) message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
@ -193,3 +210,267 @@ This is embedded source code.
It is embedded code It is embedded code
``` ```
""")) """))
class TestQuestionCmd(TestQuestionCmdBase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='openai',
model='gpt-3.5-turbo',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None,
overwrite=None
)
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version).
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_with_error(self, mock_create_ai: MagicMock) -> None:
"""
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_overwrite(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question and overwrite the old one.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (WITH overwriting)
# -> expect a single message afterwards
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question_after_error(self, mock_create_ai: MagicMock) -> None:
"""
Repeat a single question after an error.
"""
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# 2. repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
self.ai.request.side_effect = self.mock_request
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)