Compare commits
8 Commits
e66a20a473
...
a47094f1d9
| Author | SHA1 | Date | |
|---|---|---|---|
| a47094f1d9 | |||
| 2ca4e45f4a | |||
| 2b1d4f248b | |||
| 7e3cd13304 | |||
| ec0371c492 | |||
| ce90ba2cbb | |||
| d089300862 | |||
| 173a46a9b5 |
@ -4,7 +4,7 @@ Module implementing message related functions and classes.
|
|||||||
import pathlib
|
import pathlib
|
||||||
import yaml
|
import yaml
|
||||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal
|
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict, field
|
||||||
from .tags import Tag, TagLine, TagError, match_tags
|
from .tags import Tag, TagLine, TagError, match_tags
|
||||||
|
|
||||||
QuestionInst = TypeVar('QuestionInst', bound='Question')
|
QuestionInst = TypeVar('QuestionInst', bound='Question')
|
||||||
@ -188,10 +188,11 @@ class Message():
|
|||||||
"""
|
"""
|
||||||
question: Question
|
question: Question
|
||||||
answer: Optional[Answer] = None
|
answer: Optional[Answer] = None
|
||||||
tags: Optional[set[Tag]] = None
|
# metadata, ignored when comparing messages
|
||||||
ai: Optional[str] = None
|
tags: Optional[set[Tag]] = field(default=None, compare=False)
|
||||||
model: Optional[str] = None
|
ai: Optional[str] = field(default=None, compare=False)
|
||||||
file_path: Optional[pathlib.Path] = None
|
model: Optional[str] = field(default=None, compare=False)
|
||||||
|
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
|
||||||
# class variables
|
# class variables
|
||||||
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
|
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
|
||||||
tags_yaml_key: ClassVar[str] = 'tags'
|
tags_yaml_key: ClassVar[str] = 'tags'
|
||||||
@ -199,6 +200,12 @@ class Message():
|
|||||||
ai_yaml_key: ClassVar[str] = 'ai'
|
ai_yaml_key: ClassVar[str] = 'ai'
|
||||||
model_yaml_key: ClassVar[str] = 'model'
|
model_yaml_key: ClassVar[str] = 'model'
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""
|
||||||
|
The hash value is computed based on immutable members.
|
||||||
|
"""
|
||||||
|
return hash((self.question, self.answer))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
|
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -595,3 +595,24 @@ class MessageIDTestCase(CmmTestCase):
|
|||||||
def test_msg_id_txt_exception(self) -> None:
|
def test_msg_id_txt_exception(self) -> None:
|
||||||
with self.assertRaises(MessageError):
|
with self.assertRaises(MessageError):
|
||||||
self.message_no_file_path.msg_id()
|
self.message_no_file_path.msg_id()
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHashTestCase(CmmTestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.message1 = Message(Question('This is a question.'),
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
self.message2 = Message(Question('This is a new question.'),
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
self.message3 = Message(Question('This is a question.'),
|
||||||
|
Answer('This is an answer.'),
|
||||||
|
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||||
|
# message4 is a copy of message1, because 'file_path'
|
||||||
|
# is ignored for hashing and comparison
|
||||||
|
self.message4 = Message(Question('This is a question.'),
|
||||||
|
file_path=pathlib.Path('foobla'))
|
||||||
|
|
||||||
|
def test_set_hashing(self) -> None:
|
||||||
|
msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4}
|
||||||
|
self.assertEqual(len(msgs), 3)
|
||||||
|
for msg in [self.message1, self.message2, self.message3]:
|
||||||
|
self.assertIn(msg, msgs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user