Compare commits
1 Commits
3c1c9860a0
...
fae835be1f
| Author | SHA1 | Date | |
|---|---|---|---|
| fae835be1f |
@ -6,9 +6,9 @@ from pathlib import Path
|
||||
from pprint import PrettyPrinter
|
||||
from pydoc import pager
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
|
||||
from .configuration import default_config_file
|
||||
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in
|
||||
from .message import Message, MessageFilter, MessageError, message_in
|
||||
from .tags import Tag
|
||||
|
||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||
@ -17,7 +17,6 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||
db_next_file = '.next'
|
||||
ignored_files = [db_next_file, default_config_file]
|
||||
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
|
||||
msg_suffix = Message.file_suffix_write
|
||||
|
||||
|
||||
class ChatError(Exception):
|
||||
@ -53,7 +52,7 @@ def read_dir(dir_path: Path,
|
||||
for file_path in sorted(file_iter):
|
||||
if (file_path.is_file()
|
||||
and file_path.name not in ignored_files # noqa: W503
|
||||
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
|
||||
and file_path.suffix in Message.file_suffixes): # noqa: W503
|
||||
try:
|
||||
message = Message.from_file(file_path, mfilter)
|
||||
if message:
|
||||
@ -64,20 +63,22 @@ def read_dir(dir_path: Path,
|
||||
|
||||
|
||||
def make_file_path(dir_path: Path,
|
||||
file_suffix: str,
|
||||
next_fid: Callable[[], int]) -> Path:
|
||||
"""
|
||||
Create a file_path for the given directory using the given ID generator function.
|
||||
Create a file_path for the given directory using the
|
||||
given file_suffix and ID generator function.
|
||||
"""
|
||||
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
|
||||
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
|
||||
while file_path.exists():
|
||||
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
|
||||
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
|
||||
return file_path
|
||||
|
||||
|
||||
def write_dir(dir_path: Path,
|
||||
messages: list[Message],
|
||||
next_fid: Callable[[], int],
|
||||
mformat: MessageFormat = Message.default_format) -> None:
|
||||
file_suffix: str,
|
||||
next_fid: Callable[[], int]) -> None:
|
||||
"""
|
||||
Write all messages to the given directory. If a message has no file_path,
|
||||
a new one will be created. If message.file_path exists, it will be modified
|
||||
@ -85,17 +86,18 @@ def write_dir(dir_path: Path,
|
||||
Parameters:
|
||||
* 'dir_path': destination directory
|
||||
* 'messages': list of messages to write
|
||||
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
|
||||
* 'next_fid': callable that returns the next file ID
|
||||
"""
|
||||
for message in messages:
|
||||
file_path = message.file_path
|
||||
# message has no file_path: create one
|
||||
if not file_path:
|
||||
file_path = make_file_path(dir_path, next_fid)
|
||||
file_path = make_file_path(dir_path, file_suffix, next_fid)
|
||||
# file_path does not point to given directory: modify it
|
||||
elif not file_path.parent.samefile(dir_path):
|
||||
file_path = dir_path / file_path.name
|
||||
message.to_file(file_path, mformat=mformat)
|
||||
message.to_file(file_path)
|
||||
|
||||
|
||||
def clear_dir(dir_path: Path,
|
||||
@ -107,7 +109,7 @@ def clear_dir(dir_path: Path,
|
||||
for file_path in file_iter:
|
||||
if (file_path.is_file()
|
||||
and file_path.name not in ignored_files # noqa: W503
|
||||
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
|
||||
and file_path.suffix in Message.file_suffixes): # noqa: W503
|
||||
file_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@ -144,7 +146,7 @@ class Chat:
|
||||
Matching is True if:
|
||||
* 'name' matches the full 'file_path'
|
||||
* 'name' matches 'file_path.name' (i. e. including the suffix)
|
||||
* 'name' matches 'file_path.stem' (i. e. without the suffix)
|
||||
* 'name' matches 'file_path.stem' (i. e. without a suffix)
|
||||
"""
|
||||
return Path(name) == file_path or name == file_path.name or name == file_path.stem
|
||||
|
||||
@ -279,10 +281,13 @@ class ChatDB(Chat):
|
||||
persistently.
|
||||
"""
|
||||
|
||||
default_file_suffix: ClassVar[str] = '.txt'
|
||||
|
||||
cache_path: Path
|
||||
db_path: Path
|
||||
# a MessageFilter that all messages must match (if given)
|
||||
mfilter: Optional[MessageFilter] = None
|
||||
file_suffix: str = default_file_suffix
|
||||
# the glob pattern for all messages
|
||||
glob: Optional[str] = None
|
||||
|
||||
@ -312,7 +317,8 @@ class ChatDB(Chat):
|
||||
when reading them.
|
||||
"""
|
||||
messages = read_dir(db_path, glob, mfilter)
|
||||
return cls(messages, cache_path, db_path, mfilter, glob)
|
||||
return cls(messages, cache_path, db_path, mfilter,
|
||||
cls.default_file_suffix, glob)
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls: Type[ChatDBInst],
|
||||
@ -339,9 +345,7 @@ class ChatDB(Chat):
|
||||
with open(self.next_path, 'w') as f:
|
||||
f.write(f'{fid}')
|
||||
|
||||
def msg_write(self,
|
||||
messages: Optional[list[Message]] = None,
|
||||
mformat: MessageFormat = Message.default_format) -> None:
|
||||
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
|
||||
"""
|
||||
Write either the given messages or the internal ones to their CURRENT file_path.
|
||||
If messages are given, they all must have a valid file_path. When writing the
|
||||
@ -352,7 +356,7 @@ class ChatDB(Chat):
|
||||
raise ChatError("Can't write files without a valid file_path")
|
||||
msgs = iter(messages if messages else self.messages)
|
||||
while (m := next(msgs, None)):
|
||||
m.to_file(mformat=mformat)
|
||||
m.to_file()
|
||||
|
||||
def msg_update(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
@ -514,6 +518,7 @@ class ChatDB(Chat):
|
||||
"""
|
||||
write_dir(self.cache_path,
|
||||
messages if messages else self.messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
|
||||
def cache_add(self, messages: list[Message], write: bool = True) -> None:
|
||||
@ -526,10 +531,11 @@ class ChatDB(Chat):
|
||||
if write:
|
||||
write_dir(self.cache_path,
|
||||
messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
|
||||
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
|
||||
self.messages += messages
|
||||
self.msg_sort()
|
||||
|
||||
@ -579,6 +585,7 @@ class ChatDB(Chat):
|
||||
"""
|
||||
write_dir(self.db_path,
|
||||
messages if messages else self.messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
|
||||
def db_add(self, messages: list[Message], write: bool = True) -> None:
|
||||
@ -591,10 +598,11 @@ class ChatDB(Chat):
|
||||
if write:
|
||||
write_dir(self.db_path,
|
||||
messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.db_path, self.get_next_fid)
|
||||
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
|
||||
self.messages += messages
|
||||
self.msg_sort()
|
||||
|
||||
|
||||
@ -5,8 +5,7 @@ import pathlib
|
||||
import yaml
|
||||
import tempfile
|
||||
import shutil
|
||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
|
||||
from typing import get_args as typing_get_args
|
||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
|
||||
|
||||
@ -16,9 +15,6 @@ MessageInst = TypeVar('MessageInst', bound='Message')
|
||||
AILineInst = TypeVar('AILineInst', bound='AILine')
|
||||
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
|
||||
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
|
||||
MessageFormat = Literal['txt', 'yaml']
|
||||
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
|
||||
message_default_format: Final[MessageFormat] = 'txt'
|
||||
|
||||
|
||||
class MessageError(Exception):
|
||||
@ -222,7 +218,8 @@ class Message():
|
||||
# class variables
|
||||
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
|
||||
file_suffix_write: ClassVar[str] = '.msg'
|
||||
default_format: ClassVar[MessageFormat] = message_default_format
|
||||
file_formats: ClassVar[list[str]] = ['txt', 'yaml']
|
||||
default_format: ClassVar[str] = 'txt'
|
||||
tags_yaml_key: ClassVar[str] = 'tags'
|
||||
file_yaml_key: ClassVar[str] = 'file_path'
|
||||
ai_yaml_key: ClassVar[str] = 'ai'
|
||||
@ -372,6 +369,10 @@ class Message():
|
||||
tags = TagLine(fd.readline()).tags()
|
||||
except TagError:
|
||||
fd.seek(pos)
|
||||
if tags_or or tags_and or tags_not:
|
||||
# match with an empty set if the file has no tags
|
||||
if not match_tags(tags, tags_or, tags_and, tags_not):
|
||||
return None
|
||||
# AILine (Optional)
|
||||
try:
|
||||
pos = fd.tell()
|
||||
@ -386,23 +387,17 @@ class Message():
|
||||
fd.seek(pos)
|
||||
# Question and Answer
|
||||
text = fd.read().strip().split('\n')
|
||||
try:
|
||||
question_idx = text.index(Question.txt_header) + 1
|
||||
except ValueError:
|
||||
raise MessageError(f"'{file_path}' does not contain a valid message")
|
||||
try:
|
||||
answer_idx = text.index(Answer.txt_header)
|
||||
question = Question.from_list(text[question_idx:answer_idx])
|
||||
answer = Answer.from_list(text[answer_idx + 1:])
|
||||
except ValueError:
|
||||
question = Question.from_list(text[question_idx:])
|
||||
# match tags AFTER reading the whole file
|
||||
# -> make sure it's a valid 'txt' file format
|
||||
if tags_or or tags_and or tags_not:
|
||||
# match with an empty set if the file has no tags
|
||||
if not match_tags(tags, tags_or, tags_and, tags_not):
|
||||
return None
|
||||
return cls(question, answer, tags, ai, model, file_path)
|
||||
try:
|
||||
question_idx = text.index(Question.txt_header) + 1
|
||||
except ValueError:
|
||||
raise MessageError(f"'{file_path}' does not contain a valid message")
|
||||
try:
|
||||
answer_idx = text.index(Answer.txt_header)
|
||||
question = Question.from_list(text[question_idx:answer_idx])
|
||||
answer = Answer.from_list(text[answer_idx + 1:])
|
||||
except ValueError:
|
||||
question = Question.from_list(text[question_idx:])
|
||||
return cls(question, answer, tags, ai, model, file_path)
|
||||
|
||||
@classmethod
|
||||
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
|
||||
@ -443,7 +438,7 @@ class Message():
|
||||
output.append(self.answer)
|
||||
return '\n'.join(output)
|
||||
|
||||
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
|
||||
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: str = Message.default_format) -> None: # noqa: 11
|
||||
"""
|
||||
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
|
||||
Suffix is always '.msg'.
|
||||
@ -452,14 +447,10 @@ class Message():
|
||||
self.file_path = file_path
|
||||
if not self.file_path:
|
||||
raise MessageError("Got no valid path to write message")
|
||||
if mformat not in message_valid_formats:
|
||||
if mformat not in self.file_formats:
|
||||
raise MessageError(f"File format '{mformat}' is not supported")
|
||||
# check for valid suffix
|
||||
# -> add one if it's empty
|
||||
# -> refuse old or otherwise unsupported suffixes
|
||||
if not self.file_path.suffix:
|
||||
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
|
||||
elif self.file_path.suffix != self.file_suffix_write:
|
||||
# do not write old suffixes
|
||||
if self.file_path.suffix != self.file_suffix_write:
|
||||
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
|
||||
# TXT
|
||||
if mformat == 'txt':
|
||||
|
||||
@ -10,20 +10,6 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||
from chatmastermind.chat import Chat, ChatDB, ChatError
|
||||
|
||||
|
||||
msg_suffix: str = Message.file_suffix_write
|
||||
|
||||
|
||||
def msg_to_file_force_suffix(msg: Message) -> None:
|
||||
"""
|
||||
Force writing a message file with illegal suffixes.
|
||||
"""
|
||||
def_suffix = Message.file_suffix_write
|
||||
assert msg.file_path
|
||||
Message.file_suffix_write = msg.file_path.suffix
|
||||
msg.to_file()
|
||||
Message.file_suffix_write = def_suffix
|
||||
|
||||
|
||||
class TestChatBase(unittest.TestCase):
|
||||
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||
"""
|
||||
@ -41,11 +27,11 @@ class TestChat(TestChatBase):
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path(f'0001{msg_suffix}'))
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('btag2')},
|
||||
file_path=pathlib.Path(f'0002{msg_suffix}'))
|
||||
file_path=pathlib.Path('0002.txt'))
|
||||
self.maxDiff = None
|
||||
|
||||
def test_unique_id(self) -> None:
|
||||
@ -113,24 +99,24 @@ class TestChat(TestChatBase):
|
||||
|
||||
def test_find_remove_messages(self) -> None:
|
||||
self.chat.msg_add([self.message1, self.message2])
|
||||
msgs = self.chat.msg_find(['0001'])
|
||||
msgs = self.chat.msg_find(['0001.txt'])
|
||||
self.assertListEqual(msgs, [self.message1])
|
||||
msgs = self.chat.msg_find(['0001', '0002'])
|
||||
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
|
||||
self.assertListEqual(msgs, [self.message1, self.message2])
|
||||
# add new Message with full path
|
||||
message3 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('btag2')},
|
||||
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
|
||||
file_path=pathlib.Path('/foo/bla/0003.txt'))
|
||||
self.chat.msg_add([message3])
|
||||
# find new Message by full path
|
||||
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
|
||||
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
|
||||
self.assertListEqual(msgs, [message3])
|
||||
# find Message with full path only by filename
|
||||
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
|
||||
msgs = self.chat.msg_find(['0003.txt'])
|
||||
self.assertListEqual(msgs, [message3])
|
||||
# remove last message
|
||||
self.chat.msg_remove(['0003'])
|
||||
self.chat.msg_remove(['0003.txt'])
|
||||
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
||||
|
||||
def test_latest_message(self) -> None:
|
||||
@ -160,13 +146,13 @@ Answer 2
|
||||
self.chat.msg_add([self.message1, self.message2])
|
||||
self.chat.print(paged=False, with_tags=True, with_files=True)
|
||||
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||
FILE: 0001{msg_suffix}
|
||||
FILE: 0001.txt
|
||||
{Question.txt_header}
|
||||
Question 1
|
||||
{Answer.txt_header}
|
||||
Answer 1
|
||||
{TagLine.prefix} btag2
|
||||
FILE: 0002{msg_suffix}
|
||||
FILE: 0002.txt
|
||||
{Question.txt_header}
|
||||
Question 2
|
||||
{Answer.txt_header}
|
||||
@ -182,27 +168,31 @@ class TestChatDB(TestChatBase):
|
||||
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('tag1')})
|
||||
{Tag('tag1')},
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('tag2')})
|
||||
{Tag('tag2')},
|
||||
file_path=pathlib.Path('0002.yaml'))
|
||||
self.message3 = Message(Question('Question 3'),
|
||||
Answer('Answer 3'),
|
||||
{Tag('tag3')})
|
||||
{Tag('tag3')},
|
||||
file_path=pathlib.Path('0003.txt'))
|
||||
self.message4 = Message(Question('Question 4'),
|
||||
Answer('Answer 4'),
|
||||
{Tag('tag4')})
|
||||
{Tag('tag4')},
|
||||
file_path=pathlib.Path('0004.yaml'))
|
||||
|
||||
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
|
||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
|
||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
|
||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
|
||||
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
# make the next FID match the current state
|
||||
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write('4')
|
||||
# add some "trash" in order to test if it's correctly handled / ignored
|
||||
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
|
||||
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
|
||||
for file in self.trash_files:
|
||||
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
|
||||
f.write('test trash')
|
||||
@ -217,7 +207,7 @@ class TestChatDB(TestChatBase):
|
||||
List all Message files in the given TemporaryDirectory.
|
||||
"""
|
||||
# exclude '.next'
|
||||
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
|
||||
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.db_path.cleanup()
|
||||
@ -228,31 +218,13 @@ class TestChatDB(TestChatBase):
|
||||
duplicate_message = Message(Question('Question 4'),
|
||||
Answer('Answer 4'),
|
||||
{Tag('tag4')},
|
||||
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
|
||||
msg_to_file_force_suffix(duplicate_message)
|
||||
file_path=pathlib.Path('0004.txt'))
|
||||
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
|
||||
with self.assertRaises(ChatError) as cm:
|
||||
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(str(cm.exception), "Validation failed")
|
||||
|
||||
def test_file_path_ID_exists(self) -> None:
|
||||
"""
|
||||
Tests if the CacheDB chooses another ID if a file path with
|
||||
the given one exists.
|
||||
"""
|
||||
# create a new and empty CacheDB
|
||||
db_path = tempfile.TemporaryDirectory()
|
||||
cache_path = tempfile.TemporaryDirectory()
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
|
||||
pathlib.Path(db_path.name))
|
||||
# add a message file
|
||||
message = Message(Question('What?'),
|
||||
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
|
||||
message.to_file()
|
||||
message1 = Message(Question('Where?'))
|
||||
chat_db.cache_write([message1])
|
||||
self.assertEqual(message1.msg_id(), '0002')
|
||||
|
||||
def test_from_dir(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -261,23 +233,25 @@ class TestChatDB(TestChatBase):
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
# check that the files are sorted
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_from_dir_glob(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
glob='*1.*')
|
||||
self.assertEqual(len(chat_db.messages), 1)
|
||||
glob='*.txt')
|
||||
self.assertEqual(len(chat_db.messages), 2)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
|
||||
def test_from_dir_filter_tags(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
@ -287,7 +261,7 @@ class TestChatDB(TestChatBase):
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
|
||||
def test_from_dir_filter_tags_empty(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
@ -305,7 +279,7 @@ class TestChatDB(TestChatBase):
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||
|
||||
def test_from_messages(self) -> None:
|
||||
@ -350,25 +324,25 @@ class TestChatDB(TestChatBase):
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# check that Message.file_path is correct
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
# write the messages to the cache directory
|
||||
chat_db.cache_write()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
|
||||
# check that Message.file_path has been correctly updated
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
|
||||
|
||||
# check the timestamp of the files in the DB directory
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
@ -380,18 +354,18 @@ class TestChatDB(TestChatBase):
|
||||
# check if the written files are in the DB directory
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
||||
# check if all files in the DB dir have actually been overwritten
|
||||
for file in db_dir_files:
|
||||
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
|
||||
# check that Message.file_path has been correctly updated (again)
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_db_read(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
@ -406,65 +380,65 @@ class TestChatDB(TestChatBase):
|
||||
new_message2 = Message(Question('Question 6'),
|
||||
Answer('Answer 6'),
|
||||
{Tag('tag6')})
|
||||
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
|
||||
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
|
||||
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
# read and check them
|
||||
chat_db.db_read()
|
||||
self.assertEqual(len(chat_db.messages), 6)
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
|
||||
# create 2 new files in the cache directory
|
||||
new_message3 = Message(Question('Question 7'),
|
||||
Answer('Answer 7'),
|
||||
Answer('Answer 5'),
|
||||
{Tag('tag7')})
|
||||
new_message4 = Message(Question('Question 8'),
|
||||
Answer('Answer 8'),
|
||||
Answer('Answer 6'),
|
||||
{Tag('tag8')})
|
||||
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
|
||||
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
|
||||
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
|
||||
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
|
||||
# read and check them
|
||||
chat_db.cache_read()
|
||||
self.assertEqual(len(chat_db.messages), 8)
|
||||
# check that the new message have the cache dir path
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
|
||||
# an the old ones keep their path (since they have not been replaced)
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
|
||||
# now overwrite two messages in the DB directory
|
||||
new_message1.question = Question('New Question 1')
|
||||
new_message2.question = Question('New Question 2')
|
||||
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
|
||||
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
|
||||
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
# read from the DB dir and check if the modified messages have been updated
|
||||
chat_db.db_read()
|
||||
self.assertEqual(len(chat_db.messages), 8)
|
||||
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
|
||||
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
|
||||
# now write the messages from the cache to the DB directory
|
||||
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
|
||||
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
|
||||
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
|
||||
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
|
||||
# read and check them
|
||||
chat_db.db_read()
|
||||
self.assertEqual(len(chat_db.messages), 8)
|
||||
# check that they now have the DB path
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
|
||||
|
||||
def test_cache_clear(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# check that Message.file_path is correct
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
# write the messages to the cache directory
|
||||
chat_db.cache_write()
|
||||
@ -476,10 +450,10 @@ class TestChatDB(TestChatBase):
|
||||
chat_db.db_write()
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
||||
|
||||
# add a new message with empty file_path
|
||||
message_empty = Message(question=Question("What the hell am I doing here?"),
|
||||
@ -487,7 +461,7 @@ class TestChatDB(TestChatBase):
|
||||
# and one for the cache dir
|
||||
message_cache = Message(question=Question("What the hell am I doing here?"),
|
||||
answer=Answer("You're a creep!"),
|
||||
file_path=pathlib.Path(self.cache_path.name, '0005'))
|
||||
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
|
||||
chat_db.msg_add([message_empty, message_cache])
|
||||
|
||||
# clear the cache and check the cache dir
|
||||
@ -549,11 +523,11 @@ class TestChatDB(TestChatBase):
|
||||
chat_db.msg_write([message])
|
||||
|
||||
# write a message with a valid file_path
|
||||
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
|
||||
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
|
||||
chat_db.msg_write([message])
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
|
||||
|
||||
def test_msg_update(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
@ -589,21 +563,21 @@ class TestChatDB(TestChatBase):
|
||||
# search for a DB file in memory
|
||||
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
|
||||
# and on disk
|
||||
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
|
||||
# now search the cache -> expect empty result
|
||||
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
|
||||
# search for multiple messages
|
||||
# -> search one twice, expect result to be unique
|
||||
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
|
||||
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
|
||||
expected_result = [self.message1, self.message2, self.message3]
|
||||
result = chat_db.msg_find(search_names, loc='all')
|
||||
self.assert_messages_equal(result, expected_result)
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
import tempfile
|
||||
import itertools
|
||||
from typing import cast
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
|
||||
MessageFilter, message_in, message_valid_formats
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
||||
from chatmastermind.tags import Tag, TagLine
|
||||
|
||||
|
||||
msg_suffix: str = Message.file_suffix_write
|
||||
|
||||
|
||||
class SourceCodeTestCase(unittest.TestCase):
|
||||
def test_source_code_with_include_delims(self) -> None:
|
||||
text = """
|
||||
@ -106,7 +101,7 @@ class AnswerTestCase(unittest.TestCase):
|
||||
|
||||
class MessageToFileTxtTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message_complete = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
@ -122,7 +117,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
|
||||
self.file_path.unlink()
|
||||
|
||||
def test_to_file_txt_complete(self) -> None:
|
||||
self.message_complete.to_file(self.file_path, mformat='txt')
|
||||
self.message_complete.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
@ -137,7 +132,7 @@ This is an answer.
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_txt_min(self) -> None:
|
||||
self.message_min.to_file(self.file_path, mformat='txt')
|
||||
self.message_min.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
@ -146,17 +141,11 @@ This is a question.
|
||||
"""
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_unsupported_file_suffix(self) -> None:
|
||||
def test_to_file_unsupported_file_type(self) -> None:
|
||||
unsupported_file_path = pathlib.Path("example.doc")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
self.message_complete.to_file(unsupported_file_path)
|
||||
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
|
||||
|
||||
def test_to_file_unsupported_file_format(self) -> None:
|
||||
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
|
||||
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
|
||||
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
|
||||
|
||||
def test_to_file_no_file_path(self) -> None:
|
||||
"""
|
||||
@ -170,24 +159,10 @@ This is a question.
|
||||
# reset the internal file_path
|
||||
self.message_complete.file_path = self.file_path
|
||||
|
||||
def test_to_file_txt_auto_suffix(self) -> None:
|
||||
"""
|
||||
Test if suffix is auto-generated if omitted.
|
||||
"""
|
||||
file_path_no_suffix = self.file_path.with_suffix('')
|
||||
# test with file_path member
|
||||
self.message_min.file_path = file_path_no_suffix
|
||||
self.message_min.to_file(mformat='txt')
|
||||
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||
# test with explicit file_path
|
||||
self.message_min.file_path = file_path_no_suffix
|
||||
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
|
||||
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||
|
||||
|
||||
class MessageToFileYamlTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message_complete = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
@ -209,7 +184,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
|
||||
self.file_path.unlink()
|
||||
|
||||
def test_to_file_yaml_complete(self) -> None:
|
||||
self.message_complete.to_file(self.file_path, mformat='yaml')
|
||||
self.message_complete.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
@ -224,7 +199,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_yaml_multiline(self) -> None:
|
||||
self.message_multiline.to_file(self.file_path, mformat='yaml')
|
||||
self.message_multiline.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
@ -243,31 +218,17 @@ class MessageToFileYamlTestCase(unittest.TestCase):
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_yaml_min(self) -> None:
|
||||
self.message_min.to_file(self.file_path, mformat='yaml')
|
||||
self.message_min.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = f"{Question.yaml_key}: This is a question.\n"
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_yaml_auto_suffix(self) -> None:
|
||||
"""
|
||||
Test if suffix is auto-generated if omitted.
|
||||
"""
|
||||
file_path_no_suffix = self.file_path.with_suffix('')
|
||||
# test with file_path member
|
||||
self.message_min.file_path = file_path_no_suffix
|
||||
self.message_min.to_file(mformat='yaml')
|
||||
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||
# test with explicit file_path
|
||||
self.message_min.file_path = file_path_no_suffix
|
||||
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
|
||||
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||
|
||||
|
||||
class MessageFromFileTxtTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
with open(self.file_path, "w") as fd:
|
||||
fd.write(f"""{TagLine.prefix} tag1 tag2
|
||||
@ -278,7 +239,7 @@ This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||
with open(self.file_path_min, "w") as fd:
|
||||
fd.write(f"""{Question.txt_header}
|
||||
@ -298,13 +259,13 @@ This is a question.
|
||||
message = Message.from_file(self.file_path)
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.ai, 'ChatGPT')
|
||||
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.ai, 'ChatGPT')
|
||||
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_txt_min(self) -> None:
|
||||
"""
|
||||
@ -313,21 +274,21 @@ This is a question.
|
||||
message = Message.from_file(self.file_path_min)
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
self.assertIsNone(message.answer)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
self.assertIsNone(message.answer)
|
||||
|
||||
def test_from_file_txt_tags_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_txt_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
@ -350,13 +311,13 @@ This is a question.
|
||||
MessageFilter(tags_not={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
|
||||
def test_from_file_not_exists(self) -> None:
|
||||
file_not_exists = pathlib.Path(f"example{msg_suffix}")
|
||||
file_not_exists = pathlib.Path("example.txt")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
Message.from_file(file_not_exists)
|
||||
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||
@ -435,7 +396,7 @@ This is a question.
|
||||
|
||||
class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
with open(self.file_path, "w") as fd:
|
||||
fd.write(f"""
|
||||
@ -449,7 +410,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
- tag1
|
||||
- tag2
|
||||
""")
|
||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||
with open(self.file_path_min, "w") as fd:
|
||||
fd.write(f"""
|
||||
@ -470,13 +431,13 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
message = Message.from_file(self.file_path)
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertIsNotNone(message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.ai, 'ChatGPT')
|
||||
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.ai, 'ChatGPT')
|
||||
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_yaml_min(self) -> None:
|
||||
"""
|
||||
@ -485,14 +446,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
message = Message.from_file(self.file_path_min)
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertIsNotNone(message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
self.assertIsNone(message.answer)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
self.assertIsNone(message.answer)
|
||||
|
||||
def test_from_file_not_exists(self) -> None:
|
||||
file_not_exists = pathlib.Path(f"example{msg_suffix}")
|
||||
file_not_exists = pathlib.Path("example.yaml")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
Message.from_file(file_not_exists)
|
||||
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||
@ -502,11 +463,11 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_yaml_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
@ -523,10 +484,10 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
MessageFilter(tags_not={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
assert message
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
|
||||
def test_from_file_yaml_question_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
@ -602,7 +563,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
|
||||
class TagsFromFileTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
||||
with open(self.file_path_txt, "w") as fd:
|
||||
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
|
||||
@ -611,7 +572,7 @@ This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
|
||||
with open(self.file_path_txt_no_tags, "w") as fd:
|
||||
fd.write(f"""{Question.txt_header}
|
||||
@ -619,7 +580,7 @@ This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
|
||||
with open(self.file_path_txt_tags_empty, "w") as fd:
|
||||
fd.write(f"""TAGS:
|
||||
@ -628,7 +589,7 @@ This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
|
||||
with open(self.file_path_yaml, "w") as fd:
|
||||
fd.write(f"""
|
||||
@ -641,7 +602,7 @@ This is an answer.
|
||||
- tag2
|
||||
- ptag3
|
||||
""")
|
||||
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
|
||||
with open(self.file_path_yaml_no_tags, "w") as fd:
|
||||
fd.write(f"""
|
||||
@ -718,25 +679,24 @@ class TagsFromDirTestCase(unittest.TestCase):
|
||||
{Tag('ctag5'), Tag('ctag6')}
|
||||
]
|
||||
self.files = [
|
||||
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
|
||||
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
|
||||
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
|
||||
pathlib.Path(self.temp_dir.name, 'file1.txt'),
|
||||
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
|
||||
pathlib.Path(self.temp_dir.name, 'file3.txt')
|
||||
]
|
||||
self.files_no_tags = [
|
||||
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
|
||||
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
|
||||
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
|
||||
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
|
||||
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
|
||||
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
|
||||
]
|
||||
mformats = itertools.cycle(message_valid_formats)
|
||||
for file, tags in zip(self.files, self.tag_sets):
|
||||
message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
tags)
|
||||
message.to_file(file, next(mformats))
|
||||
message.to_file(file)
|
||||
for file in self.files_no_tags:
|
||||
message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'))
|
||||
message.to_file(file, next(mformats))
|
||||
message.to_file(file)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.temp_dir.cleanup()
|
||||
@ -759,7 +719,7 @@ class TagsFromDirTestCase(unittest.TestCase):
|
||||
|
||||
class MessageIDTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message = Message(Question('This is a question.'),
|
||||
file_path=self.file_path)
|
||||
|
||||
@ -14,9 +14,6 @@ from chatmastermind.ai import AIError
|
||||
from .test_common import TestWithFakeAI
|
||||
|
||||
|
||||
msg_suffix = Message.file_suffix_write
|
||||
|
||||
|
||||
class TestMessageCreate(TestWithFakeAI):
|
||||
"""
|
||||
Test if messages created by the 'question' command have
|
||||
@ -86,7 +83,7 @@ Aaaand again some text."""
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||
# exclude '.next'
|
||||
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
|
||||
return list(Path(tmp_dir.name).glob('*.[ty]*'))
|
||||
|
||||
def test_message_file_created(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
@ -234,7 +231,7 @@ class TestQuestionCmd(TestWithFakeAI):
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||
# exclude '.next'
|
||||
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
|
||||
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
|
||||
|
||||
|
||||
class TestQuestionCmdAsk(TestQuestionCmd):
|
||||
@ -333,16 +330,14 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
Repeat a single question.
|
||||
"""
|
||||
mock_create_ai.side_effect = self.mock_create_ai
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
# create a message
|
||||
message = Message(Question(self.args.ask[0]),
|
||||
Answer('Old Answer'),
|
||||
tags=set(self.args.output_tags),
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||
chat.msg_write([message])
|
||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||
message.to_file()
|
||||
|
||||
# repeat the last question (without overwriting)
|
||||
# -> expect two identical messages (except for the file_path)
|
||||
@ -358,6 +353,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
# we expect the original message + the one with the new response
|
||||
expected_responses = [message] + [expected_response]
|
||||
question_cmd(self.args, self.config)
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
print(self.message_list(self.cache_dir))
|
||||
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||
@ -369,16 +366,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
Repeat a single question and overwrite the old one.
|
||||
"""
|
||||
mock_create_ai.side_effect = self.mock_create_ai
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
# create a message
|
||||
message = Message(Question(self.args.ask[0]),
|
||||
Answer('Old Answer'),
|
||||
tags=set(self.args.output_tags),
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||
chat.msg_write([message])
|
||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||
message.to_file()
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
assert cached_msg[0].file_path
|
||||
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||
@ -408,16 +405,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
Repeat a single question after an error.
|
||||
"""
|
||||
mock_create_ai.side_effect = self.mock_create_ai
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
# create a question WITHOUT an answer
|
||||
# -> just like after an error, which is tested above
|
||||
message = Message(Question(self.args.ask[0]),
|
||||
tags=set(self.args.output_tags),
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||
chat.msg_write([message])
|
||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||
message.to_file()
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
assert cached_msg[0].file_path
|
||||
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||
@ -448,16 +445,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
Repeat a single question with new arguments.
|
||||
"""
|
||||
mock_create_ai.side_effect = self.mock_create_ai
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
# create a message
|
||||
message = Message(Question(self.args.ask[0]),
|
||||
Answer('Old Answer'),
|
||||
tags=set(self.args.output_tags),
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||
chat.msg_write([message])
|
||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||
message.to_file()
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
assert cached_msg[0].file_path
|
||||
|
||||
@ -486,16 +483,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
Repeat a single question with new arguments, overwriting the old one.
|
||||
"""
|
||||
mock_create_ai.side_effect = self.mock_create_ai
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
# create a message
|
||||
message = Message(Question(self.args.ask[0]),
|
||||
Answer('Old Answer'),
|
||||
tags=set(self.args.output_tags),
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||
chat.msg_write([message])
|
||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||
message.to_file()
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
assert cached_msg[0].file_path
|
||||
|
||||
@ -523,29 +520,29 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
Repeat multiple questions.
|
||||
"""
|
||||
mock_create_ai.side_effect = self.mock_create_ai
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
# 1. === create three questions ===
|
||||
# cached message without an answer
|
||||
message1 = Message(Question(self.args.ask[0]),
|
||||
tags=self.args.output_tags,
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
||||
# cached message with an answer
|
||||
message2 = Message(Question(self.args.ask[0]),
|
||||
Answer('Old Answer'),
|
||||
tags=self.args.output_tags,
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
|
||||
file_path=Path(self.cache_dir.name) / '0002.txt')
|
||||
# DB message without an answer
|
||||
message3 = Message(Question(self.args.ask[0]),
|
||||
tags=self.args.output_tags,
|
||||
ai=self.args.AI,
|
||||
model=self.args.model,
|
||||
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
|
||||
chat.msg_write([message1, message2, message3])
|
||||
file_path=Path(self.db_dir.name) / '0003.txt')
|
||||
message1.to_file()
|
||||
message2.to_file()
|
||||
message3.to_file()
|
||||
questions = [message1, message2, message3]
|
||||
expected_responses: list[Message] = []
|
||||
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||
@ -569,6 +566,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
||||
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
|
||||
self.assertEqual(len(self.message_list(self.db_dir)), 1)
|
||||
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
|
||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||
Path(self.db_dir.name))
|
||||
cached_msg = chat.msg_gather(loc='cache')
|
||||
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
|
||||
# check that the DB message has not been modified at all
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user