Compare commits

..

7 Commits

3 changed files with 155 additions and 179 deletions

View File

@ -44,7 +44,7 @@ class OpenAI(AI):
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = set(otags) if otags is not None else None
question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]

View File

@ -106,7 +106,7 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(response.tokens)
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
def make_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
@ -118,7 +118,8 @@ def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespac
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
or args.output_tags is None # noqa: W503
or len(args.output_tags) == 0): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
@ -135,23 +136,22 @@ def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namesp
"""
ai: AI
for msg in messages:
msg_args = create_msg_args(msg, args)
ai = create_ai(msg_args, config)
ai = create_ai(make_msg_args(msg, args), config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
if ((msg.answer is None or msg_args.overwrite is True)
if ((msg.answer is None or args.overwrite is True)
and (not chat.msg_in_db(msg))): # noqa: W503
msg.clear_answer()
make_request(ai, chat, msg, msg_args)
make_request(ai, chat, msg, args)
# otherwise create a new one
else:
msg_args.ask = [msg.question]
message = create_message(chat, msg_args)
make_request(ai, chat, message, msg_args)
args.ask = [msg.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
def invert_input_tag_args(args: argparse.Namespace) -> None:
def modify_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
@ -165,13 +165,17 @@ def invert_input_tag_args(args: argparse.Namespace) -> None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
if args.exclude_tags is None:
args.exclude_tags = set()
elif len(args.exclude_tags) == 0:
args.exclude_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
invert_input_tag_args(args)
modify_tag_args(args)
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags)

View File

@ -4,9 +4,9 @@ import argparse
import tempfile
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, call
from typing import Optional, Union
from chatmastermind.configuration import Config, AIConfig
from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
@ -14,56 +14,6 @@ from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
@ -74,18 +24,6 @@ class TestQuestionCmdBase(unittest.TestCase):
# exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
class TestMessageCreate(TestQuestionCmdBase):
"""
@ -289,8 +227,8 @@ class TestQuestionCmd(TestQuestionCmdBase):
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='FakeAI',
model='FakeModel',
AI='openai',
model='gpt-3.5-turbo',
or_tags=None,
and_tags=None,
exclude_tags=None,
@ -301,17 +239,44 @@ class TestQuestionCmd(TestQuestionCmdBase):
process=None,
overwrite=None
)
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def create_single_message(self, args: argparse.Namespace, with_answer: bool = True) -> Message:
message = Message(Question(args.ask[0]),
tags=set(args.output_tags) if args.output_tags is not None else None,
ai=args.AI,
model=args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
if with_answer:
message.answer = Answer('Answer 0')
message.to_file()
return message
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
@ -325,20 +290,21 @@ class TestQuestionCmdAsk(TestQuestionCmd):
"""
Test single answer with no errors.
"""
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model)
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
@ -355,20 +321,22 @@ class TestQuestionCmdAsk(TestQuestionCmd):
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
mock_create_ai.side_effect = self.mock_create_ai
expected_question = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model)
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
@ -385,16 +353,19 @@ class TestQuestionCmdAsk(TestQuestionCmd):
Provoke an error during the AI request and verify that the question
has been correctly stored in the cache.
"""
mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model)
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
# execute the command
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
@ -410,26 +381,28 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
"""
Repeat a single question.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a message
message = self.create_single_message(self.args)
# repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
expected_responses = expected_response + expected_response
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
expected_responses += expected_responses
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses)
@ -438,29 +411,31 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
"""
Repeat a single question and overwrite the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a message
message = self.create_single_message(self.args)
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# repeat the last question (WITH overwriting)
# 2. repeat the last question (WITH overwriting)
# -> expect a single message afterwards
self.args.ask = None
self.args.repeat = []
self.args.overwrite = True
fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_response)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -470,31 +445,35 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
"""
Repeat a single question after an error.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a question WITHOUT an answer
# -> just like after an error, which is tested above
question = self.create_single_message(self.args, with_answer=False)
# 1. ask a question and provoke an error
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
self.ai.request.side_effect = AIError
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# repeat the last question (without overwriting)
# 2. repeat the last question (without overwriting)
# -> expect a single message because if the original has
# no answer, it should be overwritten by default
self.args.ask = None
self.args.repeat = []
self.args.overwrite = False
fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
self.ai.request.side_effect = self.mock_request
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_response)
self.assert_messages_equal(cached_msg, expected_responses)
# also check that the file ID has not been changed
assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -504,15 +483,21 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
"""
Repeat a single question with new arguments.
"""
mock_create_ai.side_effect = self.mock_create_ai
# create a message
message = self.create_single_message(self.args)
# 1. ask a question
mock_create_ai.return_value = self.ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# repeat the last question with new arguments (without overwriting)
# 2. repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata
self.args.ask = None
self.args.repeat = []
@ -520,26 +505,25 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.args.output_tags = ['newtag']
self.args.AI = 'newai'
self.args.model = 'newmodel'
new_expected_question = Message(question=Question(message.question),
new_expected_question = Message(question=Question(expected_question.question),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model)
fake_ai = self.mock_create_ai(self.args, self.config)
new_expected_response = fake_ai.request(new_expected_question,
expected_responses += self.mock_request(new_expected_question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, [message] + new_expected_response)
self.assert_messages_equal(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
"""
Repeat multiple questions.
"""
# 1. === create three questions ===
# 1. create some questions / messages
# cached message without an answer
message1 = Message(Question('Question 1'),
ai='foo',
@ -547,7 +531,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question('Question 2'),
Answer('Answer 0'),
Answer('Answer 2'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.cache_dir.name) / '0002.txt')
@ -559,16 +543,10 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
message1.to_file()
message2.to_file()
message3.to_file()
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
expected_responses += fake_ai.request(question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# chat = ChatDB.from_dir(Path(self.cache_dir.name),
# Path(self.db_dir.name))
# 2. === repeat all three questions (without overwriting) ===
# 2. repeat all three questions (without overwriting)
self.args.ask = None
self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False
@ -576,13 +554,7 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# two new files should be in the cache directory
# * the repeated cached message with answer
# * the repeated DB message
# -> the cached message without answer should be overwritten
# -> the cached message wihtout answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
print(f"Cached: {cached_msg}")
print(f"Expected: {expected_cache_messages}")
self.assert_messages_equal(cached_msg, expected_cache_messages)
# FIXME: also compare actual content!