Compare commits

..

3 Commits

6 changed files with 87 additions and 31 deletions

View File

@ -15,7 +15,7 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
tags_not=args.exclude_tags, tags_not=args.exclude_tags,
question_contains=args.question, question_contains=args.question,
answer_contains=args.answer) answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'), chat = ChatDB.from_dir(Path(config.cache),
Path(config.db), Path(config.db),
mfilter=mfilter) mfilter=mfilter)
chat.print(args.source_code_only, chat.print(args.source_code_only,

View File

@ -112,7 +112,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(), tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set()) tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db), db_path=Path(config.db),
mfilter=mfilter) mfilter=mfilter)
# if it's a new question, create and store it immediately # if it's a new question, create and store it immediately

View File

@ -8,7 +8,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tags' command. Handler for the 'tags' command.
""" """
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path(config.cache),
db_path=Path(config.db)) db_path=Path(config.db))
if args.list: if args.list:
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain) tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)

View File

@ -116,6 +116,7 @@ class Config:
""" """
# all members have default values, so we can easily create # all members have default values, so we can easily create
# a default configuration # a default configuration
cache: str = '.'
db: str = './db/' db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@ -132,6 +133,7 @@ class Config:
ai_conf = ai_config_instance(conf['name'], conf) ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf ais[ID] = ai_conf
return cls( return cls(
cache=str(source['cache']) if 'cache' in source else '.',
db=str(source['db']), db=str(source['db']),
ais=ais ais=ais
) )

View File

@ -57,6 +57,7 @@ class TestConfig(unittest.TestCase):
def test_from_dict_should_create_config_from_dict(self) -> None: def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = { source_dict = {
'cache': '.',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'myopenai': { 'myopenai': {
@ -73,6 +74,7 @@ class TestConfig(unittest.TestCase):
} }
} }
config = Config.from_dict(source_dict) config = Config.from_dict(source_dict)
self.assertEqual(config.cache, '.')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai') self.assertEqual(config.ais['myopenai'].name, 'openai')
@ -89,6 +91,7 @@ class TestConfig(unittest.TestCase):
def test_from_file_should_load_config_from_file(self) -> None: def test_from_file_should_load_config_from_file(self) -> None:
source_dict = { source_dict = {
'cache': './test_cache/',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'default': {
@ -108,6 +111,7 @@ class TestConfig(unittest.TestCase):
yaml.dump(source_dict, f) yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name) config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config) self.assertIsInstance(config, Config)
self.assertEqual(config.cache, './test_cache/')
self.assertEqual(config.db, './test_db/') self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1) self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig) self.assertIsInstance(config.ais['default'], AIConfig)
@ -115,6 +119,7 @@ class TestConfig(unittest.TestCase):
def test_to_file_should_save_config_to_file(self) -> None: def test_to_file_should_save_config_to_file(self) -> None:
config = Config( config = Config(
cache='./test_cache/',
db='./test_db/', db='./test_db/',
ais={ ais={
'myopenai': OpenAIConfig( 'myopenai': OpenAIConfig(
@ -133,12 +138,14 @@ class TestConfig(unittest.TestCase):
config.to_file(Path(self.test_file.name)) config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f: with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader) saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['cache'], './test_cache/')
self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None: def test_from_file_error_unknown_ai(self) -> None:
source_dict = { source_dict = {
'cache': './test_cache/',
'db': './test_db/', 'db': './test_db/',
'ais': { 'ais': {
'default': { 'default': {

View File

@ -4,11 +4,13 @@ import argparse
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call, ANY
from typing import Optional
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.commands.question import create_message, question_cmd from chatmastermind.commands.question import create_message, question_cmd
from chatmastermind.tags import Tag
from chatmastermind.message import Message, Question, Answer from chatmastermind.message import Message, Question, Answer
from chatmastermind.chat import ChatDB from chatmastermind.chat import Chat, ChatDB
from chatmastermind.ai import AI, AIResponse, Tokens from chatmastermind.ai import AI, AIResponse, Tokens
@ -19,10 +21,10 @@ class TestMessageCreate(unittest.TestCase):
""" """
def setUp(self) -> None: def setUp(self) -> None:
# create ChatDB structure # create ChatDB structure
self.db_path = tempfile.TemporaryDirectory() self.db_dir = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), self.chat = ChatDB.from_dir(cache_path=Path(self.cache_dir.name),
db_path=Path(self.db_path.name)) db_path=Path(self.db_dir.name))
# create some messages # create some messages
self.message_text = Message(Question("What is this?"), self.message_text = Message(Question("What is this?"),
Answer("It is pure text")) Answer("It is pure text"))
@ -85,10 +87,10 @@ Aaaand again some text."""
def test_message_file_created(self) -> None: def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"] self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args) create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_dir)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0]) message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
@ -203,10 +205,12 @@ class TestQuestionCmd(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# create DB and cache # create DB and cache
self.db_path = tempfile.TemporaryDirectory() self.db_dir = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_dir = tempfile.TemporaryDirectory()
# create configuration # create configuration
self.config = Config() self.config = Config()
self.config.cache = self.cache_dir.name
self.config.db = self.db_dir.name
# create a mock argparse.Namespace # create a mock argparse.Namespace
self.args = argparse.Namespace( self.args = argparse.Namespace(
ask=['What is the meaning of life?'], ask=['What is the meaning of life?'],
@ -238,36 +242,79 @@ class TestQuestionCmd(unittest.TestCase):
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
def response(self, args: argparse.Namespace) -> AIResponse: def mock_request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
""" """
Create the expected AI response from the give arguments. Mock the 'ai.request()' function
""" """
input_msg = self.input_message(args) question.answer = Answer("Answer 0")
response = AIResponse(messages=[], tokens=Tokens(10, 10, 20)) question.tags = otags
for n in range(args.num_answers): question.ai = 'FakeAI'
response_msg = Message(input_msg.question, question.model = 'FakeModel'
Answer(f"Answer {n}"), answers: list[Message] = [question]
tags=input_msg.tags, for n in range(1, num_answers):
ai=input_msg.ai, answers.append(Message(question=question.question,
model=input_msg.model) answer=Answer(f"Answer {n}"),
response.messages.append(response_msg) tags=otags,
return response ai='FakeAI',
model='FakeModel'))
return AIResponse(answers, Tokens(10, 10, 20))
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
@mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock) -> None:
"""
Test single answer with no errors
"""
# create a mock AI instance
ai = MagicMock(spec=AI)
ai.request.side_effect = self.mock_request
mock_create_ai.return_value = ai
expected_question = self.input_message(self.args)
expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command
question_cmd(self.args, self.config)
# check for correct request call
ai.request.assert_called_once_with(expected_question,
ANY,
self.args.num_answers,
self.args.output_tags)
# check for the expected message files
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assertEqual(len(self.message_list(self.cache_dir)), 1)
self.assertSequenceEqual(cached_msg, expected_responses)
@mock.patch('chatmastermind.commands.question.ChatDB.from_dir') @mock.patch('chatmastermind.commands.question.ChatDB.from_dir')
@mock.patch('chatmastermind.commands.question.create_ai') @mock.patch('chatmastermind.commands.question.create_ai')
def test_ask_single_answer(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None: def test_ask_single_answer_mocked(self, mock_create_ai: MagicMock, mock_from_dir: MagicMock) -> None:
"""
# FIXME: this mock is only neccessary because the cache dir is not Test single answer with no errors (mocked ChatDB version)
# configurable in the configuration file """
chat = MagicMock(spec=ChatDB) chat = MagicMock(spec=ChatDB)
mock_from_dir.return_value = chat mock_from_dir.return_value = chat
# create a mock AI instance # create a mock AI instance
ai = MagicMock(spec=AI) ai = MagicMock(spec=AI)
ai.request.return_value = self.response(self.args) ai.request.side_effect = self.mock_request
mock_create_ai.return_value = ai mock_create_ai.return_value = ai
expected_question = self.input_message(self.args) expected_question = self.input_message(self.args)
expected_responses = ai.request.return_value.messages expected_responses = self.mock_request(expected_question,
Chat([]),
self.args.num_answers,
self.args.output_tags).messages
# execute the command # execute the command
question_cmd(self.args, self.config) question_cmd(self.args, self.config)