Compare commits
5 Commits
451cda6bfa
...
2548ae5a52
| Author | SHA1 | Date | |
|---|---|---|---|
| 2548ae5a52 | |||
| 2dc8e1d6b2 | |||
| b83cbb719b | |||
| 8e1cdee3bf | |||
| 73d2a9ea3b |
@ -5,9 +5,10 @@ import shutil
|
||||
import pathlib
|
||||
from pprint import PrettyPrinter
|
||||
from pydoc import pager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any
|
||||
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
|
||||
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
|
||||
from .tags import Tag
|
||||
|
||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||
@ -29,6 +30,58 @@ def print_paged(text: str) -> None:
|
||||
pager(text)
|
||||
|
||||
|
||||
def read_dir(dir_path: pathlib.Path,
|
||||
glob: Optional[str] = None,
|
||||
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||
"""
|
||||
Reads the messages from the given folder.
|
||||
Parameters:
|
||||
* 'dir_path': source directory
|
||||
* 'glob': if specified, files will be filtered using 'path.glob()',
|
||||
otherwise it uses 'path.iterdir()'.
|
||||
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||
when reading them.
|
||||
"""
|
||||
messages: list[Message] = []
|
||||
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
||||
for file_path in sorted(file_iter):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
message = Message.from_file(file_path, mfilter)
|
||||
if message:
|
||||
messages.append(message)
|
||||
except MessageError as e:
|
||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||
return messages
|
||||
|
||||
|
||||
def write_dir(dir_path: pathlib.Path,
|
||||
messages: list[Message],
|
||||
file_suffix: str,
|
||||
next_fid: Callable[[], int]) -> None:
|
||||
"""
|
||||
Write all messages to the given directory. If a message has no file_path,
|
||||
a new one will be created. If message.file_path exists, it will be modified
|
||||
to point to the given directory.
|
||||
Parameters:
|
||||
* 'dir_path': destination directory
|
||||
* 'messages': list of messages to write
|
||||
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
|
||||
* 'next_fid': callable that returns the next file ID
|
||||
"""
|
||||
for message in messages:
|
||||
file_path = message.file_path
|
||||
# message has no file_path: create one
|
||||
if not file_path:
|
||||
fid = next_fid()
|
||||
fname = f"{fid:04d}{file_suffix}"
|
||||
file_path = dir_path / fname
|
||||
# file_path does not point to given directory: modify it
|
||||
elif not file_path.parent.samefile(dir_path):
|
||||
file_path = dir_path / file_path.name
|
||||
message.to_file(file_path)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chat:
|
||||
"""
|
||||
@ -67,6 +120,15 @@ class Chat:
|
||||
self.messages += msgs
|
||||
self.sort()
|
||||
|
||||
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
"""
|
||||
Get the tags of all messages, optionally filtered by prefix or substring.
|
||||
"""
|
||||
tags: set[Tag] = set()
|
||||
for m in self.messages:
|
||||
tags |= m.filter_tags(prefix, contain)
|
||||
return tags
|
||||
|
||||
def print(self, dump: bool = False, source_code_only: bool = False,
|
||||
with_tags: bool = False, with_file: bool = False,
|
||||
paged: bool = True) -> None:
|
||||
@ -113,8 +175,10 @@ class ChatDB(Chat):
|
||||
file_suffix: str = default_file_suffix
|
||||
# the glob pattern for all messages
|
||||
glob: Optional[str] = None
|
||||
# set containing all file names of the current messages
|
||||
message_files: set[str] = field(default_factory=set, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# contains the latest message ID
|
||||
self.next_fname = self.db_path / '.next'
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls: Type[ChatDBInst],
|
||||
@ -128,94 +192,78 @@ class ChatDB(Chat):
|
||||
Parameters:
|
||||
* 'cache_path': path to the directory for temporary messages
|
||||
* 'db_path': path to the directory for persistent messages
|
||||
* 'glob' fs specified, files will be filtered using 'path.glob()',
|
||||
* 'glob': if specified, files will be filtered using 'path.glob()',
|
||||
otherwise it uses 'path.iterdir()'.
|
||||
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||
when reading them.
|
||||
"""
|
||||
messages: list[Message] = []
|
||||
message_files: set[str] = set()
|
||||
file_iter = db_path.glob(glob) if glob else db_path.iterdir()
|
||||
for file_path in sorted(file_iter):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
message = Message.from_file(file_path, mfilter)
|
||||
if message:
|
||||
messages.append(message)
|
||||
message_files.add(file_path.name)
|
||||
except MessageError as e:
|
||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||
messages = read_dir(db_path, glob, mfilter)
|
||||
return cls(messages, cache_path, db_path, mfilter,
|
||||
cls.default_file_suffix, glob, message_files)
|
||||
cls.default_file_suffix, glob)
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls: Type[ChatDBInst],
|
||||
cache_path: pathlib.Path,
|
||||
db_path: pathlib.Path,
|
||||
messages: list[Message],
|
||||
mfilter: Optional[MessageFilter]) -> ChatDBInst:
|
||||
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
||||
"""
|
||||
Create a ChatDB instance from the given message list. Note that the next
|
||||
call to 'dump()' will write all files in order to synchronize the messages.
|
||||
Similarly, 'update()' will read all messages, so you may end up with a lot
|
||||
of duplicates when using 'update()' first.
|
||||
Create a ChatDB instance from the given message list.
|
||||
"""
|
||||
return cls(messages, cache_path, db_path, mfilter)
|
||||
|
||||
def get_next_fid(self) -> int:
|
||||
next_fname = self.db_path / '.next'
|
||||
try:
|
||||
with open(next_fname, 'r') as f:
|
||||
return int(f.read()) + 1
|
||||
with open(self.next_fname, 'r') as f:
|
||||
next_fid = int(f.read()) + 1
|
||||
self.set_next_fid(next_fid)
|
||||
return next_fid
|
||||
except Exception:
|
||||
self.set_next_fid(1)
|
||||
return 1
|
||||
|
||||
def set_next_fid(self, fid: int) -> None:
|
||||
next_fname = self.db_path / '.next'
|
||||
with open(next_fname, 'w') as f:
|
||||
with open(self.next_fname, 'w') as f:
|
||||
f.write(f'{fid}')
|
||||
|
||||
def dump(self, to_db: bool = False, force_all: bool = False) -> None:
|
||||
def read_db(self) -> None:
|
||||
"""
|
||||
Write all messages to 'cache_path' (or 'db_path' if 'to_db' is True). If a message
|
||||
has no file_path, a new one will be created. By default, only messages that have
|
||||
not been written (or read) before will be dumped. Use 'force_all' to force writing
|
||||
all message files.
|
||||
Reads new messages from the DB directory. New ones are added to the internal list,
|
||||
existing ones are kept or replaced. A message is determined as 'existing' if a
|
||||
message with the same base filename (i. e. 'file_path.name') is already in the list.
|
||||
"""
|
||||
for message in self.messages:
|
||||
# skip messages that we have already written (or read)
|
||||
if message.file_path and message.file_path in self.message_files and not force_all:
|
||||
continue
|
||||
file_path = message.file_path
|
||||
if not file_path:
|
||||
fid = self.get_next_fid()
|
||||
fname = f"{fid:04d}{self.file_suffix}"
|
||||
file_path = self.db_path / fname if to_db else self.cache_path / fname
|
||||
self.set_next_fid(fid)
|
||||
message.to_file(file_path)
|
||||
|
||||
def update(self, from_cache: bool = False, force_all: bool = False) -> None:
|
||||
"""
|
||||
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
|
||||
By default, only messages that have not been read (or written) before will
|
||||
be read. Use 'force_all' to force reading all messages (existing messages
|
||||
are discarded).
|
||||
"""
|
||||
if from_cache:
|
||||
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
||||
else:
|
||||
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
||||
if force_all:
|
||||
self.messages = []
|
||||
for file_path in sorted(file_iter):
|
||||
if file_path.is_file():
|
||||
if file_path.name in self.message_files and not force_all:
|
||||
continue
|
||||
try:
|
||||
message = Message.from_file(file_path, self.mfilter)
|
||||
if message:
|
||||
self.messages.append(message)
|
||||
self.message_files.add(file_path.name)
|
||||
except MessageError as e:
|
||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
|
||||
# remove all messages from self.messages that are in the new list
|
||||
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
|
||||
# copy the messages from the temporary list to self.messages and sort them
|
||||
self.messages += new_messages
|
||||
self.sort()
|
||||
|
||||
def read_cache(self) -> None:
|
||||
"""
|
||||
Reads new messages from the cache directory. New ones are added to the internal list,
|
||||
existing ones are kept or replaced. A message is determined as 'existing' if a
|
||||
message with the same base filename (i. e. 'file_path.name') is already in the list.
|
||||
"""
|
||||
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
|
||||
# remove all messages from self.messages that are in the new list
|
||||
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
|
||||
# copy the messages from the temporary list to self.messages and sort them
|
||||
self.messages += new_messages
|
||||
self.sort()
|
||||
|
||||
def write_db(self) -> None:
|
||||
"""
|
||||
Write all messages to the DB directory. If a message has no file_path,
|
||||
a new one will be created. If message.file_path exists, it will be modified
|
||||
to point to the DB directory.
|
||||
"""
|
||||
write_dir(self.db_path, self.messages, self.file_suffix, self.get_next_fid)
|
||||
|
||||
def write_cache(self) -> None:
|
||||
"""
|
||||
Write all messages to the cache directory. If a message has no file_path,
|
||||
a new one will be created. If message.file_path exists, it will be modified
|
||||
to point to the cache directory.
|
||||
"""
|
||||
write_dir(self.cache_path, self.messages, self.file_suffix, self.get_next_fid)
|
||||
|
||||
@ -3,7 +3,7 @@ Module implementing message related functions and classes.
|
||||
"""
|
||||
import pathlib
|
||||
import yaml
|
||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal
|
||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from .tags import Tag, TagLine, TagError, match_tags
|
||||
|
||||
@ -57,6 +57,20 @@ def source_code(text: str, include_delims: bool = False) -> list[str]:
|
||||
return code_sections
|
||||
|
||||
|
||||
def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool:
|
||||
"""
|
||||
Searches the given message list for a message with the same file
|
||||
name as the given one (i. e. it compares Message.file_path.name).
|
||||
If the given message has no file_path, False is returned.
|
||||
"""
|
||||
if not message.file_path:
|
||||
return False
|
||||
for m in messages:
|
||||
if m.file_path and m.file_path.name == message.file_path.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MessageFilter:
|
||||
"""
|
||||
@ -436,13 +450,14 @@ class Message():
|
||||
Filter tags based on their prefix (i. e. the tag starts with a given string)
|
||||
or some contained string.
|
||||
"""
|
||||
res_tags = self.tags
|
||||
if res_tags:
|
||||
if not self.tags:
|
||||
return set()
|
||||
res_tags = self.tags.copy()
|
||||
if prefix and len(prefix) > 0:
|
||||
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
|
||||
if contain and len(contain) > 0:
|
||||
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||
return res_tags or set()
|
||||
return res_tags
|
||||
|
||||
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
|
||||
"""
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from chatmastermind.tags import TagLine
|
||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||
from chatmastermind.chat import Chat, terminal_width
|
||||
from chatmastermind.chat import Chat, ChatDB, terminal_width
|
||||
from .test_main import CmmTestCase
|
||||
|
||||
|
||||
@ -12,11 +14,11 @@ class TestChat(CmmTestCase):
|
||||
self.chat = Chat([])
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('tag1')},
|
||||
{Tag('atag1')},
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('tag2')},
|
||||
{Tag('btag2')},
|
||||
file_path=pathlib.Path('0002.txt'))
|
||||
|
||||
def test_filter(self) -> None:
|
||||
@ -42,11 +44,19 @@ class TestChat(CmmTestCase):
|
||||
|
||||
def test_add_msgs(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
|
||||
self.assertEqual(len(self.chat.messages), 2)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
|
||||
def test_tags(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
tags_all = self.chat.tags()
|
||||
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||
tags_pref = self.chat.tags(prefix='a')
|
||||
self.assertSetEqual(tags_pref, {Tag('atag1')})
|
||||
tags_cont = self.chat.tags(contain='2')
|
||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
@ -73,14 +83,152 @@ Answer 2
|
||||
Question 1
|
||||
{Answer.txt_header}
|
||||
Answer 1
|
||||
{TagLine.prefix} tag1
|
||||
{TagLine.prefix} atag1
|
||||
FILE: 0001.txt
|
||||
{'-'*terminal_width()}
|
||||
{Question.txt_header}
|
||||
Question 2
|
||||
{Answer.txt_header}
|
||||
Answer 2
|
||||
{TagLine.prefix} tag2
|
||||
{TagLine.prefix} btag2
|
||||
FILE: 0002.txt
|
||||
"""
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
|
||||
class TestChatDB(CmmTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.db_path = tempfile.TemporaryDirectory()
|
||||
self.cache_path = tempfile.TemporaryDirectory()
|
||||
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('tag1')},
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('tag2')},
|
||||
file_path=pathlib.Path('0002.yaml'))
|
||||
self.message3 = Message(Question('Question 3'),
|
||||
Answer('Answer 3'),
|
||||
{Tag('tag3')},
|
||||
file_path=pathlib.Path('0003.txt'))
|
||||
self.message4 = Message(Question('Question 4'),
|
||||
Answer('Answer 4'),
|
||||
{Tag('tag4')},
|
||||
file_path=pathlib.Path('0004.yaml'))
|
||||
|
||||
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.db_path.cleanup()
|
||||
self.cache_path.cleanup()
|
||||
pass
|
||||
|
||||
def test_chat_db_from_dir(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(len(chat_db.messages), 4)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
# check that the files are sorted
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path,
|
||||
pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_chat_db_from_dir_glob(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
glob='*.txt')
|
||||
self.assertEqual(len(chat_db.messages), 2)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
|
||||
def test_chat_db_filter(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(answer_contains='Answer 2'))
|
||||
self.assertEqual(len(chat_db.messages), 1)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||
|
||||
def test_chat_db_from_messges(self) -> None:
|
||||
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
messages=[self.message1, self.message2,
|
||||
self.message3, self.message4])
|
||||
self.assertEqual(len(chat_db.messages), 4)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
|
||||
def test_chat_db_fids(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.get_next_fid(), 1)
|
||||
self.assertEqual(chat_db.get_next_fid(), 2)
|
||||
self.assertEqual(chat_db.get_next_fid(), 3)
|
||||
with open(chat_db.next_fname, 'r') as f:
|
||||
self.assertEqual(f.read(), '3')
|
||||
|
||||
def test_chat_db_write(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# check that Message.file_path is correct
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
|
||||
# check that Message.file_path has been correctly updated
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
|
||||
|
||||
# check the timestamp of the files in the DB directory
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
|
||||
# overwrite the messages in the db directory
|
||||
time.sleep(0.05)
|
||||
chat_db.write_db()
|
||||
# check if the written files are in the DB directory
|
||||
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
||||
# check if all files in the DB dir have actually been overwritten
|
||||
for file in db_dir_files:
|
||||
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
|
||||
# check that Message.file_path has been correctly updated (again)
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
@ -2,7 +2,7 @@ import pathlib
|
||||
import tempfile
|
||||
from typing import cast
|
||||
from .test_main import CmmTestCase
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
||||
from chatmastermind.tags import Tag, TagLine
|
||||
|
||||
|
||||
@ -594,6 +594,12 @@ This is an answer.
|
||||
self.file_path_txt.unlink()
|
||||
self.file_yaml.close()
|
||||
self.file_path_yaml.unlink()
|
||||
self.file_txt_no_tags.close
|
||||
self.file_path_txt_no_tags.unlink()
|
||||
self.file_txt_tags_empty.close
|
||||
self.file_path_txt_tags_empty.unlink()
|
||||
self.file_yaml_no_tags.close()
|
||||
self.file_path_yaml_no_tags.unlink()
|
||||
|
||||
def test_tags_from_file_txt(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_txt)
|
||||
@ -671,6 +677,7 @@ class TagsFromDirTestCase(CmmTestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.temp_dir.cleanup()
|
||||
self.temp_dir_no_tags.cleanup()
|
||||
|
||||
def test_tags_from_dir(self) -> None:
|
||||
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name))
|
||||
@ -739,3 +746,32 @@ class MessageTagsStrTestCase(CmmTestCase):
|
||||
|
||||
def test_tags_str(self) -> None:
|
||||
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
||||
|
||||
|
||||
class MessageFilterTagsTestCase(CmmTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_filter_tags(self) -> None:
|
||||
tags_all = self.message.filter_tags()
|
||||
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||
tags_pref = self.message.filter_tags(prefix='a')
|
||||
self.assertSetEqual(tags_pref, {Tag('atag1')})
|
||||
tags_cont = self.message.filter_tags(contain='2')
|
||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||
|
||||
|
||||
class MessageInTestCase(CmmTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message1 = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
self.message2 = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/bla/foo'))
|
||||
|
||||
def test_message_in(self) -> None:
|
||||
self.assertTrue(message_in(self.message1, [self.message1]))
|
||||
self.assertFalse(message_in(self.message1, [self.message2]))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user