Compare commits

..

8 Commits

3 changed files with 17 additions and 71 deletions

View File

@ -2,7 +2,7 @@
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
from pathlib import Path
import pathlib
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
@ -30,7 +30,7 @@ def print_paged(text: str) -> None:
pager(text)
def read_dir(dir_path: Path,
def read_dir(dir_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
@ -55,9 +55,9 @@ def read_dir(dir_path: Path,
return messages
def make_file_path(dir_path: Path,
def make_file_path(dir_path: pathlib.Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
next_fid: Callable[[], int]) -> pathlib.Path:
"""
Create a file_path for the given directory using the
given file_suffix and ID generator function.
@ -65,7 +65,7 @@ def make_file_path(dir_path: Path,
return dir_path / f"{next_fid():04d}{file_suffix}"
def write_dir(dir_path: Path,
def write_dir(dir_path: pathlib.Path,
messages: list[Message],
file_suffix: str,
next_fid: Callable[[], int]) -> None:
@ -90,7 +90,7 @@ def write_dir(dir_path: Path,
message.to_file(file_path)
def clear_dir(dir_path: Path,
def clear_dir(dir_path: pathlib.Path,
glob: Optional[str] = None) -> None:
"""
Deletes all Message files in the given directory.
@ -139,34 +139,6 @@ class Chat:
self.messages += messages
self.sort()
def latest_message(self) -> Optional[Message]:
"""
Returns the last added message (according to the file ID).
"""
if len(self.messages) > 0:
self.sort()
return self.messages[-1]
else:
return None
def find_messages(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
caller should check the result if he requires all messages).
"""
return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths.
"""
self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
@ -220,8 +192,8 @@ class ChatDB(Chat):
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path
db_path: Path
cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
@ -237,8 +209,8 @@ class ChatDB(Chat):
@classmethod
def from_dir(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
cache_path: pathlib.Path,
db_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
@ -258,8 +230,8 @@ class ChatDB(Chat):
@classmethod
def from_messages(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
cache_path: pathlib.Path,
db_path: pathlib.Path,
messages: list[Message],
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""

View File

@ -13,7 +13,6 @@ from .configuration import Config
from .chat import ChatDB
from .message import Message, MessageFilter, MessageError, Question
from .ai_factory import create_ai
from .ai import AI, AIResponse
from itertools import zip_longest
from typing import Any
@ -104,20 +103,17 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
ai = create_ai(args, config)
if args.ask:
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.otags) # FIXME
assert response
ai.request(message,
chat,
args.num_answers, # FIXME
args.otags) # FIXME
# TODO:
# * add answer to the message above (and create
# more messages for any additional answers)
pass
elif args.repeat:
lmessage = chat.latest_message()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)

View File

@ -62,28 +62,6 @@ class TestChat(CmmTestCase):
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.add_messages([message3])
# find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.remove_messages(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2])