Compare commits

..

6 Commits

2 changed files with 21 additions and 3 deletions

View File

@ -320,7 +320,7 @@ class ChatDB(Chat):
def add_to_db(self, messages: list[Message], write: bool = True) -> None: def add_to_db(self, messages: list[Message], write: bool = True) -> None:
""" """
Adds the given new messages and sets the file_path to the DB directory. Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages): if any(m.file_path is not None for m in messages):
@ -338,7 +338,7 @@ class ChatDB(Chat):
def add_to_cache(self, messages: list[Message], write: bool = True) -> None: def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
""" """
Adds the given new messages and sets the file_path to the cache directory. Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages): if any(m.file_path is not None for m in messages):
@ -353,3 +353,16 @@ class ChatDB(Chat):
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages self.messages += messages
self.sort() self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()

View File

@ -5,7 +5,7 @@ from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
from .test_main import CmmTestCase from .test_main import CmmTestCase
@ -384,3 +384,8 @@ class TestChatDB(CmmTestCase):
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
# TODO: add testcase for "ChatDB.write_messages()"