Compare commits

...

9 Commits

5 changed files with 298 additions and 160 deletions

View File

@ -3,18 +3,20 @@ Creates different AI instances, based on the given configuration.
""" """
import argparse import argparse
from typing import cast from typing import cast, Optional
from .configuration import Config, AIConfig, OpenAIConfig from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError from .ai import AI, AIError
from .ais.openai import OpenAI from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 def create_ai(args: argparse.Namespace, config: Config, # noqa: 11
def_ai: Optional[str] = None,
def_model: Optional[str] = None) -> AI:
""" """
Creates an AI subclass instance from the given arguments Creates an AI subclass instance from the given arguments and configuration file.
and configuration file. If AI has not been set in the If AI has not been set in the arguments, it searches for the ID 'default'. If
arguments, it searches for the ID 'default'. If that that is not found, it uses the first AI in the list. It's also possible to
is not found, it uses the first AI in the list. specify a default AI and model using 'def_ai' and 'def_model'.
""" """
ai_conf: AIConfig ai_conf: AIConfig
if hasattr(args, 'AI') and args.AI: if hasattr(args, 'AI') and args.AI:
@ -22,6 +24,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai_conf = config.ais[args.AI] ai_conf = config.ais[args.AI]
except KeyError: except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif def_ai:
ai_conf = config.ais[def_ai]
elif 'default' in config.ais: elif 'default' in config.ais:
ai_conf = config.ais['default'] ai_conf = config.ais['default']
else: else:
@ -34,6 +38,8 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
if hasattr(args, 'model') and args.model: if hasattr(args, 'model') and args.model:
ai.config.model = args.model ai.config.model = args.model
elif def_model:
ai.config.model = def_model
if hasattr(args, 'max_tokens') and args.max_tokens: if hasattr(args, 'max_tokens') and args.max_tokens:
ai.config.max_tokens = args.max_tokens ai.config.max_tokens = args.max_tokens
if hasattr(args, 'temperature') and args.temperature: if hasattr(args, 'temperature') and args.temperature:

View File

@ -44,7 +44,7 @@ class OpenAI(AI):
frequency_penalty=self.config.frequency_penalty, frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty) presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content']) question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags question.tags = set(otags) if otags is not None else None
question.ai = self.ID question.ai = self.ID
question.model = self.config.model question.model = self.config.model
answers: list[Message] = [question] answers: list[Message] = [question]

View File

@ -2,6 +2,7 @@ import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
from itertools import zip_longest from itertools import zip_longest
from copy import deepcopy
from ..configuration import Config from ..configuration import Config
from ..chat import ChatDB from ..chat import ChatDB
from ..message import Message, MessageFilter, MessageError, Question, source_code from ..message import Message, MessageFilter, MessageError, Question, source_code
@ -105,13 +106,75 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
print(response.tokens) print(response.tokens)
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
"""
Takes an existing message and CLI arguments, and returns modified args based
on the members of the given message. Used e.g. when repeating messages, where
it's necessary to determine the correct AI, module and output tags to use
(either from the existing message or the given args).
"""
msg_args = args
# if AI, model or output tags have not been specified,
# use those from the original message
if (args.AI is None
or args.model is None # noqa: W503
or args.output_tags is None): # noqa: W503
msg_args = deepcopy(args)
if args.AI is None and msg.ai is not None:
msg_args.AI = msg.ai
if args.model is None and msg.model is not None:
msg_args.model = msg.model
if args.output_tags is None and msg.tags is not None:
msg_args.output_tags = msg.tags
return msg_args
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
"""
Repeat the given messages using the given arguments.
"""
ai: AI
for msg in messages:
msg_args = create_msg_args(msg, args)
ai = create_ai(msg_args, config)
print(f"--------- Repeating message '{msg.msg_id()}': ---------")
# overwrite the latest message if requested or empty
# -> but not if it's in the DB!
if ((msg.answer is None or msg_args.overwrite is True)
and (not chat.msg_in_db(msg))): # noqa: W503
msg.clear_answer()
make_request(ai, chat, msg, msg_args)
# otherwise create a new one
else:
msg_args.ask = [msg.question]
message = create_message(chat, msg_args)
make_request(ai, chat, message, msg_args)
def invert_input_tag_args(args: argparse.Namespace) -> None:
"""
Changes the semantics of the INPUT tags for this command:
* not tags specified on the CLI -> no tags are selected
* empty tags specified on the CLI -> all tags are selected
"""
if args.or_tags is None:
args.or_tags = set()
elif len(args.or_tags) == 0:
args.or_tags = None
if args.and_tags is None:
args.and_tags = set()
elif len(args.and_tags) == 0:
args.and_tags = None
def question_cmd(args: argparse.Namespace, config: Config) -> None: def question_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'question' command. Handler for the 'question' command.
""" """
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), invert_input_tag_args(args)
tags_and=args.and_tags if args.and_tags is not None else set(), mfilter = MessageFilter(tags_or=args.or_tags,
tags_not=args.exclude_tags if args.exclude_tags is not None else set()) tags_and=args.and_tags,
tags_not=args.exclude_tags)
chat = ChatDB.from_dir(cache_path=Path(config.cache), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db), db_path=Path(config.db),
mfilter=mfilter) mfilter=mfilter)
@ -121,30 +184,24 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
if args.create: if args.create:
return return
# create the correct AI instance
ai: AI = create_ai(args, config)
# === ASK === # === ASK ===
if args.ask: if args.ask:
ai: AI = create_ai(args, config)
make_request(ai, chat, message, args) make_request(ai, chat, message, args)
# === REPEAT === # === REPEAT ===
elif args.repeat is not None: elif args.repeat is not None:
repeat_msgs: list[Message] = []
# repeat latest message
if len(args.repeat) == 0:
lmessage = chat.msg_latest(loc='cache') lmessage = chat.msg_latest(loc='cache')
if lmessage is None: if lmessage is None:
print("No message found to repeat!") print("No message found to repeat!")
sys.exit(1) sys.exit(1)
repeat_msgs.append(lmessage)
# repeat given message(s)
else: else:
print(f"Repeating message '{lmessage.msg_id()}':") repeat_msgs = chat.msg_find(args.repeat, loc='disk')
# overwrite the latest message if requested or empty repeat_messages(repeat_msgs, chat, args, config)
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS === # === PROCESS ===
elif args.process is not None: elif args.process is not None:
# TODO: process either all questions without an # TODO: process either all questions without an

View File

@ -34,13 +34,13 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection # a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False) tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='*',
help='List of tags (one must match)', metavar='OTAGS') help='List of tags (one must match)', metavar='OTAGS')
tag_arg.completer = tags_completer # type: ignore tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='*',
help='List of tags (all must match)', metavar='ATAGS') help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='*',
help='List of tags to exclude', metavar='XTAGS') help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',

View File

@ -4,9 +4,9 @@ import argparse
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, call, ANY from unittest.mock import MagicMock, call
from typing import Optional from typing import Optional, Union
from chatmastermind.configuration import Config from chatmastermind.configuration import Config, AIConfig
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
@ -14,6 +14,56 @@ from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens, AIError from chatmastermind.ai import AI, AIResponse, Tokens, AIError
class FakeAI(AI):
"""
A mocked version of the 'AI' class.
"""
ID: str
name: str
config: AIConfig
def models(self) -> list[str]:
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
return 123
def print(self) -> None:
pass
def print_models(self) -> None:
pass
def __init__(self, ID: str, model: str, error: bool = False):
self.ID = ID
self.model = model
self.error = error
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function by either returning fake
answers or raising an exception.
"""
if self.error:
raise AIError
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags is not None else None
question.ai = self.ID
question.model = self.model
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai=self.ID,
model=self.model))
return AIResponse(answers, Tokens(10, 10, 20))
class TestQuestionCmdBase(unittest.TestCase): class TestQuestionCmdBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None: def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
""" """
@ -24,6 +74,18 @@ class TestQuestionCmdBase(unittest.TestCase):
# exclude the file_path, compare only Q, A and metadata # exclude the file_path, compare only Q, A and metadata
self.assertTrue(m1.equals(m2, file_path=False, verbose=True)) self.assertTrue(m1.equals(m2, file_path=False, verbose=True))
def mock_create_ai(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model)
def mock_create_ai_with_error(self, args: argparse.Namespace, config: Config) -> AI:
"""
Mocked 'create_ai' that returns a 'FakeAI' instance.
"""
return FakeAI(args.AI, args.model, error=True)
class TestMessageCreate(TestQuestionCmdBase): class TestMessageCreate(TestQuestionCmdBase):
""" """
@ -227,8 +289,8 @@ class TestQuestionCmd(TestQuestionCmdBase):
ask=['What is the meaning of life?'], ask=['What is the meaning of life?'],
num_answers=1, num_answers=1,
output_tags=['science'], output_tags=['science'],
AI='openai', AI='FakeAI',
model='gpt-3.5-turbo', model='FakeModel',
or_tags=None, or_tags=None,
and_tags=None, and_tags=None,
exclude_tags=None, exclude_tags=None,
@ -239,57 +301,37 @@ class TestQuestionCmd(TestQuestionCmdBase):
process=None, process=None,
overwrite=None overwrite=None
) )
# create a mock AI instance
self.ai = MagicMock(spec=AI)
self.ai.request.side_effect = self.mock_request
def input_message(self, args: argparse.Namespace) -> Message: def create_single_message(self, args: argparse.Namespace, with_answer: bool = True) -> Message:
""" message = Message(Question(args.ask[0]),
Create the expected input message for a question using the tags=set(args.output_tags) if args.output_tags is not None else None,
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI, ai=args.AI,
model=args.model) model=args.model,
file_path=Path(self.cache_dir.name) / '0001.txt')
def mock_request(self, if with_answer:
question: Message, message.answer = Answer('Answer 0')
chat: Chat, message.to_file()
num_answers: int = 1, return message
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = set(otags) if otags else None
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next' # exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')]) return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
class TestQuestionCmdAsk(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None: def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
""" """
Test single answer with no errors. Test single answer with no errors.
""" """
mock_create_ai.return_value = self.ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = self.input_message(self.args) expected_question = Message(Question(self.args.ask[0]),
expected_responses = self.mock_request(expected_question, tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model)
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
self.args.output_tags).messages self.args.output_tags).messages
@ -297,11 +339,6 @@ class TestQuestionCmd(TestQuestionCmdBase):
# execute the command # execute the command
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
@ -318,9 +355,13 @@ class TestQuestionCmd(TestQuestionCmdBase):
chat = MagicMock(spec=ChatDB) chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat mock_from_dir.return_value = chat
mock_create_ai.return_value = self.ai mock_create_ai.side_effect = self.mock_create_ai
expected_question = self.input_message(self.args) expected_question = Message(Question(self.args.ask[0]),
expected_responses = self.mock_request(expected_question, tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model)
fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = fake_ai.request(expected_question,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
self.args.output_tags).messages self.args.output_tags).messages
@ -328,12 +369,6 @@ class TestQuestionCmd(TestQuestionCmdBase):
# execute the command # execute the command
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls: # check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request) # - initial question has been written (prior to the actual request)
# - responses have been written (after the request) # - responses have been written (after the request)
@ -350,19 +385,16 @@ class TestQuestionCmd(TestQuestionCmdBase):
Provoke an error during the AI request and verify that the question Provoke an error during the AI request and verify that the question
has been correctly stored in the cache. has been correctly stored in the cache.
""" """
mock_create_ai.return_value = self.ai mock_create_ai.side_effect = self.mock_create_ai_with_error
expected_question = self.input_message(self.args) expected_question = Message(Question(self.args.ask[0]),
self.ai.request.side_effect = AIError tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model)
# execute the command # execute the command
with self.assertRaises(AIError): with self.assertRaises(AIError):
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
# check for correct request call
self.ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files # check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
@ -370,33 +402,34 @@ class TestQuestionCmd(TestQuestionCmdBase):
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question]) self.assert_messages_equal(cached_msg, [expected_question])
class TestQuestionCmdRepeat(TestQuestionCmd):
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None: def test_repeat_single_question(self, mock_create_ai: MagicMock) -> None:
""" """
Repeat a single question. Repeat a single question.
""" """
# 1. ask a question mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a message
expected_question = self.input_message(self.args) message = self.create_single_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (without overwriting) # repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path) # -> expect two identical messages (except for the file_path)
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
expected_responses += expected_responses fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
expected_responses = expected_response + expected_response
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_responses)
@ -405,31 +438,29 @@ class TestQuestionCmd(TestQuestionCmdBase):
""" """
Repeat a single question and overwrite the old one. Repeat a single question and overwrite the old one.
""" """
# 1. ask a question mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a message
expected_question = self.input_message(self.args) message = self.create_single_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question (WITH overwriting) # repeat the last question (WITH overwriting)
# -> expect a single message afterwards # -> expect a single message afterwards
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = True self.args.overwrite = True
fake_ai = self.mock_create_ai(self.args, self.config)
expected_response = fake_ai.request(message,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_response)
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -439,35 +470,31 @@ class TestQuestionCmd(TestQuestionCmdBase):
""" """
Repeat a single question after an error. Repeat a single question after an error.
""" """
# 1. ask a question and provoke an error mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a question WITHOUT an answer
expected_question = self.input_message(self.args) # -> just like after an error, which is tested above
self.ai.request.side_effect = AIError question = self.create_single_message(self.args, with_answer=False)
with self.assertRaises(AIError):
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem cached_msg_file_id = cached_msg[0].file_path.stem
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, [expected_question])
# 2. repeat the last question (without overwriting) # repeat the last question (without overwriting)
# -> expect a single message because if the original has # -> expect a single message because if the original has
# no answer, it should be overwritten by default # no answer, it should be overwritten by default
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
self.args.overwrite = False self.args.overwrite = False
self.ai.request.side_effect = self.mock_request fake_ai = self.mock_create_ai(self.args, self.config)
expected_responses = self.mock_request(expected_question, expected_response = fake_ai.request(question,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
self.args.output_tags).messages self.args.output_tags).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, expected_response)
# also check that the file ID has not been changed # also check that the file ID has not been changed
assert cached_msg[0].file_path assert cached_msg[0].file_path
self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem) self.assertEqual(cached_msg_file_id, cached_msg[0].file_path.stem)
@ -477,21 +504,15 @@ class TestQuestionCmd(TestQuestionCmdBase):
""" """
Repeat a single question with new arguments. Repeat a single question with new arguments.
""" """
# 1. ask a question mock_create_ai.side_effect = self.mock_create_ai
mock_create_ai.return_value = self.ai # create a message
expected_question = self.input_message(self.args) message = self.create_single_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name), chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name)) Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1) assert cached_msg[0].file_path
self.assert_messages_equal(cached_msg, expected_responses)
# 2. repeat the last question with new arguments (without overwriting) # repeat the last question with new arguments (without overwriting)
# -> expect two messages with identical question and answer, but different metadata # -> expect two messages with identical question and answer, but different metadata
self.args.ask = None self.args.ask = None
self.args.repeat = [] self.args.repeat = []
@ -499,15 +520,69 @@ class TestQuestionCmd(TestQuestionCmdBase):
self.args.output_tags = ['newtag'] self.args.output_tags = ['newtag']
self.args.AI = 'newai' self.args.AI = 'newai'
self.args.model = 'newmodel' self.args.model = 'newmodel'
new_expected_question = Message(question=Question(expected_question.question), new_expected_question = Message(question=Question(message.question),
tags=set(self.args.output_tags), tags=set(self.args.output_tags),
ai=self.args.AI, ai=self.args.AI,
model=self.args.model) model=self.args.model)
expected_responses += self.mock_request(new_expected_question, fake_ai = self.mock_create_ai(self.args, self.config)
new_expected_response = fake_ai.request(new_expected_question,
Chat([]), Chat([]),
self.args.num_answers, self.args.num_answers,
set(self.args.output_tags)).messages set(self.args.output_tags)).messages
question_cmd(self.args, self.config) question_cmd(self.args, self.config)
cached_msg = chat.msg_gather(loc='cache') cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 2) self.assertEqual(len(self.message_list(self.cache_dir)), 2)
self.assert_messages_equal(cached_msg, expected_responses) self.assert_messages_equal(cached_msg, [message] + new_expected_response)
@mock.patch('chatmastermind.commands.question.create_ai')
def test_repeat_multiple_questions(self, mock_create_ai: MagicMock) -> None:
"""
Repeat multiple questions.
"""
# 1. === create three questions ===
# cached message without an answer
message1 = Message(Question('Question 1'),
ai='foo',
model='bla',
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question('Question 2'),
Answer('Answer 0'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer
message3 = Message(Question('Question 3'),
ai='openai',
model='gpt-3.5-turbo',
file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file()
message2.to_file()
message3.to_file()
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
for question in questions:
expected_responses += fake_ai.request(question,
Chat([]),
self.args.num_answers,
set(self.args.output_tags)).messages
# 2. === repeat all three questions (without overwriting) ===
self.args.ask = None
self.args.repeat = ['0001', '0002', '0003']
self.args.overwrite = False
question_cmd(self.args, self.config)
# two new files should be in the cache directory
# * the repeated cached message with answer
# * the repeated DB message
# -> the cached message without answer should be overwritten
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
print(f"Cached: {cached_msg}")
print(f"Expected: {expected_cache_messages}")
self.assert_messages_equal(cached_msg, expected_cache_messages)