Compare commits

...

3 Commits

Author SHA1 Message Date
60583a27b2 added tests for 'chat.py' 2023-08-31 15:47:58 +02:00
d438ba86c6 added new module 'chat.py' 2023-08-31 15:47:58 +02:00
214a6919db tags: some clarification and new tests 2023-08-31 15:47:58 +02:00
4 changed files with 588 additions and 1 deletions

272
chatmastermind/chat.py Normal file
View File

@ -0,0 +1,272 @@
"""
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
import pathlib
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
class ChatError(Exception):
pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_paged(text: str) -> None:
pager(text)
def read_dir(dir_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Reads the messages from the given folder.
Parameters:
* 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return messages
def write_dir(dir_path: pathlib.Path,
messages: list[Message],
file_suffix: str,
next_fid: Callable[[], int]) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the given directory.
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
fid = next_fid()
fname = f"{fid:04d}{file_suffix}"
file_path = dir_path / fname
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path)
@dataclass
class Chat:
"""
A class containing a complete chat history.
"""
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
try:
# the message may not have an ID if it doesn't have a file_path
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
except MessageError:
pass
def clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def add_msgs(self, msgs: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += msgs
self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
"""
tags: set[Tag] = set()
for m in self.messages:
tags |= m.filter_tags(prefix, contain)
return tags
def print(self, dump: bool = False, source_code_only: bool = False,
with_tags: bool = False, with_file: bool = False,
paged: bool = True) -> None:
if dump:
pp(self)
return
output: list[str] = []
for message in self.messages:
if source_code_only:
output.extend(source_code(message.question, include_delims=True))
continue
output.append('-' * terminal_width())
output.append(Question.txt_header)
output.append(message.question)
if message.answer:
output.append(Answer.txt_header)
output.append(message.answer)
if with_tags:
output.append(message.tags_str())
if with_file:
output.append('FILE: ' + str(message.file_path))
if paged:
print_paged('\n'.join(output))
else:
print(*output, sep='\n')
@dataclass
class ChatDB(Chat):
"""
A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
def __post_init__(self) -> None:
# contains the latest message ID
self.next_fname = self.db_path / '.next'
# make all paths absolute
self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute()
@classmethod
def from_dir(cls: Type[ChatDBInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
@classmethod
def from_messages(cls: Type[ChatDBInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
messages: list[Message],
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a ChatDB instance from the given message list.
"""
return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int:
try:
with open(self.next_fname, 'r') as f:
next_fid = int(f.read()) + 1
self.set_next_fid(next_fid)
return next_fid
except Exception:
self.set_next_fid(1)
return 1
def set_next_fid(self, fid: int) -> None:
with open(self.next_fname, 'w') as f:
f.write(f'{fid}')
def read_db(self) -> None:
"""
Reads new messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def read_cache(self) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self) -> None:
"""
Write all messages to the DB directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the DB directory.
"""
write_dir(self.db_path, self.messages, self.file_suffix, self.get_next_fid)
def write_cache(self) -> None:
"""
Write all messages to the cache directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the cache directory.
"""
write_dir(self.cache_path, self.messages, self.file_suffix, self.get_next_fid)

View File

@ -77,7 +77,8 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
or all of the tags in 'tags_and' but it must never contain any of the tags in
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
exclusion is still done if 'tags_not' is not 'None').
exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()),
they match no tags.
"""
required_tags_present = False
excluded_tags_missing = False

297
tests/test_chat.py Normal file
View File

@ -0,0 +1,297 @@
import pathlib
import tempfile
import time
from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width
from .test_main import CmmTestCase
class TestChat(CmmTestCase):
def setUp(self) -> None:
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.add_msgs([self.message2, self.message1])
self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.add_msgs([self.message1])
self.chat.clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_msgs(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"""{'-'*terminal_width()}
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_file=True)
expected_output = f"""{'-'*terminal_width()}
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{TagLine.prefix} atag1
FILE: 0001.txt
{'-'*terminal_width()}
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{TagLine.prefix} btag2
FILE: 0002.txt
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(CmmTestCase):
def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'),
Answer('Answer 3'),
{Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
def tearDown(self) -> None:
self.db_path.cleanup()
self.cache_path.cleanup()
pass
def test_chat_db_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*.txt')
self.assertEqual(len(chat_db.messages), 2)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_filter(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2'))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messges(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2,
self.message3, self.message4])
self.assertEqual(len(chat_db.messages), 4)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 1)
self.assertEqual(chat_db.get_next_fid(), 2)
self.assertEqual(chat_db.get_next_fid(), 3)
with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '3')
def test_chat_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
# check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
self.assertEqual(len(db_dir_files), 4)
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.write_db()
# check if the written files are in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# check if all files in the DB dir have actually been overwritten
for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_read(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
# create 2 new files in the DB directory
new_message1 = Message(Question('Question 5'),
Answer('Answer 5'),
{Tag('tag5')})
new_message2 = Message(Question('Question 6'),
Answer('Answer 6'),
{Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'),
Answer('Answer 5'),
{Tag('tag7')})
new_message4 = Message(Question('Question 8'),
Answer('Answer 6'),
{Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them
chat_db.read_cache()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
# an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))

View File

@ -144,3 +144,20 @@ class TestTagLine(CmmTestCase):
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not))
# Test case 10: 'tags_or' and 'tags_and' are empty, match no tags
self.assertFalse(tagline.match_tags(set(), set(), None))
# Test case 11: 'tags_or' is empty, match no tags
self.assertFalse(tagline.match_tags(set(), None, None))
# Test case 12: 'tags_and' is empty, match no tags
self.assertFalse(tagline.match_tags(None, set(), None))
# Test case 13: 'tags_or' is empty, match 'tags_and'
tags_and = {Tag('tag1'), Tag('tag2')}
self.assertTrue(tagline.match_tags(None, tags_and, None))
# Test case 14: 'tags_and' is empty, match 'tags_or'
tags_or = {Tag('tag1'), Tag('tag2')}
self.assertTrue(tagline.match_tags(tags_or, None, None))