From 6ad404899e9c599b8b330652232be815d5b41317 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 14 Sep 2023 16:05:18 +0200 Subject: [PATCH] chat: ChatDB now correctly ignores files that contain no valid messages --- chatmastermind/chat.py | 2 +- chatmastermind/message.py | 9 ++++++--- tests/test_chat.py | 7 ++++++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 8d64f86..c1464c3 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -57,7 +57,7 @@ def read_dir(dir_path: Path, if message: messages.append(message) except MessageError as e: - print(f"Error processing message in '{file_path}': {str(e)}") + print(f"WARNING: Skipping message in '{file_path}': {str(e)}") return messages diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 64929a3..8b32ae9 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -370,7 +370,7 @@ class Message(): try: question_idx = text.index(Question.txt_header) + 1 except ValueError: - raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") + raise MessageError(f"'{file_path}' does not contain a valid message") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) @@ -390,8 +390,11 @@ class Message(): * Message.model_yaml_key: str [Optional] """ with open(file_path, "r") as fd: - data = yaml.load(fd, Loader=yaml.FullLoader) - data[cls.file_yaml_key] = file_path + try: + data = yaml.load(fd, Loader=yaml.FullLoader) + data[cls.file_yaml_key] = file_path + except Exception: + raise MessageError(f"'{file_path}' does not contain a valid message") return cls.from_dict(data) def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: diff --git a/tests/test_chat.py b/tests/test_chat.py index ff44cda..b052d19 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -156,13 +156,18 @@ class TestChatDB(unittest.TestCase): next_fname = pathlib.Path(self.db_path.name) / '.next' with open(next_fname, 'w') as f: f.write('4') + # add some "trash" in order to test if it's correctly handled / ignored + self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt'] + for file in self.trash_files: + with open(pathlib.Path(self.db_path.name) / file, 'w') as f: + f.write('test trash') def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: """ List all Message files in the given TemporaryDirectory. """ # exclude '.next' - return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) + return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files] def tearDown(self) -> None: self.db_path.cleanup()