Compare commits
4 Commits
589b92c9b6
...
3245690d4d
| Author | SHA1 | Date | |
|---|---|---|---|
| 3245690d4d | |||
| 37341ccebe | |||
| 8031271c18 | |||
| 5e392e782e |
@ -8,7 +8,7 @@ from pydoc import pager
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
|
||||
from .configuration import default_config_file
|
||||
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in
|
||||
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in, message_valid_formats
|
||||
from .tags import Tag
|
||||
|
||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||
@ -285,6 +285,8 @@ class ChatDB(Chat):
|
||||
mfilter: Optional[MessageFilter] = None
|
||||
# the glob pattern for all messages
|
||||
glob: Optional[str] = None
|
||||
# message format (for writing)
|
||||
mformat: MessageFormat = Message.default_format
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# contains the latest message ID
|
||||
@ -339,9 +341,17 @@ class ChatDB(Chat):
|
||||
with open(self.next_path, 'w') as f:
|
||||
f.write(f'{fid}')
|
||||
|
||||
def set_msg_format(self, mformat: MessageFormat) -> None:
|
||||
"""
|
||||
Set message format for writing messages.
|
||||
"""
|
||||
if mformat not in message_valid_formats:
|
||||
raise ChatError(f"Message format '{mformat}' is not supported")
|
||||
self.mformat = mformat
|
||||
|
||||
def msg_write(self,
|
||||
messages: Optional[list[Message]] = None,
|
||||
mformat: MessageFormat = Message.default_format) -> None:
|
||||
mformat: Optional[MessageFormat] = None) -> None:
|
||||
"""
|
||||
Write either the given messages or the internal ones to their CURRENT file_path.
|
||||
If messages are given, they all must have a valid file_path. When writing the
|
||||
@ -352,7 +362,7 @@ class ChatDB(Chat):
|
||||
raise ChatError("Can't write files without a valid file_path")
|
||||
msgs = iter(messages if messages else self.messages)
|
||||
while (m := next(msgs, None)):
|
||||
m.to_file(mformat=mformat)
|
||||
m.to_file(mformat=mformat if mformat else self.mformat)
|
||||
|
||||
def msg_update(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
@ -373,6 +383,7 @@ class ChatDB(Chat):
|
||||
def msg_gather(self,
|
||||
loc: msg_location,
|
||||
require_file_path: bool = False,
|
||||
glob: Optional[str] = None,
|
||||
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||
"""
|
||||
Gather and return messages from the given locations:
|
||||
@ -391,9 +402,9 @@ class ChatDB(Chat):
|
||||
else:
|
||||
loc_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
|
||||
if loc in ['cache', 'disk', 'all']:
|
||||
loc_messages += read_dir(self.cache_path, mfilter=mfilter)
|
||||
loc_messages += read_dir(self.cache_path, glob=glob, mfilter=mfilter)
|
||||
if loc in ['db', 'disk', 'all']:
|
||||
loc_messages += read_dir(self.db_path, mfilter=mfilter)
|
||||
loc_messages += read_dir(self.db_path, glob=glob, mfilter=mfilter)
|
||||
# remove_duplicates and sort the list
|
||||
unique_messages: list[Message] = []
|
||||
for m in loc_messages:
|
||||
@ -514,7 +525,8 @@ class ChatDB(Chat):
|
||||
"""
|
||||
write_dir(self.cache_path,
|
||||
messages if messages else self.messages,
|
||||
self.get_next_fid)
|
||||
self.get_next_fid,
|
||||
self.mformat)
|
||||
|
||||
def cache_add(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
@ -526,7 +538,8 @@ class ChatDB(Chat):
|
||||
if write:
|
||||
write_dir(self.cache_path,
|
||||
messages,
|
||||
self.get_next_fid)
|
||||
self.get_next_fid,
|
||||
self.mformat)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
|
||||
@ -579,7 +592,8 @@ class ChatDB(Chat):
|
||||
"""
|
||||
write_dir(self.db_path,
|
||||
messages if messages else self.messages,
|
||||
self.get_next_fid)
|
||||
self.get_next_fid,
|
||||
self.mformat)
|
||||
|
||||
def db_add(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
@ -591,7 +605,8 @@ class ChatDB(Chat):
|
||||
if write:
|
||||
write_dir(self.db_path,
|
||||
messages,
|
||||
self.get_next_fid)
|
||||
self.get_next_fid,
|
||||
self.mformat)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.db_path, self.get_next_fid)
|
||||
|
||||
@ -51,6 +51,29 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
|
||||
question_parts.append(f"```\n{content}\n```")
|
||||
|
||||
|
||||
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
|
||||
"""
|
||||
Takes an existing message and CLI arguments, and returns modified args based
|
||||
on the members of the given message. Used e.g. when repeating messages, where
|
||||
it's necessary to determine the correct AI, module and output tags to use
|
||||
(either from the existing message or the given args).
|
||||
"""
|
||||
msg_args = args
|
||||
# if AI, model or output tags have not been specified,
|
||||
# use those from the original message
|
||||
if (args.AI is None
|
||||
or args.model is None # noqa: W503
|
||||
or args.output_tags is None): # noqa: W503
|
||||
msg_args = deepcopy(args)
|
||||
if args.AI is None and msg.ai is not None:
|
||||
msg_args.AI = msg.ai
|
||||
if args.model is None and msg.model is not None:
|
||||
msg_args.model = msg.model
|
||||
if args.output_tags is None and msg.tags is not None:
|
||||
msg_args.output_tags = msg.tags
|
||||
return msg_args
|
||||
|
||||
|
||||
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||
"""
|
||||
Create a new message from the given arguments and write it
|
||||
@ -106,29 +129,6 @@ def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespac
|
||||
print(response.tokens)
|
||||
|
||||
|
||||
def create_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace:
|
||||
"""
|
||||
Takes an existing message and CLI arguments, and returns modified args based
|
||||
on the members of the given message. Used e.g. when repeating messages, where
|
||||
it's necessary to determine the correct AI, module and output tags to use
|
||||
(either from the existing message or the given args).
|
||||
"""
|
||||
msg_args = args
|
||||
# if AI, model or output tags have not been specified,
|
||||
# use those from the original message
|
||||
if (args.AI is None
|
||||
or args.model is None # noqa: W503
|
||||
or args.output_tags is None): # noqa: W503
|
||||
msg_args = deepcopy(args)
|
||||
if args.AI is None and msg.ai is not None:
|
||||
msg_args.AI = msg.ai
|
||||
if args.model is None and msg.model is not None:
|
||||
msg_args.model = msg.model
|
||||
if args.output_tags is None and msg.tags is not None:
|
||||
msg_args.output_tags = msg.tags
|
||||
return msg_args
|
||||
|
||||
|
||||
def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Repeat the given messages using the given arguments.
|
||||
|
||||
@ -517,6 +517,13 @@ class Message():
|
||||
yaml.dump(data, temp_fd, sort_keys=False)
|
||||
shutil.move(temp_file_path, file_path)
|
||||
|
||||
def rm_file(self) -> None:
|
||||
"""
|
||||
Delete the message file. Ignore empty file_path and not existing files.
|
||||
"""
|
||||
if self.file_path is not None:
|
||||
self.file_path.unlink(missing_ok=True)
|
||||
|
||||
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
"""
|
||||
Filter tags based on their prefix (i. e. the tag starts with a given string)
|
||||
|
||||
@ -874,3 +874,22 @@ This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer."""
|
||||
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
|
||||
|
||||
|
||||
class MessageRmFileTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message = Message(Question('This is a question.'),
|
||||
file_path=self.file_path)
|
||||
self.message.to_file()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
self.file_path.unlink(missing_ok=True)
|
||||
|
||||
def test_rm_file(self) -> None:
|
||||
assert self.message.file_path
|
||||
self.assertTrue(self.message.file_path.exists())
|
||||
self.message.rm_file()
|
||||
self.assertFalse(self.message.file_path.exists())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user