Compare commits

...

2 Commits

Author SHA1 Message Date
a596918d7f question cmd: added test module 2023-09-09 08:51:44 +02:00
c0b7d17587 question_cmd: fixes 2023-09-09 08:51:44 +02:00
2 changed files with 79 additions and 2 deletions

View File

@ -1,5 +1,6 @@
import argparse
from pathlib import Path
from itertools import zip_longest
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, Question
@ -11,8 +12,26 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates (and writes) a new message from the given arguments.
"""
# FIXME: add sources to the question
message = Message(question=Question(args.question),
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
# FIXME: don't surround all sourced files with ```
# -> do it only if '--source-code-only' is True and no source code
# could be extracted from that file
for question, source in zip_longest(question_list, source_list, fillvalue=None):
if question is not None and source is not None:
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
elif question is not None:
question_parts.append(question)
elif source is not None:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME
ai=args.ai,
model=args.model)

View File

@ -0,0 +1,58 @@
import unittest
import argparse
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB
class TestMessageCreate(unittest.TestCase):
"""
Test if messages created by the 'question' command have
the correct format.
"""
def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source = None
self.args.ai = None
self.args.model = None
self.args.output_tags = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None:
self.args.question = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
def test_single_question(self) -> None:
self.args.question = ["What is this?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?"))
self.assertEqual(len(message.question.source_code()), 0)
def test_multipart_question(self) -> None:
self.args.question = ["What is this", "'bard' thing?", "Is it good?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("""What is this
'bard' thing?
Is it good?"""))