Compare commits
4 Commits
5ebb9f3295
...
4e4e6f56b0
| Author | SHA1 | Date | |
|---|---|---|---|
| 4e4e6f56b0 | |||
| a7fa316487 | |||
| 169f1bb458 | |||
| 7f91a2b567 |
@ -2,11 +2,12 @@
|
||||
Module implementing various chat classes and functions for managing a chat history.
|
||||
"""
|
||||
import shutil
|
||||
from pprint import PrettyPrinter
|
||||
import pathlib
|
||||
from pprint import PrettyPrinter
|
||||
from pydoc import pager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any
|
||||
from .message import Message, MessageFilter, MessageError
|
||||
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code
|
||||
|
||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||
@ -24,6 +25,10 @@ def pp(*args: Any, **kwargs: Any) -> None:
|
||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def print_paged(text: str) -> None:
|
||||
pager(text)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chat:
|
||||
"""
|
||||
@ -62,22 +67,31 @@ class Chat:
|
||||
self.messages += msgs
|
||||
self.sort()
|
||||
|
||||
def print(self, dump: bool = False) -> None:
|
||||
def print(self, dump: bool = False, source_code_only: bool = False,
|
||||
with_tags: bool = False, with_file: bool = False,
|
||||
paged: bool = True) -> None:
|
||||
if dump:
|
||||
pp(self)
|
||||
return
|
||||
# for message in self.messages:
|
||||
# text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
|
||||
# if source_code:
|
||||
# display_source_code(message['content'])
|
||||
# continue
|
||||
# if message['role'] == 'user':
|
||||
# print('-' * terminal_width())
|
||||
# if text_too_long:
|
||||
# print(f"{message['role'].upper()}:")
|
||||
# print(message['content'])
|
||||
# else:
|
||||
# print(f"{message['role'].upper()}: {message['content']}")
|
||||
output: list[str] = []
|
||||
for message in self.messages:
|
||||
if source_code_only:
|
||||
output.extend(source_code(message.question, include_delims=True))
|
||||
continue
|
||||
output.append('-' * terminal_width())
|
||||
output.append(Question.txt_header)
|
||||
output.append(message.question)
|
||||
if message.answer:
|
||||
output.append(Answer.txt_header)
|
||||
output.append(message.answer)
|
||||
if with_tags:
|
||||
output.append(message.tags_str())
|
||||
if with_file:
|
||||
output.append('FILE: ' + str(message.file_path))
|
||||
if paged:
|
||||
print_paged('\n'.join(output))
|
||||
else:
|
||||
print(*output, sep='\n')
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -184,12 +198,15 @@ class ChatDB(Chat):
|
||||
"""
|
||||
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
|
||||
By default, only messages that have not been read (or written) before will
|
||||
be read. Use 'force_all' to force reading all messages.
|
||||
be read. Use 'force_all' to force reading all messages (existing messages
|
||||
are discarded).
|
||||
"""
|
||||
if from_cache:
|
||||
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
||||
else:
|
||||
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
||||
if force_all:
|
||||
self.messages = []
|
||||
for file_path in sorted(file_iter):
|
||||
if file_path.is_file():
|
||||
if file_path.name in self.message_files and not force_all:
|
||||
|
||||
@ -444,6 +444,16 @@ class Message():
|
||||
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||
return res_tags or set()
|
||||
|
||||
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns all tags as a string with the TagLine prefix. Optionally filtered
|
||||
using 'Message.filter_tags()'.
|
||||
"""
|
||||
if self.tags:
|
||||
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
|
||||
else:
|
||||
return str(TagLine.from_set(set()))
|
||||
|
||||
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
|
||||
"""
|
||||
Matches the current Message to the given filter atttributes.
|
||||
|
||||
64
tests/test_chat.py
Normal file
64
tests/test_chat.py
Normal file
@ -0,0 +1,64 @@
|
||||
import pathlib
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from chatmastermind.tags import TagLine
|
||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||
from chatmastermind.chat import Chat, terminal_width
|
||||
from .test_main import CmmTestCase
|
||||
|
||||
|
||||
class TestChat(CmmTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.chat = Chat([])
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('tag1')},
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('tag2')},
|
||||
file_path=pathlib.Path('0002.txt'))
|
||||
|
||||
def test_filter(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
|
||||
|
||||
self.assertEqual(len(self.chat.messages), 1)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
|
||||
def test_sort(self) -> None:
|
||||
self.chat.add_msgs([self.message2, self.message1])
|
||||
self.chat.sort()
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
self.chat.sort(reverse=True)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 2')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 1')
|
||||
|
||||
def test_clear(self) -> None:
|
||||
self.chat.add_msgs([self.message1])
|
||||
self.chat.clear()
|
||||
self.assertEqual(len(self.chat.messages), 0)
|
||||
|
||||
def test_add_msgs(self) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
|
||||
self.assertEqual(len(self.chat.messages), 2)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_msgs([self.message1, self.message2])
|
||||
self.chat.print(paged=False)
|
||||
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
|
||||
{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\nAnswer 2\n"
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.print(paged=False, with_tags=True, with_file=True)
|
||||
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
|
||||
\n{TagLine.prefix} tag1\nFILE: 0001.txt\n{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\n\
|
||||
Answer 2\n{TagLine.prefix} tag2\nFILE: 0002.txt\n"
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
@ -729,3 +729,13 @@ class MessageHashTestCase(CmmTestCase):
|
||||
self.assertEqual(len(msgs), 3)
|
||||
for msg in [self.message1, self.message2, self.message3]:
|
||||
self.assertIn(msg, msgs)
|
||||
|
||||
|
||||
class MessageTagsStrTestCase(CmmTestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('tag1')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_tags_str(self) -> None:
|
||||
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user