Compare commits

..

7 Commits

2 changed files with 219 additions and 29 deletions

View File

@ -3,7 +3,7 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from .tags import Tag, TagLine, TagError, match_tags from .tags import Tag, TagLine, TagError, match_tags
@ -57,6 +57,23 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
return code_sections return code_sections
@dataclass(kw_only=True)
class MessageFilter:
"""
Various filters for a Message.
"""
tags_or: Optional[set[Tag]] = None
tags_and: Optional[set[Tag]] = None
tags_not: Optional[set[Tag]] = None
ai: Optional[str] = None
model: Optional[str] = None
question_contains: Optional[str] = None
answer_contains: Optional[str] = None
answer_state: Optional[Literal['available', 'missing']] = None
ai_state: Optional[Literal['available', 'missing']] = None
model_state: Optional[Literal['available', 'missing']] = None
class AILine(str): class AILine(str):
""" """
A line that represents the AI name in a '.txt' file.. A line that represents the AI name in a '.txt' file..
@ -214,12 +231,11 @@ class Message():
@classmethod @classmethod
def from_file(cls: Type[MessageInst], file_path: pathlib.Path, def from_file(cls: Type[MessageInst], file_path: pathlib.Path,
tags_or: Optional[set[Tag]] = None, mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]:
tags_and: Optional[set[Tag]] = None,
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
""" """
Create a Message from the given file. Returns 'None' if the message does Create a Message from the given file. Returns 'None' if the message does
not fulfill the tag requirements. not fulfill the filter requirements. For TXT files, the tags are matched
before building the whole message. The other filters are applied afterwards.
""" """
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
@ -227,9 +243,16 @@ class Message():
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt': if file_path.suffix == '.txt':
return cls.__from_file_txt(file_path, tags_or, tags_and, tags_not) message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None)
else: else:
return cls.__from_file_yaml(file_path, tags_or, tags_and, tags_not) message = cls.__from_file_yaml(file_path)
if message and (not mfilter or (mfilter and message.match(mfilter))):
return message
else:
return None
@classmethod @classmethod
def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11
@ -289,10 +312,7 @@ class Message():
return cls(question, answer, tags, ai, model, file_path) return cls(question, answer, tags, ai, model, file_path)
@classmethod @classmethod
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path, def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
tags_or: Optional[set[Tag]] = None,
tags_and: Optional[set[Tag]] = None,
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
""" """
Create a Message from the given YAML file. Expects the following file structures: Create a Message from the given YAML file. Expects the following file structures:
* Question.yaml_key: single or multiline string * Question.yaml_key: single or multiline string
@ -300,18 +320,9 @@ class Message():
* Message.tags_yaml_key: list of strings [Optional] * Message.tags_yaml_key: list of strings [Optional]
* Message.ai_yaml_key: str [Optional] * Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional]
Returns 'None' if the message does not fulfill the tag requirements.
""" """
tags: set[Tag] = set()
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader) data = yaml.load(fd, Loader=yaml.FullLoader)
if tags_or or tags_and or tags_not:
if Message.tags_yaml_key in data:
tags = set([Tag(tag) for tag in data[Message.tags_yaml_key]])
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
data[cls.file_yaml_key] = file_path data[cls.file_yaml_key] = file_path
return cls.from_dict(data) return cls.from_dict(data)
@ -377,5 +388,26 @@ class Message():
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd, sort_keys=False) yaml.dump(data, fd, sort_keys=False)
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
"""
Matches the current Message to the given filter atttributes.
Return True if all attributes match, else False.
"""
mytags = self.tags or set()
if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503
or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503
or (mfilter.model_state == 'missing' and self.model)): # noqa: W503
return False
return True
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -2,7 +2,7 @@ import pathlib
import tempfile import tempfile
from typing import cast from typing import cast
from .test_main import CmmTestCase from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
@ -215,6 +215,8 @@ class MessageFromFileTxtTestCase(CmmTestCase):
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd: with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 fd.write(f"""{TagLine.prefix} tag1 tag2
{AILine.prefix} ChatGPT
{ModelLine.prefix} gpt-3.5-turbo
{Question.txt_header} {Question.txt_header}
This is a question. This is a question.
{Answer.txt_header} {Answer.txt_header}
@ -244,6 +246,8 @@ This is a question.
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path) self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_min(self) -> None: def test_from_file_txt_min(self) -> None:
@ -259,7 +263,8 @@ This is a question.
self.assertIsNone(message.answer) self.assertIsNone(message.answer)
def test_from_file_txt_tags_match(self) -> None: def test_from_file_txt_tags_match(self) -> None:
message = Message.from_file(self.file_path, tags_or={Tag('tag1')}) message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug if message: # mypy bug
@ -269,15 +274,18 @@ This is a question.
self.assertEqual(message.file_path, self.file_path) self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_tags_dont_match(self) -> None: def test_from_file_txt_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path, tags_or={Tag('tag3')}) message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag3')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_txt_no_tags_dont_match(self) -> None: def test_from_file_txt_no_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min, tags_or={Tag('tag1')}) message = Message.from_file(self.file_path_min,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None: def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min, tags_not={Tag('tag1')}) message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug if message: # mypy bug
@ -291,6 +299,77 @@ This is a question.
Message.from_file(file_not_exists) Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
def test_from_file_txt_question_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='question'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='answer'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_available(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='available'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_missing(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='missing'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_question_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='answer'))
self.assertIsNone(message)
def test_from_file_txt_answer_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='question'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_exists(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_contains='answer'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_available(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='available'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_missing(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='missing'))
self.assertIsNone(message)
def test_from_file_txt_ai_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='ChatGPT'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_ai_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='Foo'))
self.assertIsNone(message)
def test_from_file_txt_model_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='gpt-3.5-turbo'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_model_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='Bar'))
self.assertIsNone(message)
class MessageFromFileYamlTestCase(CmmTestCase): class MessageFromFileYamlTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None:
@ -302,6 +381,8 @@ class MessageFromFileYamlTestCase(CmmTestCase):
This is a question. This is a question.
{Answer.yaml_key}: |- {Answer.yaml_key}: |-
This is an answer. This is an answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}: {Message.tags_yaml_key}:
- tag1 - tag1
- tag2 - tag2
@ -331,6 +412,8 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertEqual(message.question, 'This is a question.') self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.') self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path) self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_min(self) -> None: def test_from_file_yaml_min(self) -> None:
@ -353,7 +436,8 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
def test_from_file_yaml_tags_match(self) -> None: def test_from_file_yaml_tags_match(self) -> None:
message = Message.from_file(self.file_path, tags_or={Tag('tag1')}) message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug if message: # mypy bug
@ -363,15 +447,18 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertEqual(message.file_path, self.file_path) self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_tags_dont_match(self) -> None: def test_from_file_yaml_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path, tags_or={Tag('tag3')}) message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag3')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_yaml_no_tags_dont_match(self) -> None: def test_from_file_yaml_no_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min, tags_or={Tag('tag1')}) message = Message.from_file(self.file_path_min,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message) self.assertIsNone(message)
def test_from_file_yaml_no_tags_match_tags_not(self) -> None: def test_from_file_yaml_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min, tags_not={Tag('tag1')}) message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message) self.assertIsNotNone(message)
self.assertIsInstance(message, Message) self.assertIsInstance(message, Message)
if message: # mypy bug if message: # mypy bug
@ -379,6 +466,77 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertSetEqual(cast(set[Tag], message.tags), set()) self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min) self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_yaml_question_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='question'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='answer'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_available(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='available'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_missing(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='missing'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_question_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='answer'))
self.assertIsNone(message)
def test_from_file_yaml_answer_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='question'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_exists(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_contains='answer'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_available(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='available'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_missing(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='missing'))
self.assertIsNone(message)
def test_from_file_yaml_ai_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='ChatGPT'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_ai_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='Foo'))
self.assertIsNone(message)
def test_from_file_yaml_model_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='gpt-3.5-turbo'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_model_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='Bar'))
self.assertIsNone(message)
class TagsFromFileTestCase(CmmTestCase): class TagsFromFileTestCase(CmmTestCase):
def setUp(self) -> None: def setUp(self) -> None: