chat: added tags_frequency() function and test
This commit is contained in:
parent
8e0a158ac9
commit
45393e3a15
@ -127,7 +127,16 @@ class Chat:
|
|||||||
tags: set[Tag] = set()
|
tags: set[Tag] = set()
|
||||||
for m in self.messages:
|
for m in self.messages:
|
||||||
tags |= m.filter_tags(prefix, contain)
|
tags |= m.filter_tags(prefix, contain)
|
||||||
return tags
|
return set(sorted(tags))
|
||||||
|
|
||||||
|
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
|
||||||
|
"""
|
||||||
|
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
|
||||||
|
"""
|
||||||
|
tags: list[Tag] = []
|
||||||
|
for m in self.messages:
|
||||||
|
tags += [tag for tag in m.filter_tags(prefix, contain)]
|
||||||
|
return {tag: tags.count(tag) for tag in sorted(tags)}
|
||||||
|
|
||||||
def tokens(self) -> int:
|
def tokens(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class TestChat(CmmTestCase):
|
|||||||
self.chat = Chat([])
|
self.chat = Chat([])
|
||||||
self.message1 = Message(Question('Question 1'),
|
self.message1 = Message(Question('Question 1'),
|
||||||
Answer('Answer 1'),
|
Answer('Answer 1'),
|
||||||
{Tag('atag1')},
|
{Tag('atag1'), Tag('btag2')},
|
||||||
file_path=pathlib.Path('0001.txt'))
|
file_path=pathlib.Path('0001.txt'))
|
||||||
self.message2 = Message(Question('Question 2'),
|
self.message2 = Message(Question('Question 2'),
|
||||||
Answer('Answer 2'),
|
Answer('Answer 2'),
|
||||||
@ -57,6 +57,11 @@ class TestChat(CmmTestCase):
|
|||||||
tags_cont = self.chat.tags(contain='2')
|
tags_cont = self.chat.tags(contain='2')
|
||||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||||
|
|
||||||
|
def test_tags_frequency(self) -> None:
|
||||||
|
self.chat.add_msgs([self.message1, self.message2])
|
||||||
|
tags_freq = self.chat.tags_frequency()
|
||||||
|
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
||||||
|
|
||||||
@patch('sys.stdout', new_callable=StringIO)
|
@patch('sys.stdout', new_callable=StringIO)
|
||||||
def test_print(self, mock_stdout: StringIO) -> None:
|
def test_print(self, mock_stdout: StringIO) -> None:
|
||||||
self.chat.add_msgs([self.message1, self.message2])
|
self.chat.add_msgs([self.message1, self.message2])
|
||||||
@ -83,7 +88,7 @@ Answer 2
|
|||||||
Question 1
|
Question 1
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 1
|
Answer 1
|
||||||
{TagLine.prefix} atag1
|
{TagLine.prefix} atag1 btag2
|
||||||
FILE: 0001.txt
|
FILE: 0001.txt
|
||||||
{'-'*terminal_width()}
|
{'-'*terminal_width()}
|
||||||
{Question.txt_header}
|
{Question.txt_header}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user