Compare commits
3 Commits
4e4e6f56b0
...
5ebb9f3295
| Author | SHA1 | Date | |
|---|---|---|---|
| 5ebb9f3295 | |||
| 9682463af1 | |||
| c7b99fe9b4 |
@ -2,12 +2,11 @@
|
|||||||
Module implementing various chat classes and functions for managing a chat history.
|
Module implementing various chat classes and functions for managing a chat history.
|
||||||
"""
|
"""
|
||||||
import shutil
|
import shutil
|
||||||
import pathlib
|
|
||||||
from pprint import PrettyPrinter
|
from pprint import PrettyPrinter
|
||||||
from pydoc import pager
|
import pathlib
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TypeVar, Type, Optional, ClassVar, Any
|
from typing import TypeVar, Type, Optional, ClassVar, Any
|
||||||
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code
|
from .message import Message, MessageFilter, MessageError
|
||||||
|
|
||||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||||
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||||
@ -25,10 +24,6 @@ def pp(*args: Any, **kwargs: Any) -> None:
|
|||||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def print_paged(text: str) -> None:
|
|
||||||
pager(text)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Chat:
|
class Chat:
|
||||||
"""
|
"""
|
||||||
@ -67,31 +62,22 @@ class Chat:
|
|||||||
self.messages += msgs
|
self.messages += msgs
|
||||||
self.sort()
|
self.sort()
|
||||||
|
|
||||||
def print(self, dump: bool = False, source_code_only: bool = False,
|
def print(self, dump: bool = False) -> None:
|
||||||
with_tags: bool = False, with_file: bool = False,
|
|
||||||
paged: bool = True) -> None:
|
|
||||||
if dump:
|
if dump:
|
||||||
pp(self)
|
pp(self)
|
||||||
return
|
return
|
||||||
output: list[str] = []
|
# for message in self.messages:
|
||||||
for message in self.messages:
|
# text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
|
||||||
if source_code_only:
|
# if source_code:
|
||||||
output.extend(source_code(message.question, include_delims=True))
|
# display_source_code(message['content'])
|
||||||
continue
|
# continue
|
||||||
output.append('-' * terminal_width())
|
# if message['role'] == 'user':
|
||||||
output.append(Question.txt_header)
|
# print('-' * terminal_width())
|
||||||
output.append(message.question)
|
# if text_too_long:
|
||||||
if message.answer:
|
# print(f"{message['role'].upper()}:")
|
||||||
output.append(Answer.txt_header)
|
# print(message['content'])
|
||||||
output.append(message.answer)
|
# else:
|
||||||
if with_tags:
|
# print(f"{message['role'].upper()}: {message['content']}")
|
||||||
output.append(message.tags_str())
|
|
||||||
if with_file:
|
|
||||||
output.append('FILE: ' + str(message.file_path))
|
|
||||||
if paged:
|
|
||||||
print_paged('\n'.join(output))
|
|
||||||
else:
|
|
||||||
print(*output, sep='\n')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -198,15 +184,12 @@ class ChatDB(Chat):
|
|||||||
"""
|
"""
|
||||||
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
|
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
|
||||||
By default, only messages that have not been read (or written) before will
|
By default, only messages that have not been read (or written) before will
|
||||||
be read. Use 'force_all' to force reading all messages (existing messages
|
be read. Use 'force_all' to force reading all messages.
|
||||||
are discarded).
|
|
||||||
"""
|
"""
|
||||||
if from_cache:
|
if from_cache:
|
||||||
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
||||||
else:
|
else:
|
||||||
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
|
||||||
if force_all:
|
|
||||||
self.messages = []
|
|
||||||
for file_path in sorted(file_iter):
|
for file_path in sorted(file_iter):
|
||||||
if file_path.is_file():
|
if file_path.is_file():
|
||||||
if file_path.name in self.message_files and not force_all:
|
if file_path.name in self.message_files and not force_all:
|
||||||
|
|||||||
@ -444,16 +444,6 @@ class Message():
|
|||||||
res_tags -= {tag for tag in res_tags if contain not in tag}
|
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||||
return res_tags or set()
|
return res_tags or set()
|
||||||
|
|
||||||
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
|
|
||||||
"""
|
|
||||||
Returns all tags as a string with the TagLine prefix. Optionally filtered
|
|
||||||
using 'Message.filter_tags()'.
|
|
||||||
"""
|
|
||||||
if self.tags:
|
|
||||||
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
|
|
||||||
else:
|
|
||||||
return str(TagLine.from_set(set()))
|
|
||||||
|
|
||||||
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
|
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
|
||||||
"""
|
"""
|
||||||
Matches the current Message to the given filter atttributes.
|
Matches the current Message to the given filter atttributes.
|
||||||
|
|||||||
@ -1,64 +0,0 @@
|
|||||||
import pathlib
|
|
||||||
from io import StringIO
|
|
||||||
from unittest.mock import patch
|
|
||||||
from chatmastermind.tags import TagLine
|
|
||||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
|
||||||
from chatmastermind.chat import Chat, terminal_width
|
|
||||||
from .test_main import CmmTestCase
|
|
||||||
|
|
||||||
|
|
||||||
class TestChat(CmmTestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.chat = Chat([])
|
|
||||||
self.message1 = Message(Question('Question 1'),
|
|
||||||
Answer('Answer 1'),
|
|
||||||
{Tag('tag1')},
|
|
||||||
file_path=pathlib.Path('0001.txt'))
|
|
||||||
self.message2 = Message(Question('Question 2'),
|
|
||||||
Answer('Answer 2'),
|
|
||||||
{Tag('tag2')},
|
|
||||||
file_path=pathlib.Path('0002.txt'))
|
|
||||||
|
|
||||||
def test_filter(self) -> None:
|
|
||||||
self.chat.add_msgs([self.message1, self.message2])
|
|
||||||
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
|
|
||||||
|
|
||||||
self.assertEqual(len(self.chat.messages), 1)
|
|
||||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
|
||||||
|
|
||||||
def test_sort(self) -> None:
|
|
||||||
self.chat.add_msgs([self.message2, self.message1])
|
|
||||||
self.chat.sort()
|
|
||||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
|
||||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
|
||||||
self.chat.sort(reverse=True)
|
|
||||||
self.assertEqual(self.chat.messages[0].question, 'Question 2')
|
|
||||||
self.assertEqual(self.chat.messages[1].question, 'Question 1')
|
|
||||||
|
|
||||||
def test_clear(self) -> None:
|
|
||||||
self.chat.add_msgs([self.message1])
|
|
||||||
self.chat.clear()
|
|
||||||
self.assertEqual(len(self.chat.messages), 0)
|
|
||||||
|
|
||||||
def test_add_msgs(self) -> None:
|
|
||||||
self.chat.add_msgs([self.message1, self.message2])
|
|
||||||
|
|
||||||
self.assertEqual(len(self.chat.messages), 2)
|
|
||||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
|
||||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
|
||||||
|
|
||||||
@patch('sys.stdout', new_callable=StringIO)
|
|
||||||
def test_print(self, mock_stdout: StringIO) -> None:
|
|
||||||
self.chat.add_msgs([self.message1, self.message2])
|
|
||||||
self.chat.print(paged=False)
|
|
||||||
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
|
|
||||||
{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\nAnswer 2\n"
|
|
||||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
|
||||||
|
|
||||||
@patch('sys.stdout', new_callable=StringIO)
|
|
||||||
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
|
|
||||||
self.chat.print(paged=False, with_tags=True, with_file=True)
|
|
||||||
expected_output = f"{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 1\n{Answer.txt_header}\nAnswer 1\n\
|
|
||||||
\n{TagLine.prefix} tag1\nFILE: 0001.txt\n{'-'*terminal_width()}\n{Question.txt_header}\nQuestion 2\n{Answer.txt_header}\n\
|
|
||||||
Answer 2\n{TagLine.prefix} tag2\nFILE: 0002.txt\n"
|
|
||||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
|
||||||
@ -729,13 +729,3 @@ class MessageHashTestCase(CmmTestCase):
|
|||||||
self.assertEqual(len(msgs), 3)
|
self.assertEqual(len(msgs), 3)
|
||||||
for msg in [self.message1, self.message2, self.message3]:
|
for msg in [self.message1, self.message2, self.message3]:
|
||||||
self.assertIn(msg, msgs)
|
self.assertIn(msg, msgs)
|
||||||
|
|
||||||
|
|
||||||
class MessageTagsStrTestCase(CmmTestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.message = Message(Question('This is a question.'),
|
|
||||||
tags={Tag('tag1')},
|
|
||||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
|
||||||
|
|
||||||
def test_tags_str(self) -> None:
|
|
||||||
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user