Compare commits

..

4 Commits

4 changed files with 117 additions and 16 deletions

View File

@ -2,11 +2,12 @@
Module implementing various chat classes and functions for managing a chat history. Module implementing various chat classes and functions for managing a chat history.
""" """
import shutil import shutil
from pprint import PrettyPrinter
import pathlib import pathlib
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TypeVar, Type, Optional, ClassVar, Any from typing import TypeVar, Type, Optional, ClassVar, Any
from .message import Message, MessageFilter, MessageError from .message import Question, Answer, Message, MessageFilter, MessageError, source_code
ChatInst = TypeVar('ChatInst', bound='Chat') ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
@ -24,6 +25,10 @@ def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_paged(text: str) -> None:
pager(text)
@dataclass @dataclass
class Chat: class Chat:
""" """
@ -62,22 +67,31 @@ class Chat:
self.messages += msgs self.messages += msgs
self.sort() self.sort()
def print(self, dump: bool = False) -> None: def print(self, dump: bool = False, source_code_only: bool = False,
with_tags: bool = False, with_file: bool = False,
paged: bool = True) -> None:
if dump: if dump:
pp(self) pp(self)
return return
# for message in self.messages: output: list[str] = []
# text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 for message in self.messages:
# if source_code: if source_code_only:
# display_source_code(message['content']) output.extend(source_code(message.question, include_delims=True))
# continue continue
# if message['role'] == 'user': output.append('-' * terminal_width())
# print('-' * terminal_width()) output.append(Question.txt_header)
# if text_too_long: output.append(message.question)
# print(f"{message['role'].upper()}:") if message.answer:
# print(message['content']) output.append(Answer.txt_header)
# else: output.append(message.answer)
# print(f"{message['role'].upper()}: {message['content']}") if with_tags:
output.append(message.tags_str())
if with_file:
output.append('FILE: ' + str(message.file_path))
if paged:
print_paged('\n'.join(output))
else:
print(*output, sep='\n')
@dataclass @dataclass
@ -184,12 +198,15 @@ class ChatDB(Chat):
""" """
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true). Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
By default, only messages that have not been read (or written) before will By default, only messages that have not been read (or written) before will
be read. Use 'force_all' to force reading all messages. be read. Use 'force_all' to force reading all messages (existing messages
are discarded).
""" """
if from_cache: if from_cache:
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir() file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
else: else:
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir() file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
if force_all:
self.messages = []
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if file_path.is_file(): if file_path.is_file():
if file_path.name in self.message_files and not force_all: if file_path.name in self.message_files and not force_all:

View File

@ -444,6 +444,16 @@ class Message():
res_tags -= {tag for tag in res_tags if contain not in tag} res_tags -= {tag for tag in res_tags if contain not in tag}
return res_tags or set() return res_tags or set()
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
"""
Returns all tags as a string with the TagLine prefix. Optionally filtered
using 'Message.filter_tags()'.
"""
if self.tags:
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
else:
return str(TagLine.from_set(set()))
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
""" """
Matches the current Message to the given filter atttributes. Matches the current Message to the given filter atttributes.

64
tests/test_chat.py Normal file
View File

@ -0,0 +1,64 @@
import pathlib
from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, terminal_width
from .test_main import CmmTestCase
class TestChat(CmmTestCase):
def setUp(self) -> None:
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')},
file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.add_msgs([self.message2, self.message1])
self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.add_msgs([self.message1])
self.chat.clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_msgs(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\nAnswer 2\n"
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.print(paged=False, with_tags=True, with_file=True)
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
\n{TagLine.prefix} tag1\nFILE: 0001.txt\n{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\n\
Answer 2\n{TagLine.prefix} tag2\nFILE: 0002.txt\n"
self.assertEqual(mock_stdout.getvalue(), expected_output)

View File

@ -729,3 +729,13 @@ class MessageHashTestCase(CmmTestCase):
self.assertEqual(len(msgs), 3) self.assertEqual(len(msgs), 3)
for msg in [self.message1, self.message2, self.message3]: for msg in [self.message1, self.message2, self.message3]:
self.assertIn(msg, msgs) self.assertIn(msg, msgs)
class MessageTagsStrTestCase(CmmTestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('tag1')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_tags_str(self) -> None:
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')