Compare commits
8 Commits
b88a6b9a92
...
f5b185505e
| Author | SHA1 | Date | |
|---|---|---|---|
| f5b185505e | |||
| 07b8f955da | |||
| 4e2af55b7c | |||
| de4c6e3b4a | |||
| c16afc6c11 | |||
| 68a7315044 | |||
| 6e7f39de2a | |||
| 89d19ee9d6 |
@ -60,13 +60,17 @@ class Chat:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatDir(Chat):
|
class ChatDir(Chat):
|
||||||
"""
|
"""
|
||||||
A Chat class that is bound to a given directory. Supports reading
|
A 'Chat' class that is bound to a given directory structure. Supports reading
|
||||||
and writing messages from / to that directory.
|
and writing messages from / to that structure. Such a structure consists of
|
||||||
|
two directories: a 'cache directory', where all messages are temporarily
|
||||||
|
stored, and a 'DB' directory, where selected messages can be stored
|
||||||
|
persistently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_file_suffix: ClassVar[str] = '.txt'
|
default_file_suffix: ClassVar[str] = '.txt'
|
||||||
|
|
||||||
directory: pathlib.Path
|
cache_path: pathlib.Path
|
||||||
|
db_path: pathlib.Path
|
||||||
# a MessageFilter that all messages must match (if given)
|
# a MessageFilter that all messages must match (if given)
|
||||||
mfilter: Optional[MessageFilter] = None
|
mfilter: Optional[MessageFilter] = None
|
||||||
file_suffix: str = default_file_suffix
|
file_suffix: str = default_file_suffix
|
||||||
@ -75,17 +79,24 @@ class ChatDir(Chat):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dir(cls: Type[ChatDirInst],
|
def from_dir(cls: Type[ChatDirInst],
|
||||||
path: pathlib.Path,
|
cache_path: pathlib.Path,
|
||||||
|
db_path: pathlib.Path,
|
||||||
glob: Optional[str] = None,
|
glob: Optional[str] = None,
|
||||||
mfilter: Optional[MessageFilter] = None) -> ChatDirInst:
|
mfilter: Optional[MessageFilter] = None) -> ChatDirInst:
|
||||||
"""
|
"""
|
||||||
Create a ChatDir instance from the given directory. If 'glob' is specified,
|
Create a 'ChatDir' instance from the given directory structure.
|
||||||
files will be filtered using 'path.glob()', otherwise it uses 'path.iterdir()'.
|
Reads all messages from 'db_path' into the local message list.
|
||||||
Messages are created using 'Message.from_file()' and the optional MessageFilter.
|
Parameters:
|
||||||
|
* 'cache_path': path to the directory for temporary messages
|
||||||
|
* 'db_path': path to the directory for persistent messages
|
||||||
|
* 'glob' fs specified, files will be filtered using 'path.glob()',
|
||||||
|
otherwise it uses 'path.iterdir()'.
|
||||||
|
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||||
|
when reading them.
|
||||||
"""
|
"""
|
||||||
messages: list[Message] = []
|
messages: list[Message] = []
|
||||||
message_files: set[str] = set()
|
message_files: set[str] = set()
|
||||||
file_iter = path.glob(glob) if glob else path.iterdir()
|
file_iter = db_path.glob(glob) if glob else db_path.iterdir()
|
||||||
for file_path in sorted(file_iter):
|
for file_path in sorted(file_iter):
|
||||||
if file_path.is_file():
|
if file_path.is_file():
|
||||||
try:
|
try:
|
||||||
@ -95,11 +106,12 @@ class ChatDir(Chat):
|
|||||||
message_files.add(file_path.name)
|
message_files.add(file_path.name)
|
||||||
except MessageError as e:
|
except MessageError as e:
|
||||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||||
return cls(messages, path, mfilter, cls.default_file_suffix, message_files)
|
return cls(messages, cache_path, db_path, mfilter, cls.default_file_suffix, message_files)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_messages(cls: Type[ChatDirInst],
|
def from_messages(cls: Type[ChatDirInst],
|
||||||
path: pathlib.Path,
|
cache_path: pathlib.Path,
|
||||||
|
db_path: pathlib.Path,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
mfilter: Optional[MessageFilter]) -> ChatDirInst:
|
mfilter: Optional[MessageFilter]) -> ChatDirInst:
|
||||||
"""
|
"""
|
||||||
@ -108,10 +120,10 @@ class ChatDir(Chat):
|
|||||||
in order to synchronize the messages. 'update()' is not
|
in order to synchronize the messages. 'update()' is not
|
||||||
supported until after the first 'dump()'.
|
supported until after the first 'dump()'.
|
||||||
"""
|
"""
|
||||||
return cls(messages, path, mfilter)
|
return cls(messages, cache_path, db_path, mfilter)
|
||||||
|
|
||||||
def get_next_fid(self) -> int:
|
def get_next_fid(self) -> int:
|
||||||
next_fname = self.directory / '.next'
|
next_fname = self.db_path / '.next'
|
||||||
try:
|
try:
|
||||||
with open(next_fname, 'r') as f:
|
with open(next_fname, 'r') as f:
|
||||||
return int(f.read()) + 1
|
return int(f.read()) + 1
|
||||||
@ -119,18 +131,16 @@ class ChatDir(Chat):
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
def set_next_fid(self, fid: int) -> None:
|
def set_next_fid(self, fid: int) -> None:
|
||||||
next_fname = self.directory / '.next'
|
next_fname = self.db_path / '.next'
|
||||||
with open(next_fname, 'w') as f:
|
with open(next_fname, 'w') as f:
|
||||||
f.write(f'{fid}')
|
f.write(f'{fid}')
|
||||||
|
|
||||||
def dump(self, force_all: bool = False) -> None:
|
def dump(self, to_db: bool = False, force_all: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Writes all messages to the bound directory. If a message has no file_path,
|
Writes all messages to the 'cache_path' or 'db_path'. If a message has no file_path,
|
||||||
it will create a new one. By default, only messages that have not been
|
it will create a new one. By default, only messages that have not been written
|
||||||
written (or read) before will be dumped. Use 'force_all' to force writing
|
(or read) before will be dumped. Use 'force_all' to force writing all message files.
|
||||||
all message files.
|
|
||||||
"""
|
"""
|
||||||
# FIXME: write to 'db' subfolder or given folder
|
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
# skip messages that we have already written (or read)
|
# skip messages that we have already written (or read)
|
||||||
if message.file_path and message.file_path in self.message_files and not force_all:
|
if message.file_path and message.file_path in self.message_files and not force_all:
|
||||||
@ -138,6 +148,7 @@ class ChatDir(Chat):
|
|||||||
file_path = message.file_path
|
file_path = message.file_path
|
||||||
if not file_path:
|
if not file_path:
|
||||||
fid = self.get_next_fid()
|
fid = self.get_next_fid()
|
||||||
file_path = self.directory / f"{fid:04d}{self.file_suffix}"
|
fname = f"{fid:04d}{self.file_suffix}"
|
||||||
|
file_path = self.db_path / fname if to_db else self.cache_path / fname
|
||||||
self.set_next_fid(fid)
|
self.set_next_fid(fid)
|
||||||
message.to_file(file_path)
|
message.to_file(file_path)
|
||||||
|
|||||||
@ -187,7 +187,7 @@ class Message():
|
|||||||
and a file path.
|
and a file path.
|
||||||
"""
|
"""
|
||||||
question: Question
|
question: Question
|
||||||
answer: Optional[Answer] = None # FIXME: support multiple answers
|
answer: Optional[Answer] = None
|
||||||
tags: Optional[set[Tag]] = None
|
tags: Optional[set[Tag]] = None
|
||||||
ai: Optional[str] = None
|
ai: Optional[str] = None
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
@ -409,5 +409,15 @@ class Message():
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def file_id(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns an ID that is unique within the directory of this message.
|
||||||
|
Currently this is simply the file name.
|
||||||
|
"""
|
||||||
|
if self.file_path:
|
||||||
|
return self.file_path.name
|
||||||
|
else:
|
||||||
|
raise MessageError("Can't create file ID without a file path")
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, Any]:
|
def as_dict(self) -> dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|||||||
@ -575,3 +575,23 @@ This is an answer.
|
|||||||
def test_tags_from_file_yaml(self) -> None:
|
def test_tags_from_file_yaml(self) -> None:
|
||||||
tags = Message.tags_from_file(self.file_path_yaml)
|
tags = Message.tags_from_file(self.file_path_yaml)
|
||||||
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')})
|
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')})
|
||||||
|
|
||||||
|
|
||||||
|
class MessageFileIDTxtTestCase(CmmTestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||||
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
|
self.message = Message(Question('This is a question.'),
|
||||||
|
file_path=self.file_path)
|
||||||
|
self.message_no_file_path = Message(Question('This is a question.'))
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.file.close()
|
||||||
|
self.file_path.unlink()
|
||||||
|
|
||||||
|
def test_file_id_txt(self) -> None:
|
||||||
|
self.assertEqual(self.message.file_id(), self.file_path.name)
|
||||||
|
|
||||||
|
def test_file_id_txt_exception(self) -> None:
|
||||||
|
with self.assertRaises(MessageError):
|
||||||
|
self.message_no_file_path.file_id()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user