Compare commits

...

3 Commits

7 changed files with 216 additions and 34 deletions

View File

@ -15,7 +15,7 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
chat = ChatDB.from_dir(Path(config.cache),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,

View File

@ -1,3 +1,4 @@
import sys
import argparse
from pathlib import Path
from itertools import zip_longest
@ -51,7 +52,7 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates a new message from the given arguments and writes it
Create a new message from the given arguments and write it
to the cache directory.
"""
question_parts = []
@ -73,10 +74,37 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
chat.cache_add([message])
# only write the new message to the cache,
# don't add it to the internal list
chat.cache_write([message])
return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the given AI, chat history, message and arguments.
Write the response(s) to the cache directory, without appending it to the
given chat history. Then print the response(s).
"""
# print history and message question before making the request
ai.print()
chat.print(paged=False)
print(message.to_str())
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
# only write the response messages to the cache,
# don't add them to the internal list
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
@ -84,7 +112,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'),
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately
@ -95,28 +123,29 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance
ai: AI = create_ai(args, config)
# === ASK ===
if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.msg_update([response.messages[0]])
chat.cache_add(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
make_request(ai, chat, message, args)
# === REPEAT ===
elif args.repeat is not None:
lmessage = chat.msg_latest()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
lmessage = chat.msg_latest(loc='cache')
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
else:
print(f"Repeating message '{lmessage.msg_id()}':")
# overwrite the latest message if requested or empty
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS ===
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'

View File

@ -8,7 +8,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db))
if args.list:
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)

View File

@ -116,6 +116,7 @@ class Config:
"""
# all members have default values, so we can easily create
# a default configuration
cache: str = '.'
db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@ -132,6 +133,7 @@ class Config:
ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf
return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']),
ais=ais
)

View File

@ -540,6 +540,9 @@ class Message():
if self.tags:
self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str:
"""
Returns an ID that is unique throughout all messages in the same (DB) directory.

View File

@ -57,6 +57,7 @@ class TestConfig(unittest.TestCase):
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'cache': '.',
'db': './test_db/',
'ais': {
'myopenai': {
@ -73,6 +74,7 @@ class TestConfig(unittest.TestCase):
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
@ -89,6 +91,7 @@ class TestConfig(unittest.TestCase):
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'cache': './test_cache/',
'db': './test_db/',
'ais': {
'default': {
@ -108,6 +111,7 @@ class TestConfig(unittest.TestCase):
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
@ -115,6 +119,7 @@ class TestConfig(unittest.TestCase):
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
cache='./test_cache/',
db='./test_db/',
ais={
'myopenai': OpenAIConfig(
@ -133,12 +138,14 @@ class TestConfig(unittest.TestCase):
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['cache'], './test_cache/')
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
'cache': './test_cache/',
'db': './test_db/',
'ais': {
'default': {

View File

@ -3,10 +3,15 @@ import unittest
import argparse
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from unittest import mock
from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB
from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens
class TestMessageCreate(unittest.TestCase):
@ -16,10 +21,10 @@ class TestMessageCreate(unittest.TestCase):
"""
def setUp(self) -> None:
# create ChatDB structure
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_dir.name))
# create some messages
self.message_text = Message(Question("What is this?"),
Answer("It is pure text"))
@ -74,6 +79,7 @@ Aaaand again some text."""
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
os.remove(self.source_file4.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
@ -81,10 +87,10 @@ Aaaand again some text."""
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path)
cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
@ -193,3 +199,138 @@ This is embedded source code.
It is embedded code
```
"""))
class TestQuestionCmd(unittest.TestCase):
def setUp(self) -> None:
# create DB and cache
self.db_dir = tempfile.TemporaryDirectory()
self.cache_dir = tempfile.TemporaryDirectory()
# create configuration
self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace
self.args = argparse.Namespace(
ask=['What is the meaning of life?'],
num_answers=1,
output_tags=['science'],
AI='openai',
model='gpt-3.5-turbo',
or_tags=None,
and_tags=None,
exclude_tags=None,
source_text=None,
source_code=None,
create=None,
repeat=None,
process=None
)
def input_message(self, args: argparse.Namespace) -> Message:
"""
Create the expected input message for a question using the
given arguments.
"""
# NOTE: we only use the first question from the "ask" list
# -> message creation using "question.create_message()" is
# tested above
# the answer is always empty for the input message
return Message(Question(args.ask[0]),
tags=args.output_tags,
ai=args.AI,
model=args.model)
def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Mock the 'ai.request()' function
"""
question.answer = Answer("Answer 0")
question.tags = otags
question.ai = 'FakeAI'
question.model = 'FakeModel'
answers: list[Message] = [question]
for n in range(1, num_answers):
answers.append(Message(question=question.question,
answer=Answer(f"Answer {n}"),
tags=otags,
ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors
"""
# create a mock AI instance
ai = MagicMock(spec=AI)
ai.request.side_effect = self.mock_request
mock_create_ai.return_value = ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assertSequenceEqual(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
Test single answer with no errors (mocked ChatDB version)
"""
chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat
# create a mock AI instance
ai = MagicMock(spec=AI)
ai.request.side_effect = self.mock_request
mock_create_ai.return_value = ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
ai.request.assert_called_once_with(expected_question,
chat,
self.args.num_answers,
self.args.output_tags)
# check for the correct ChatDB calls:
# - initial question has been written (prior to the actual request)
# - responses have been written (after the request)
chat.cache_write.assert_has_calls([call([expected_question]),
call(expected_responses)],
any_order=False)
# check that the messages have not been added to the internal message list
chat.cache_add.assert_not_called()