Compare commits

..

5 Commits

Author SHA1 Message Date
2548ae5a52 added tests for 'chat.py' 2023-08-31 09:42:40 +02:00
2dc8e1d6b2 added new module 'chat.py' 2023-08-31 09:42:40 +02:00
b83cbb719b added 'message_in()' function and test 2023-08-31 09:21:51 +02:00
8e1cdee3bf fixed Message.filter_tags 2023-08-30 08:22:50 +02:00
73d2a9ea3b fixed test case file cleanup 2023-08-29 11:36:01 +02:00
4 changed files with 330 additions and 83 deletions

View File

@ -5,9 +5,10 @@ import shutil
import pathlib import pathlib
from pprint import PrettyPrinter from pprint import PrettyPrinter
from pydoc import pager from pydoc import pager
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
@ -29,6 +30,58 @@ def print_paged(text: str) -> None:
pager(text) pager(text)
def read_dir(dir_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Reads the messages from the given folder.
Parameters:
* 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return messages
def write_dir(dir_path: pathlib.Path,
messages: list[Message],
file_suffix: str,
next_fid: Callable[[], int]) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the given directory.
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
fid = next_fid()
fname = f"{fid:04d}{file_suffix}"
file_path = dir_path / fname
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path)
@dataclass @dataclass
class Chat: class Chat:
""" """
@ -67,6 +120,15 @@ class Chat:
self.messages += msgs self.messages += msgs
self.sort() self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
"""
tags: set[Tag] = set()
for m in self.messages:
tags |= m.filter_tags(prefix, contain)
return tags
def print(self, dump: bool = False, source_code_only: bool = False, def print(self, dump: bool = False, source_code_only: bool = False,
with_tags: bool = False, with_file: bool = False, with_tags: bool = False, with_file: bool = False,
paged: bool = True) -> None: paged: bool = True) -> None:
@ -113,8 +175,10 @@ class ChatDB(Chat):
file_suffix: str = default_file_suffix file_suffix: str = default_file_suffix
# the glob pattern for all messages # the glob pattern for all messages
glob: Optional[str] = None glob: Optional[str] = None
# set containing all file names of the current messages
message_files: set[str] = field(default_factory=set, repr=False) def __post_init__(self) -> None:
# contains the latest message ID
self.next_fname = self.db_path / '.next'
@classmethod @classmethod
def from_dir(cls: Type[ChatDBInst], def from_dir(cls: Type[ChatDBInst],
@ -128,94 +192,78 @@ class ChatDB(Chat):
Parameters: Parameters:
* 'cache_path': path to the directory for temporary messages * 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages * 'db_path': path to the directory for persistent messages
* 'glob' fs specified, files will be filtered using 'path.glob()', * 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'. otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages * 'mfilter': use with 'Message.from_file()' to filter messages
when reading them. when reading them.
""" """
messages: list[Message] = [] messages = read_dir(db_path, glob, mfilter)
message_files: set[str] = set()
file_iter = db_path.glob(glob) if glob else db_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
message_files.add(file_path.name)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return cls(messages, cache_path, db_path, mfilter, return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob, message_files) cls.default_file_suffix, glob)
@classmethod @classmethod
def from_messages(cls: Type[ChatDBInst], def from_messages(cls: Type[ChatDBInst],
cache_path: pathlib.Path, cache_path: pathlib.Path,
db_path: pathlib.Path, db_path: pathlib.Path,
messages: list[Message], messages: list[Message],
mfilter: Optional[MessageFilter]) -> ChatDBInst: mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
""" """
Create a ChatDB instance from the given message list. Note that the next Create a ChatDB instance from the given message list.
call to 'dump()' will write all files in order to synchronize the messages.
Similarly, 'update()' will read all messages, so you may end up with a lot
of duplicates when using 'update()' first.
""" """
return cls(messages, cache_path, db_path, mfilter) return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int: def get_next_fid(self) -> int:
next_fname = self.db_path / '.next'
try: try:
with open(next_fname, 'r') as f: with open(self.next_fname, 'r') as f:
return int(f.read()) + 1 next_fid = int(f.read()) + 1
self.set_next_fid(next_fid)
return next_fid
except Exception: except Exception:
self.set_next_fid(1)
return 1 return 1
def set_next_fid(self, fid: int) -> None: def set_next_fid(self, fid: int) -> None:
next_fname = self.db_path / '.next' with open(self.next_fname, 'w') as f:
with open(next_fname, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def dump(self, to_db: bool = False, force_all: bool = False) -> None: def read_db(self) -> None:
""" """
Write all messages to 'cache_path' (or 'db_path' if 'to_db' is True). If a message Reads new messages from the DB directory. New ones are added to the internal list,
has no file_path, a new one will be created. By default, only messages that have existing ones are kept or replaced. A message is determined as 'existing' if a
not been written (or read) before will be dumped. Use 'force_all' to force writing message with the same base filename (i. e. 'file_path.name') is already in the list.
all message files.
""" """
for message in self.messages: new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# skip messages that we have already written (or read) # remove all messages from self.messages that are in the new list
if message.file_path and message.file_path in self.message_files and not force_all: self.messages = [m for m in self.messages if not message_in(m, new_messages)]
continue # copy the messages from the temporary list to self.messages and sort them
file_path = message.file_path self.messages += new_messages
if not file_path: self.sort()
fid = self.get_next_fid()
fname = f"{fid:04d}{self.file_suffix}"
file_path = self.db_path / fname if to_db else self.cache_path / fname
self.set_next_fid(fid)
message.to_file(file_path)
def update(self, from_cache: bool = False, force_all: bool = False) -> None: def read_cache(self) -> None:
""" """
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true). Reads new messages from the cache directory. New ones are added to the internal list,
By default, only messages that have not been read (or written) before will existing ones are kept or replaced. A message is determined as 'existing' if a
be read. Use 'force_all' to force reading all messages (existing messages message with the same base filename (i. e. 'file_path.name') is already in the list.
are discarded).
""" """
if from_cache: new_messages = read_dir(self.db_path, self.glob, self.mfilter)
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir() # remove all messages from self.messages that are in the new list
else: self.messages = [m for m in self.messages if not message_in(m, new_messages)]
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir() # copy the messages from the temporary list to self.messages and sort them
if force_all: self.messages += new_messages
self.messages = [] self.sort()
for file_path in sorted(file_iter):
if file_path.is_file(): def write_db(self) -> None:
if file_path.name in self.message_files and not force_all: """
continue Write all messages to the DB directory. If a message has no file_path,
try: a new one will be created. If message.file_path exists, it will be modified
message = Message.from_file(file_path, self.mfilter) to point to the DB directory.
if message: """
self.messages.append(message) write_dir(self.db_path, self.messages, self.file_suffix, self.get_next_fid)
self.message_files.add(file_path.name)
except MessageError as e: def write_cache(self) -> None:
print(f"Error processing message in '{file_path}': {str(e)}") """
self.sort() Write all messages to the cache directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the cache directory.
"""
write_dir(self.cache_path, self.messages, self.file_suffix, self.get_next_fid)

View File

@ -3,7 +3,7 @@ Module implementing message related functions and classes.
""" """
import pathlib import pathlib
import yaml import yaml
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags from .tags import Tag, TagLine, TagError, match_tags
@ -57,6 +57,20 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
return code_sections return code_sections
def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool:
"""
Searches the given message list for a message with the same file
name as the given one (i. e. it compares Message.file_path.name).
If the given message has no file_path, False is returned.
"""
if not message.file_path:
return False
for m in messages:
if m.file_path and m.file_path.name == message.file_path.name:
return True
return False
@dataclass(kw_only=True) @dataclass(kw_only=True)
class MessageFilter: class MessageFilter:
""" """
@ -436,13 +450,14 @@ class Message():
Filter tags based on their prefix (i. e. the tag starts with a given string) Filter tags based on their prefix (i. e. the tag starts with a given string)
or some contained string. or some contained string.
""" """
res_tags = self.tags if not self.tags:
if res_tags: return set()
if prefix and len(prefix) > 0: res_tags = self.tags.copy()
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} if prefix and len(prefix) > 0:
if contain and len(contain) > 0: res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
res_tags -= {tag for tag in res_tags if contain not in tag} if contain and len(contain) > 0:
return res_tags or set() res_tags -= {tag for tag in res_tags if contain not in tag}
return res_tags
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
""" """

View File

@ -1,9 +1,11 @@
import pathlib import pathlib
import tempfile
import time
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, terminal_width from chatmastermind.chat import Chat, ChatDB, terminal_width
from .test_main import CmmTestCase from .test_main import CmmTestCase
@ -12,11 +14,11 @@ class TestChat(CmmTestCase):
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('tag1')}, {Tag('atag1')},
file_path=pathlib.Path('0001.txt')) file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('tag2')}, {Tag('btag2')},
file_path=pathlib.Path('0002.txt')) file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None: def test_filter(self) -> None:
@ -42,11 +44,19 @@ class TestChat(CmmTestCase):
def test_add_msgs(self) -> None: def test_add_msgs(self) -> None:
self.chat.add_msgs([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2) self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
@ -73,14 +83,152 @@ Answer 2
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{TagLine.prefix} tag1 {TagLine.prefix} atag1
FILE: 0001.txt FILE: 0001.txt
{'-'*terminal_width()} {'-'*terminal_width()}
{Question.txt_header} {Question.txt_header}
Question 2 Question 2
{Answer.txt_header} {Answer.txt_header}
Answer 2 Answer 2
{TagLine.prefix} tag2 {TagLine.prefix} btag2
FILE: 0002.txt FILE: 0002.txt
""" """
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(CmmTestCase):
def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'),
Answer('Answer 3'),
{Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
def tearDown(self) -> None:
self.db_path.cleanup()
self.cache_path.cleanup()
pass
def test_chat_db_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*.txt')
self.assertEqual(len(chat_db.messages), 2)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_filter(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2'))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messges(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2,
self.message3, self.message4])
self.assertEqual(len(chat_db.messages), 4)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 1)
self.assertEqual(chat_db.get_next_fid(), 2)
self.assertEqual(chat_db.get_next_fid(), 3)
with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '3')
def test_chat_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
# check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
self.assertEqual(len(db_dir_files), 4)
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.write_db()
# check if the written files are in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# check if all files in the DB dir have actually been overwritten
for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))

View File

@ -2,7 +2,7 @@ import pathlib
import tempfile import tempfile
from typing import cast from typing import cast
from .test_main import CmmTestCase from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
@ -594,6 +594,12 @@ This is an answer.
self.file_path_txt.unlink() self.file_path_txt.unlink()
self.file_yaml.close() self.file_yaml.close()
self.file_path_yaml.unlink() self.file_path_yaml.unlink()
self.file_txt_no_tags.close
self.file_path_txt_no_tags.unlink()
self.file_txt_tags_empty.close
self.file_path_txt_tags_empty.unlink()
self.file_yaml_no_tags.close()
self.file_path_yaml_no_tags.unlink()
def test_tags_from_file_txt(self) -> None: def test_tags_from_file_txt(self) -> None:
tags = Message.tags_from_file(self.file_path_txt) tags = Message.tags_from_file(self.file_path_txt)
@ -671,6 +677,7 @@ class TagsFromDirTestCase(CmmTestCase):
def tearDown(self) -> None: def tearDown(self) -> None:
self.temp_dir.cleanup() self.temp_dir.cleanup()
self.temp_dir_no_tags.cleanup()
def test_tags_from_dir(self) -> None: def test_tags_from_dir(self) -> None:
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name))
@ -739,3 +746,32 @@ class MessageTagsStrTestCase(CmmTestCase):
def test_tags_str(self) -> None: def test_tags_str(self) -> None:
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(CmmTestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_filter_tags(self) -> None:
tags_all = self.message.filter_tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.message.filter_tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.message.filter_tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(CmmTestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
self.message2 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/bla/foo'))
def test_message_in(self) -> None:
self.assertTrue(message_in(self.message1, [self.message1]))
self.assertFalse(message_in(self.message1, [self.message2]))