Compare commits

..

1 Commits

Author SHA1 Message Date
fae835be1f message: introduced '.msg' suffix 2023-09-24 18:20:38 +02:00
5 changed files with 243 additions and 311 deletions

View File

@ -6,9 +6,9 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
from .configuration import default_config_file
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
@ -17,7 +17,6 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
msg_suffix = Message.file_suffix_write
class ChatError(Exception):
@ -53,7 +52,7 @@ def read_dir(dir_path: Path,
for file_path in sorted(file_iter):
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
try:
message = Message.from_file(file_path, mfilter)
if message:
@ -64,20 +63,22 @@ def read_dir(dir_path: Path,
def make_file_path(dir_path: Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
"""
Create a file_path for the given directory using the given ID generator function.
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path,
messages: list[Message],
next_fid: Callable[[], int],
mformat: MessageFormat = Message.default_format) -> None:
file_suffix: str,
next_fid: Callable[[], int]) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
@ -85,17 +86,18 @@ def write_dir(dir_path: Path,
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
file_path = make_file_path(dir_path, next_fid)
file_path = make_file_path(dir_path, file_suffix, next_fid)
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path, mformat=mformat)
message.to_file(file_path)
def clear_dir(dir_path: Path,
@ -107,7 +109,7 @@ def clear_dir(dir_path: Path,
for file_path in file_iter:
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
file_path.unlink(missing_ok=True)
@ -144,7 +146,7 @@ class Chat:
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
@ -279,10 +281,13 @@ class ChatDB(Chat):
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path
db_path: Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
@ -312,7 +317,8 @@ class ChatDB(Chat):
when reading them.
"""
messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter, glob)
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
@classmethod
def from_messages(cls: Type[ChatDBInst],
@ -339,9 +345,7 @@ class ChatDB(Chat):
with open(self.next_path, 'w') as f:
f.write(f'{fid}')
def msg_write(self,
messages: Optional[list[Message]] = None,
mformat: MessageFormat = Message.default_format) -> None:
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the
@ -352,7 +356,7 @@ class ChatDB(Chat):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file(mformat=mformat)
m.to_file()
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
@ -514,6 +518,7 @@ class ChatDB(Chat):
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def cache_add(self, messages: list[Message], write: bool = True) -> None:
@ -526,10 +531,11 @@ class ChatDB(Chat):
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
@ -579,6 +585,7 @@ class ChatDB(Chat):
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None:
@ -591,10 +598,11 @@ class ChatDB(Chat):
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.get_next_fid)
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()

View File

@ -5,8 +5,7 @@ import pathlib
import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
from typing import get_args as typing_get_args
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
@ -16,9 +15,6 @@ MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
MessageFormat = Literal['txt', 'yaml']
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
message_default_format: Final[MessageFormat] = 'txt'
class MessageError(Exception):
@ -222,7 +218,8 @@ class Message():
# class variables
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
file_suffix_write: ClassVar[str] = '.msg'
default_format: ClassVar[MessageFormat] = message_default_format
file_formats: ClassVar[list[str]] = ['txt', 'yaml']
default_format: ClassVar[str] = 'txt'
tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai'
@ -372,6 +369,10 @@ class Message():
tags = TagLine(fd.readline()).tags()
except TagError:
fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional)
try:
pos = fd.tell()
@ -386,23 +387,17 @@ class Message():
fd.seek(pos)
# Question and Answer
text = fd.read().strip().split('\n')
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"'{file_path}' does not contain a valid message")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
# match tags AFTER reading the whole file
# -> make sure it's a valid 'txt' file format
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
return cls(question, answer, tags, ai, model, file_path)
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"'{file_path}' does not contain a valid message")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
return cls(question, answer, tags, ai, model, file_path)
@classmethod
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
@ -443,7 +438,7 @@ class Message():
output.append(self.answer)
return '\n'.join(output)
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: str = Message.default_format) -> None: # noqa: 11
"""
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
Suffix is always '.msg'.
@ -452,14 +447,10 @@ class Message():
self.file_path = file_path
if not self.file_path:
raise MessageError("Got no valid path to write message")
if mformat not in message_valid_formats:
if mformat not in self.file_formats:
raise MessageError(f"File format '{mformat}' is not supported")
# check for valid suffix
# -> add one if it's empty
# -> refuse old or otherwise unsupported suffixes
if not self.file_path.suffix:
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
elif self.file_path.suffix != self.file_suffix_write:
# do not write old suffixes
if self.file_path.suffix != self.file_suffix_write:
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
# TXT
if mformat == 'txt':

View File

@ -10,20 +10,6 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, ChatError
msg_suffix: str = Message.file_suffix_write
def msg_to_file_force_suffix(msg: Message) -> None:
"""
Force writing a message file with illegal suffixes.
"""
def_suffix = Message.file_suffix_write
assert msg.file_path
Message.file_suffix_write = msg.file_path.suffix
msg.to_file()
Message.file_suffix_write = def_suffix
class TestChatBase(unittest.TestCase):
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
"""
@ -41,11 +27,11 @@ class TestChat(TestChatBase):
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path(f'0001{msg_suffix}'))
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path(f'0002{msg_suffix}'))
file_path=pathlib.Path('0002.txt'))
self.maxDiff = None
def test_unique_id(self) -> None:
@ -113,24 +99,24 @@ class TestChat(TestChatBase):
def test_find_remove_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.msg_find(['0001'])
msgs = self.chat.msg_find(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.msg_find(['0001', '0002'])
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.msg_add([message3])
# find new Message by full path
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
msgs = self.chat.msg_find(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.msg_remove(['0003'])
self.chat.msg_remove(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
@ -160,13 +146,13 @@ Answer 2
self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001{msg_suffix}
FILE: 0001.txt
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{TagLine.prefix} btag2
FILE: 0002{msg_suffix}
FILE: 0002.txt
{Question.txt_header}
Question 2
{Answer.txt_header}
@ -182,27 +168,31 @@ class TestChatDB(TestChatBase):
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')})
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')})
{Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'),
Answer('Answer 3'),
{Tag('tag3')})
{Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')})
{Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
@ -217,7 +207,7 @@ class TestChatDB(TestChatBase):
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
def tearDown(self) -> None:
self.db_path.cleanup()
@ -228,31 +218,13 @@ class TestChatDB(TestChatBase):
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
msg_to_file_force_suffix(duplicate_message)
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_file_path_ID_exists(self) -> None:
"""
Tests if the CacheDB chooses another ID if a file path with
the given one exists.
"""
# create a new and empty CacheDB
db_path = tempfile.TemporaryDirectory()
cache_path = tempfile.TemporaryDirectory()
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
pathlib.Path(db_path.name))
# add a message file
message = Message(Question('What?'),
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
message.to_file()
message1 = Message(Question('Where?'))
chat_db.cache_write([message1])
self.assertEqual(message1.msg_id(), '0002')
def test_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -261,23 +233,25 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*1.*')
self.assertEqual(len(chat_db.messages), 1)
glob='*.txt')
self.assertEqual(len(chat_db.messages), 2)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@ -287,7 +261,7 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
pathlib.Path(self.db_path.name, '0001.txt'))
def test_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
@ -305,7 +279,7 @@ class TestChatDB(TestChatBase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_from_messages(self) -> None:
@ -350,25 +324,25 @@ class TestChatDB(TestChatBase):
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.cache_write()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
# check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path)
@ -380,18 +354,18 @@ class TestChatDB(TestChatBase):
# check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# check if all files in the DB dir have actually been overwritten
for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_db_read(self) -> None:
# create a new ChatDB instance
@ -406,65 +380,65 @@ class TestChatDB(TestChatBase):
new_message2 = Message(Question('Question 6'),
Answer('Answer 6'),
{Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'),
Answer('Answer 7'),
Answer('Answer 5'),
{Tag('tag7')})
new_message4 = Message(Question('Question 8'),
Answer('Answer 8'),
Answer('Answer 6'),
{Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them
chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
# an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them
chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_cache_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.cache_write()
@ -476,10 +450,10 @@ class TestChatDB(TestChatBase):
chat_db.db_write()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"),
@ -487,7 +461,7 @@ class TestChatDB(TestChatBase):
# and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005'))
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir
@ -549,11 +523,11 @@ class TestChatDB(TestChatBase):
chat_db.msg_write([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_msg_update(self) -> None:
# create a new ChatDB instance
@ -589,21 +563,21 @@ class TestChatDB(TestChatBase):
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, loc='all')
self.assert_messages_equal(result, expected_result)

View File

@ -1,16 +1,11 @@
import unittest
import pathlib
import tempfile
import itertools
from typing import cast
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
MessageFilter, message_in, message_valid_formats
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine
msg_suffix: str = Message.file_suffix_write
class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
@ -106,7 +101,7 @@ class AnswerTestCase(unittest.TestCase):
class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
@ -122,7 +117,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
self.file_path.unlink()
def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path, mformat='txt')
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@ -137,7 +132,7 @@ This is an answer.
self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path, mformat='txt')
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@ -146,17 +141,11 @@ This is a question.
"""
self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_suffix(self) -> None:
def test_to_file_unsupported_file_type(self) -> None:
unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
def test_to_file_unsupported_file_format(self) -> None:
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
def test_to_file_no_file_path(self) -> None:
"""
@ -170,24 +159,10 @@ This is a question.
# reset the internal file_path
self.message_complete.file_path = self.file_path
def test_to_file_txt_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
@ -209,7 +184,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path, mformat='yaml')
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@ -224,7 +199,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path, mformat='yaml')
self.message_multiline.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
@ -243,31 +218,17 @@ class MessageToFileYamlTestCase(unittest.TestCase):
self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path, mformat='yaml')
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content)
def test_to_file_yaml_auto_suffix(self) -> None:
"""
Test if suffix is auto-generated if omitted.
"""
file_path_no_suffix = self.file_path.with_suffix('')
# test with file_path member
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
# test with explicit file_path
self.message_min.file_path = file_path_no_suffix
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2
@ -278,7 +239,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header}
@ -298,13 +259,13 @@ This is a question.
message = Message.from_file(self.file_path)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_min(self) -> None:
"""
@ -313,21 +274,21 @@ This is a question.
message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_txt_tags_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
@ -350,13 +311,13 @@ This is a question.
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path(f"example{msg_suffix}")
file_not_exists = pathlib.Path("example.txt")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@ -435,7 +396,7 @@ This is a question.
class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""
@ -449,7 +410,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
- tag1
- tag2
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""
@ -470,13 +431,13 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_min(self) -> None:
"""
@ -485,14 +446,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path(f"example{msg_suffix}")
file_not_exists = pathlib.Path("example.yaml")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
@ -502,11 +463,11 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
@ -523,10 +484,10 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
assert message
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_yaml_question_match(self) -> None:
message = Message.from_file(self.file_path,
@ -602,7 +563,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
@ -611,7 +572,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header}
@ -619,7 +580,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS:
@ -628,7 +589,7 @@ This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd:
fd.write(f"""
@ -641,7 +602,7 @@ This is an answer.
- tag2
- ptag3
""")
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f"""
@ -718,25 +679,24 @@ class TagsFromDirTestCase(unittest.TestCase):
{Tag('ctag5'), Tag('ctag6')}
]
self.files = [
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
pathlib.Path(self.temp_dir.name, 'file1.txt'),
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
pathlib.Path(self.temp_dir.name, 'file3.txt')
]
self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
]
mformats = itertools.cycle(message_valid_formats)
for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags)
message.to_file(file, next(mformats))
message.to_file(file)
for file in self.files_no_tags:
message = Message(Question('This is a question.'),
Answer('This is an answer.'))
message.to_file(file, next(mformats))
message.to_file(file)
def tearDown(self) -> None:
self.temp_dir.cleanup()
@ -759,7 +719,7 @@ class TagsFromDirTestCase(unittest.TestCase):
class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)

View File

@ -14,9 +14,6 @@ from chatmastermind.ai import AIError
from .test_common import TestWithFakeAI
msg_suffix = Message.file_suffix_write
class TestMessageCreate(TestWithFakeAI):
"""
Test if messages created by the 'question' command have
@ -86,7 +83,7 @@ Aaaand again some text."""
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
@ -234,7 +231,7 @@ class TestQuestionCmd(TestWithFakeAI):
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
class TestQuestionCmdAsk(TestQuestionCmd):
@ -333,16 +330,14 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
# repeat the last question (without overwriting)
# -> expect two identical messages (except for the file_path)
@ -358,6 +353,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
# we expect the original message + the one with the new response
expected_responses = [message] + [expected_response]
question_cmd(self.args, self.config)
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
print(self.message_list(self.cache_dir))
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
@ -369,16 +366,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question and overwrite the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
@ -408,16 +405,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question after an error.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a question WITHOUT an answer
# -> just like after an error, which is tested above
message = Message(Question(self.args.ask[0]),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
cached_msg_file_id = cached_msg[0].file_path.stem
@ -448,16 +445,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question with new arguments.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
@ -486,16 +483,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat a single question with new arguments, overwriting the old one.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# create a message
message = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=set(self.args.output_tags),
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
chat.msg_write([message])
file_path=Path(self.cache_dir.name) / '0001.txt')
message.to_file()
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
assert cached_msg[0].file_path
@ -523,29 +520,29 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
Repeat multiple questions.
"""
mock_create_ai.side_effect = self.mock_create_ai
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
# 1. === create three questions ===
# cached message without an answer
message1 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
file_path=Path(self.cache_dir.name) / '0001.txt')
# cached message with an answer
message2 = Message(Question(self.args.ask[0]),
Answer('Old Answer'),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
file_path=Path(self.cache_dir.name) / '0002.txt')
# DB message without an answer
message3 = Message(Question(self.args.ask[0]),
tags=self.args.output_tags,
ai=self.args.AI,
model=self.args.model,
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
chat.msg_write([message1, message2, message3])
file_path=Path(self.db_dir.name) / '0003.txt')
message1.to_file()
message2.to_file()
message3.to_file()
questions = [message1, message2, message3]
expected_responses: list[Message] = []
fake_ai = self.mock_create_ai(self.args, self.config)
@ -569,6 +566,8 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
self.assertEqual(len(self.message_list(self.db_dir)), 1)
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
chat = ChatDB.from_dir(Path(self.cache_dir.name),
Path(self.db_dir.name))
cached_msg = chat.msg_gather(loc='cache')
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
# check that the DB message has not been modified at all