From 6f71a2ff691105b25593ae00d5053443a1ab768b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:56:50 +0200 Subject: [PATCH] message: to_file() now uses intermediate temporary file --- chatmastermind/message.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index df59ed6..64929a3 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,6 +3,8 @@ Module implementing message related functions and classes. """ import pathlib import yaml +import tempfile +import shutil from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -445,16 +447,18 @@ class Message(): * Answer.txt_header * Answer """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) if self.tags: - fd.write(f'{TagLine.from_set(self.tags)}\n') + temp_fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: - fd.write(f'{AILine.from_ai(self.ai)}\n') + temp_fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: - fd.write(f'{ModelLine.from_model(self.model)}\n') - fd.write(f'{Question.txt_header}\n{self.question}\n') + temp_fd.write(f'{ModelLine.from_model(self.model)}\n') + temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ @@ -466,7 +470,8 @@ class Message(): * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) @@ -476,7 +481,8 @@ class Message(): data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) - yaml.dump(data, fd, sort_keys=False) + yaml.dump(data, temp_fd, sort_keys=False) + shutil.move(temp_file_path, file_path) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """