Compare commits

...

4 Commits

6 changed files with 508 additions and 14 deletions

221
chatmastermind/chat.py Normal file
View File

@ -0,0 +1,221 @@
"""
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
import pathlib
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass, field
from typing import TypeVar, Type, Optional, ClassVar, Any
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
class ChatError(Exception):
pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_paged(text: str) -> None:
pager(text)
@dataclass
class Chat:
"""
A class containing a complete chat history.
"""
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
try:
# the message may not have an ID if it doesn't have a file_path
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
except MessageError:
pass
def clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def add_msgs(self, msgs: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += msgs
self.sort()
def print(self, dump: bool = False, source_code_only: bool = False,
with_tags: bool = False, with_file: bool = False,
paged: bool = True) -> None:
if dump:
pp(self)
return
output: list[str] = []
for message in self.messages:
if source_code_only:
output.extend(source_code(message.question, include_delims=True))
continue
output.append('-' * terminal_width())
output.append(Question.txt_header)
output.append(message.question)
if message.answer:
output.append(Answer.txt_header)
output.append(message.answer)
if with_tags:
output.append(message.tags_str())
if with_file:
output.append('FILE: ' + str(message.file_path))
if paged:
print_paged('\n'.join(output))
else:
print(*output, sep='\n')
@dataclass
class ChatDB(Chat):
"""
A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
# set containing all file names of the current messages
message_files: set[str] = field(default_factory=set, repr=False)
@classmethod
def from_dir(cls: Type[ChatDBInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob' fs specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
message_files: set[str] = set()
file_iter = db_path.glob(glob) if glob else db_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
message_files.add(file_path.name)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob, message_files)
@classmethod
def from_messages(cls: Type[ChatDBInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
messages: list[Message],
mfilter: Optional[MessageFilter]) -> ChatDBInst:
"""
Create a ChatDB instance from the given message list. Note that the next
call to 'dump()' will write all files in order to synchronize the messages.
Similarly, 'update()' will read all messages, so you may end up with a lot
of duplicates when using 'update()' first.
"""
return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int:
next_fname = self.db_path / '.next'
try:
with open(next_fname, 'r') as f:
return int(f.read()) + 1
except Exception:
return 1
def set_next_fid(self, fid: int) -> None:
next_fname = self.db_path / '.next'
with open(next_fname, 'w') as f:
f.write(f'{fid}')
def dump(self, to_db: bool = False, force_all: bool = False) -> None:
"""
Write all messages to 'cache_path' (or 'db_path' if 'to_db' is True). If a message
has no file_path, a new one will be created. By default, only messages that have
not been written (or read) before will be dumped. Use 'force_all' to force writing
all message files.
"""
for message in self.messages:
# skip messages that we have already written (or read)
if message.file_path and message.file_path in self.message_files and not force_all:
continue
file_path = message.file_path
if not file_path:
fid = self.get_next_fid()
fname = f"{fid:04d}{self.file_suffix}"
file_path = self.db_path / fname if to_db else self.cache_path / fname
self.set_next_fid(fid)
message.to_file(file_path)
def update(self, from_cache: bool = False, force_all: bool = False) -> None:
"""
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
By default, only messages that have not been read (or written) before will
be read. Use 'force_all' to force reading all messages (existing messages
are discarded).
"""
if from_cache:
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
else:
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
if force_all:
self.messages = []
for file_path in sorted(file_iter):
if file_path.is_file():
if file_path.name in self.message_files and not force_all:
continue
try:
message = Message.from_file(file_path, self.mfilter)
if message:
self.messages.append(message)
self.message_files.add(file_path.name)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
self.sort()

View File

@ -219,21 +219,57 @@ class Message():
file_path=data.get(cls.file_yaml_key, None)) file_path=data.get(cls.file_yaml_key, None))
@classmethod @classmethod
def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: def tags_from_file(cls: Type[MessageInst],
file_path: pathlib.Path,
prefix: Optional[str] = None,
contain: Optional[str] = None) -> set[Tag]:
""" """
Return only the tags from the given Message file. Return only the tags from the given Message file,
optionally filtered based on prefix or contained string.
""" """
tags: set[Tag] = set()
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes: if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
# for TXT, it's enough to read the TagLine
if file_path.suffix == '.txt': if file_path.suffix == '.txt':
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags() try:
tags = TagLine(fd.readline()).tags(prefix, contain)
except TagError:
pass # message without tags
else: # '.yaml' else: # '.yaml'
with open(file_path, "r") as fd: try:
data = yaml.load(fd, Loader=yaml.FullLoader) message = cls.from_file(file_path)
tags = set(sorted(data[cls.tags_yaml_key])) if message:
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
if msg_tags:
tags = msg_tags
return tags
@classmethod
def tags_from_dir(cls: Type[MessageInst],
path: pathlib.Path,
glob: Optional[str] = None,
prefix: Optional[str] = None,
contain: Optional[str] = None) -> set[Tag]:
"""
Return only the tags from message files in the given directory.
The files can be filtered using 'glob', the tags by using 'prefix'
and 'contain'.
"""
tags: set[Tag] = set()
file_iter = path.glob(glob) if glob else path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
tags |= cls.tags_from_file(file_path, prefix, contain)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return tags return tags
@classmethod @classmethod
@ -395,6 +431,29 @@ class Message():
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, fd, sort_keys=False) yaml.dump(data, fd, sort_keys=False)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Filter tags based on their prefix (i. e. the tag starts with a given string)
or some contained string.
"""
res_tags = self.tags
if res_tags:
if prefix and len(prefix) > 0:
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
if contain and len(contain) > 0:
res_tags -= {tag for tag in res_tags if contain not in tag}
return res_tags or set()
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
"""
Returns all tags as a string with the TagLine prefix. Optionally filtered
using 'Message.filter_tags()'.
"""
if self.tags:
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
else:
return str(TagLine.from_set(set()))
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
""" """
Matches the current Message to the given filter atttributes. Matches the current Message to the given filter atttributes.

View File

@ -118,11 +118,14 @@ class TagLine(str):
""" """
return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) return cls(' '.join([cls.prefix] + sorted([t for t in tags])))
def tags(self) -> set[Tag]: def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Returns all tags contained in this line as a set. Returns all tags contained in this line as a set, optionally
filtered based on prefix or contained string.
""" """
tagstr = self[len(self.prefix):].strip() tagstr = self[len(self.prefix):].strip()
if tagstr == '':
return set() # no tags, only prefix
separator = Tag.default_separator separator = Tag.default_separator
# look for alternative separators and use the first one found # look for alternative separators and use the first one found
# -> we don't support different separators in the same TagLine # -> we don't support different separators in the same TagLine
@ -130,7 +133,12 @@ class TagLine(str):
if s in tagstr: if s in tagstr:
separator = s separator = s
break break
return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)]))
if prefix and len(prefix) > 0:
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
if contain and len(contain) > 0:
res_tags -= {tag for tag in res_tags if contain not in tag}
return res_tags or set()
def merge(self, taglines: set['TagLine']) -> 'TagLine': def merge(self, taglines: set['TagLine']) -> 'TagLine':
""" """

64
tests/test_chat.py Normal file
View File

@ -0,0 +1,64 @@
import pathlib
from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, terminal_width
from .test_main import CmmTestCase
class TestChat(CmmTestCase):
def setUp(self) -> None:
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')},
file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.add_msgs([self.message2, self.message1])
self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.add_msgs([self.message1])
self.chat.clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_msgs(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\nAnswer 2\n"
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.print(paged=False, with_tags=True, with_file=True)
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
\n{TagLine.prefix} tag1\nFILE: 0001.txt\n{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\n\
Answer 2\n{TagLine.prefix} tag2\nFILE: 0002.txt\n"
self.assertEqual(mock_stdout.getvalue(), expected_output)

View File

@ -543,7 +543,24 @@ class TagsFromFileTestCase(CmmTestCase):
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name) self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd: with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS:
{Question.txt_header} {Question.txt_header}
This is a question. This is a question.
{Answer.txt_header} {Answer.txt_header}
@ -560,6 +577,16 @@ This is an answer.
{Message.tags_yaml_key}: {Message.tags_yaml_key}:
- tag1 - tag1
- tag2 - tag2
- ptag3
""")
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
""") """)
def tearDown(self) -> None: def tearDown(self) -> None:
@ -570,11 +597,94 @@ This is an answer.
def test_tags_from_file_txt(self) -> None: def test_tags_from_file_txt(self) -> None:
tags = Message.tags_from_file(self.file_path_txt) tags = Message.tags_from_file(self.file_path_txt)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
def test_tags_from_file_txt_no_tags(self) -> None:
tags = Message.tags_from_file(self.file_path_txt_no_tags)
self.assertSetEqual(tags, set())
def test_tags_from_file_txt_tags_empty(self) -> None:
tags = Message.tags_from_file(self.file_path_txt_tags_empty)
self.assertSetEqual(tags, set())
def test_tags_from_file_yaml(self) -> None: def test_tags_from_file_yaml(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml) tags = Message.tags_from_file(self.file_path_yaml)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
def test_tags_from_file_yaml_no_tags(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml_no_tags)
self.assertSetEqual(tags, set())
def test_tags_from_file_txt_prefix(self) -> None:
tags = Message.tags_from_file(self.file_path_txt, prefix='p')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_txt, prefix='R')
self.assertSetEqual(tags, set())
def test_tags_from_file_yaml_prefix(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml, prefix='p')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_yaml, prefix='R')
self.assertSetEqual(tags, set())
def test_tags_from_file_txt_contain(self) -> None:
tags = Message.tags_from_file(self.file_path_txt, contain='3')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_txt, contain='R')
self.assertSetEqual(tags, set())
def test_tags_from_file_yaml_contain(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml, contain='3')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_yaml, contain='R')
self.assertSetEqual(tags, set())
class TagsFromDirTestCase(CmmTestCase):
def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
self.tag_sets = [
{Tag('atag1'), Tag('atag2')},
{Tag('btag3'), Tag('btag4')},
{Tag('ctag5'), Tag('ctag6')}
]
self.files = [
pathlib.Path(self.temp_dir.name, 'file1.txt'),
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
pathlib.Path(self.temp_dir.name, 'file3.txt')
]
self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
]
for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags)
message.to_file(file)
for file in self.files_no_tags:
message = Message(Question('This is a question.'),
Answer('This is an answer.'))
message.to_file(file)
def tearDown(self) -> None:
self.temp_dir.cleanup()
def test_tags_from_dir(self) -> None:
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name))
expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2]
self.assertEqual(all_tags, expected_tags)
def test_tags_from_dir_prefix(self) -> None:
atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a')
expected_tags = self.tag_sets[0]
self.assertEqual(atags, expected_tags)
def test_tags_from_dir_no_tags(self) -> None:
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name))
self.assertSetEqual(all_tags, set())
class MessageIDTestCase(CmmTestCase): class MessageIDTestCase(CmmTestCase):
@ -619,3 +729,13 @@ class MessageHashTestCase(CmmTestCase):
self.assertEqual(len(msgs), 3) self.assertEqual(len(msgs), 3)
for msg in [self.message1, self.message2, self.message3]: for msg in [self.message1, self.message2, self.message3]:
self.assertIn(msg, msgs) self.assertIn(msg, msgs)
class MessageTagsStrTestCase(CmmTestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('tag1')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_tags_str(self) -> None:
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')

View File

@ -40,15 +40,37 @@ class TestTagLine(CmmTestCase):
self.assertEqual(tagline, 'TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_tags(self) -> None: def test_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2') tagline = TagLine('TAGS: atag1 btag2')
tags = tagline.tags() tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) self.assertEqual(tags, {Tag('atag1'), Tag('btag2')})
def test_tags_empty(self) -> None:
tagline = TagLine('TAGS:')
self.assertSetEqual(tagline.tags(), set())
def test_tags_with_newline(self) -> None: def test_tags_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2') tagline = TagLine('TAGS: tag1\n tag2')
tags = tagline.tags() tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_tags_prefix(self) -> None:
tagline = TagLine('TAGS: atag1 stag2 stag3')
tags = tagline.tags(prefix='a')
self.assertSetEqual(tags, {Tag('atag1')})
tags = tagline.tags(prefix='s')
self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')})
tags = tagline.tags(prefix='R')
self.assertSetEqual(tags, set())
def test_tags_contain(self) -> None:
tagline = TagLine('TAGS: atag1 stag2 stag3')
tags = tagline.tags(contain='t')
self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')})
tags = tagline.tags(contain='1')
self.assertSetEqual(tags, {Tag('atag1')})
tags = tagline.tags(contain='R')
self.assertSetEqual(tags, set())
def test_merge(self) -> None: def test_merge(self) -> None:
tagline1 = TagLine('TAGS: tag1 tag2') tagline1 = TagLine('TAGS: tag1 tag2')
tagline2 = TagLine('TAGS: tag2 tag3') tagline2 = TagLine('TAGS: tag2 tag3')