Compare commits

..

8 Commits

2 changed files with 33 additions and 5 deletions

View File

@ -4,7 +4,7 @@ Module implementing message related functions and classes.
import pathlib import pathlib
import yaml import yaml
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags from .tags import Tag, TagLine, TagError, match_tags
QuestionInst = TypeVar('QuestionInst', bound='Question') QuestionInst = TypeVar('QuestionInst', bound='Question')
@ -188,10 +188,11 @@ class Message():
""" """
question: Question question: Question
answer: Optional[Answer] = None answer: Optional[Answer] = None
tags: Optional[set[Tag]] = None # metadata, ignored when comparing messages
ai: Optional[str] = None tags: Optional[set[Tag]] = field(default=None, compare=False)
model: Optional[str] = None ai: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = None model: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
# class variables # class variables
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
tags_yaml_key: ClassVar[str] = 'tags' tags_yaml_key: ClassVar[str] = 'tags'
@ -199,6 +200,12 @@ class Message():
ai_yaml_key: ClassVar[str] = 'ai' ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model' model_yaml_key: ClassVar[str] = 'model'
def __hash__(self) -> int:
"""
The hash value is computed based on immutable members.
"""
return hash((self.question, self.answer))
@classmethod @classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
""" """

View File

@ -595,3 +595,24 @@ class MessageIDTestCase(CmmTestCase):
def test_msg_id_txt_exception(self) -> None: def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
self.message_no_file_path.msg_id() self.message_no_file_path.msg_id()
class MessageHashTestCase(CmmTestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
file_path=pathlib.Path('/tmp/foo/bla'))
self.message2 = Message(Question('This is a new question.'),
file_path=pathlib.Path('/tmp/foo/bla'))
self.message3 = Message(Question('This is a question.'),
Answer('This is an answer.'),
file_path=pathlib.Path('/tmp/foo/bla'))
# message4 is a copy of message1, because 'file_path'
# is ignored for hashing and comparison
self.message4 = Message(Question('This is a question.'),
file_path=pathlib.Path('foobla'))
def test_set_hashing(self) -> None:
msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4}
self.assertEqual(len(msgs), 3)
for msg in [self.message1, self.message2, self.message3]:
self.assertIn(msg, msgs)