Compare commits
2 Commits
68ac4bd60d
...
4afd586e5c
| Author | SHA1 | Date | |
|---|---|---|---|
| 4afd586e5c | |||
| 25303aba7e |
@ -52,9 +52,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'question' command.
|
||||
"""
|
||||
mfilter = MessageFilter(tags_or=args.or_tags,
|
||||
tags_and=args.and_tags,
|
||||
tags_not=args.exclude_tags)
|
||||
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
|
||||
tags_and=args.and_tags if args.and_tags is not None else set(),
|
||||
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
|
||||
chat = ChatDB.from_dir(cache_path=Path('.'),
|
||||
db_path=Path(config.db),
|
||||
mfilter=mfilter)
|
||||
|
||||
@ -312,7 +312,7 @@ class Message():
|
||||
mfilter.tags_not if mfilter else None)
|
||||
else:
|
||||
message = cls.__from_file_yaml(file_path)
|
||||
if message and (not mfilter or (mfilter and message.match(mfilter))):
|
||||
if message and (mfilter is None or message.match(mfilter)):
|
||||
return message
|
||||
else:
|
||||
return None
|
||||
@ -508,7 +508,7 @@ class Message():
|
||||
Return True if all attributes match, else False.
|
||||
"""
|
||||
mytags = self.tags or set()
|
||||
if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not)
|
||||
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
|
||||
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
|
||||
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
|
||||
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
|
||||
|
||||
@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
|
||||
def test_chat_db_filter(self) -> None:
|
||||
def test_chat_db_from_dir_filter_tags(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertEqual(len(chat_db.messages), 1)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
|
||||
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(tags_or=set(),
|
||||
tags_and=set(),
|
||||
tags_not=set()))
|
||||
self.assertEqual(len(chat_db.messages), 0)
|
||||
|
||||
def test_chat_db_from_dir_filter_answer(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(answer_contains='Answer 2'))
|
||||
@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase):
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||
|
||||
def test_chat_db_from_messges(self) -> None:
|
||||
def test_chat_db_from_messages(self) -> None:
|
||||
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
messages=[self.message1, self.message2,
|
||||
|
||||
@ -300,6 +300,12 @@ This is a question.
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_empty_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_or=set(),
|
||||
tags_and=set()))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_no_tags_match_tags_not(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_not={Tag('tag1')}))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user