Compare commits
7 Commits
fae835be1f
...
3c1c9860a0
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c1c9860a0 | |||
| 3bc5f7cd63 | |||
| f352d71177 | |||
| 11d50ae551 | |||
| 1743802262 | |||
| bdce69e741 | |||
| 499f6a7be9 |
@ -6,9 +6,9 @@ from pathlib import Path
|
|||||||
from pprint import PrettyPrinter
|
from pprint import PrettyPrinter
|
||||||
from pydoc import pager
|
from pydoc import pager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal, Union
|
from typing import TypeVar, Type, Optional, Any, Callable, Literal, Union
|
||||||
from .configuration import default_config_file
|
from .configuration import default_config_file
|
||||||
from .message import Message, MessageFilter, MessageError, message_in
|
from .message import Message, MessageFilter, MessageError, MessageFormat, message_in
|
||||||
from .tags import Tag
|
from .tags import Tag
|
||||||
|
|
||||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||||
@ -17,6 +17,7 @@ ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
|||||||
db_next_file = '.next'
|
db_next_file = '.next'
|
||||||
ignored_files = [db_next_file, default_config_file]
|
ignored_files = [db_next_file, default_config_file]
|
||||||
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
|
msg_location = Literal['mem', 'disk', 'cache', 'db', 'all']
|
||||||
|
msg_suffix = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
class ChatError(Exception):
|
class ChatError(Exception):
|
||||||
@ -52,7 +53,7 @@ def read_dir(dir_path: Path,
|
|||||||
for file_path in sorted(file_iter):
|
for file_path in sorted(file_iter):
|
||||||
if (file_path.is_file()
|
if (file_path.is_file()
|
||||||
and file_path.name not in ignored_files # noqa: W503
|
and file_path.name not in ignored_files # noqa: W503
|
||||||
and file_path.suffix in Message.file_suffixes): # noqa: W503
|
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
|
||||||
try:
|
try:
|
||||||
message = Message.from_file(file_path, mfilter)
|
message = Message.from_file(file_path, mfilter)
|
||||||
if message:
|
if message:
|
||||||
@ -63,22 +64,20 @@ def read_dir(dir_path: Path,
|
|||||||
|
|
||||||
|
|
||||||
def make_file_path(dir_path: Path,
|
def make_file_path(dir_path: Path,
|
||||||
file_suffix: str,
|
|
||||||
next_fid: Callable[[], int]) -> Path:
|
next_fid: Callable[[], int]) -> Path:
|
||||||
"""
|
"""
|
||||||
Create a file_path for the given directory using the
|
Create a file_path for the given directory using the given ID generator function.
|
||||||
given file_suffix and ID generator function.
|
|
||||||
"""
|
"""
|
||||||
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
|
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
|
||||||
while file_path.exists():
|
while file_path.exists():
|
||||||
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
|
file_path = dir_path / f"{next_fid():04d}{msg_suffix}"
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def write_dir(dir_path: Path,
|
def write_dir(dir_path: Path,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
file_suffix: str,
|
next_fid: Callable[[], int],
|
||||||
next_fid: Callable[[], int]) -> None:
|
mformat: MessageFormat = Message.default_format) -> None:
|
||||||
"""
|
"""
|
||||||
Write all messages to the given directory. If a message has no file_path,
|
Write all messages to the given directory. If a message has no file_path,
|
||||||
a new one will be created. If message.file_path exists, it will be modified
|
a new one will be created. If message.file_path exists, it will be modified
|
||||||
@ -86,18 +85,17 @@ def write_dir(dir_path: Path,
|
|||||||
Parameters:
|
Parameters:
|
||||||
* 'dir_path': destination directory
|
* 'dir_path': destination directory
|
||||||
* 'messages': list of messages to write
|
* 'messages': list of messages to write
|
||||||
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
|
|
||||||
* 'next_fid': callable that returns the next file ID
|
* 'next_fid': callable that returns the next file ID
|
||||||
"""
|
"""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
file_path = message.file_path
|
file_path = message.file_path
|
||||||
# message has no file_path: create one
|
# message has no file_path: create one
|
||||||
if not file_path:
|
if not file_path:
|
||||||
file_path = make_file_path(dir_path, file_suffix, next_fid)
|
file_path = make_file_path(dir_path, next_fid)
|
||||||
# file_path does not point to given directory: modify it
|
# file_path does not point to given directory: modify it
|
||||||
elif not file_path.parent.samefile(dir_path):
|
elif not file_path.parent.samefile(dir_path):
|
||||||
file_path = dir_path / file_path.name
|
file_path = dir_path / file_path.name
|
||||||
message.to_file(file_path)
|
message.to_file(file_path, mformat=mformat)
|
||||||
|
|
||||||
|
|
||||||
def clear_dir(dir_path: Path,
|
def clear_dir(dir_path: Path,
|
||||||
@ -109,7 +107,7 @@ def clear_dir(dir_path: Path,
|
|||||||
for file_path in file_iter:
|
for file_path in file_iter:
|
||||||
if (file_path.is_file()
|
if (file_path.is_file()
|
||||||
and file_path.name not in ignored_files # noqa: W503
|
and file_path.name not in ignored_files # noqa: W503
|
||||||
and file_path.suffix in Message.file_suffixes): # noqa: W503
|
and file_path.suffix in Message.file_suffixes_read): # noqa: W503
|
||||||
file_path.unlink(missing_ok=True)
|
file_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +144,7 @@ class Chat:
|
|||||||
Matching is True if:
|
Matching is True if:
|
||||||
* 'name' matches the full 'file_path'
|
* 'name' matches the full 'file_path'
|
||||||
* 'name' matches 'file_path.name' (i. e. including the suffix)
|
* 'name' matches 'file_path.name' (i. e. including the suffix)
|
||||||
* 'name' matches 'file_path.stem' (i. e. without a suffix)
|
* 'name' matches 'file_path.stem' (i. e. without the suffix)
|
||||||
"""
|
"""
|
||||||
return Path(name) == file_path or name == file_path.name or name == file_path.stem
|
return Path(name) == file_path or name == file_path.name or name == file_path.stem
|
||||||
|
|
||||||
@ -281,13 +279,10 @@ class ChatDB(Chat):
|
|||||||
persistently.
|
persistently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_file_suffix: ClassVar[str] = '.txt'
|
|
||||||
|
|
||||||
cache_path: Path
|
cache_path: Path
|
||||||
db_path: Path
|
db_path: Path
|
||||||
# a MessageFilter that all messages must match (if given)
|
# a MessageFilter that all messages must match (if given)
|
||||||
mfilter: Optional[MessageFilter] = None
|
mfilter: Optional[MessageFilter] = None
|
||||||
file_suffix: str = default_file_suffix
|
|
||||||
# the glob pattern for all messages
|
# the glob pattern for all messages
|
||||||
glob: Optional[str] = None
|
glob: Optional[str] = None
|
||||||
|
|
||||||
@ -317,8 +312,7 @@ class ChatDB(Chat):
|
|||||||
when reading them.
|
when reading them.
|
||||||
"""
|
"""
|
||||||
messages = read_dir(db_path, glob, mfilter)
|
messages = read_dir(db_path, glob, mfilter)
|
||||||
return cls(messages, cache_path, db_path, mfilter,
|
return cls(messages, cache_path, db_path, mfilter, glob)
|
||||||
cls.default_file_suffix, glob)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_messages(cls: Type[ChatDBInst],
|
def from_messages(cls: Type[ChatDBInst],
|
||||||
@ -345,7 +339,9 @@ class ChatDB(Chat):
|
|||||||
with open(self.next_path, 'w') as f:
|
with open(self.next_path, 'w') as f:
|
||||||
f.write(f'{fid}')
|
f.write(f'{fid}')
|
||||||
|
|
||||||
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
|
def msg_write(self,
|
||||||
|
messages: Optional[list[Message]] = None,
|
||||||
|
mformat: MessageFormat = Message.default_format) -> None:
|
||||||
"""
|
"""
|
||||||
Write either the given messages or the internal ones to their CURRENT file_path.
|
Write either the given messages or the internal ones to their CURRENT file_path.
|
||||||
If messages are given, they all must have a valid file_path. When writing the
|
If messages are given, they all must have a valid file_path. When writing the
|
||||||
@ -356,7 +352,7 @@ class ChatDB(Chat):
|
|||||||
raise ChatError("Can't write files without a valid file_path")
|
raise ChatError("Can't write files without a valid file_path")
|
||||||
msgs = iter(messages if messages else self.messages)
|
msgs = iter(messages if messages else self.messages)
|
||||||
while (m := next(msgs, None)):
|
while (m := next(msgs, None)):
|
||||||
m.to_file()
|
m.to_file(mformat=mformat)
|
||||||
|
|
||||||
def msg_update(self, messages: list[Message], write: bool = True) -> None:
|
def msg_update(self, messages: list[Message], write: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
@ -518,7 +514,6 @@ class ChatDB(Chat):
|
|||||||
"""
|
"""
|
||||||
write_dir(self.cache_path,
|
write_dir(self.cache_path,
|
||||||
messages if messages else self.messages,
|
messages if messages else self.messages,
|
||||||
self.file_suffix,
|
|
||||||
self.get_next_fid)
|
self.get_next_fid)
|
||||||
|
|
||||||
def cache_add(self, messages: list[Message], write: bool = True) -> None:
|
def cache_add(self, messages: list[Message], write: bool = True) -> None:
|
||||||
@ -531,11 +526,10 @@ class ChatDB(Chat):
|
|||||||
if write:
|
if write:
|
||||||
write_dir(self.cache_path,
|
write_dir(self.cache_path,
|
||||||
messages,
|
messages,
|
||||||
self.file_suffix,
|
|
||||||
self.get_next_fid)
|
self.get_next_fid)
|
||||||
else:
|
else:
|
||||||
for m in messages:
|
for m in messages:
|
||||||
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
|
m.file_path = make_file_path(self.cache_path, self.get_next_fid)
|
||||||
self.messages += messages
|
self.messages += messages
|
||||||
self.msg_sort()
|
self.msg_sort()
|
||||||
|
|
||||||
@ -585,7 +579,6 @@ class ChatDB(Chat):
|
|||||||
"""
|
"""
|
||||||
write_dir(self.db_path,
|
write_dir(self.db_path,
|
||||||
messages if messages else self.messages,
|
messages if messages else self.messages,
|
||||||
self.file_suffix,
|
|
||||||
self.get_next_fid)
|
self.get_next_fid)
|
||||||
|
|
||||||
def db_add(self, messages: list[Message], write: bool = True) -> None:
|
def db_add(self, messages: list[Message], write: bool = True) -> None:
|
||||||
@ -598,11 +591,10 @@ class ChatDB(Chat):
|
|||||||
if write:
|
if write:
|
||||||
write_dir(self.db_path,
|
write_dir(self.db_path,
|
||||||
messages,
|
messages,
|
||||||
self.file_suffix,
|
|
||||||
self.get_next_fid)
|
self.get_next_fid)
|
||||||
else:
|
else:
|
||||||
for m in messages:
|
for m in messages:
|
||||||
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
|
m.file_path = make_file_path(self.db_path, self.get_next_fid)
|
||||||
self.messages += messages
|
self.messages += messages
|
||||||
self.msg_sort()
|
self.msg_sort()
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,8 @@ import pathlib
|
|||||||
import yaml
|
import yaml
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
|
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable, Tuple
|
||||||
|
from typing import get_args as typing_get_args
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
|
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
|
||||||
|
|
||||||
@ -15,6 +16,9 @@ MessageInst = TypeVar('MessageInst', bound='Message')
|
|||||||
AILineInst = TypeVar('AILineInst', bound='AILine')
|
AILineInst = TypeVar('AILineInst', bound='AILine')
|
||||||
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
|
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
|
||||||
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
|
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
|
||||||
|
MessageFormat = Literal['txt', 'yaml']
|
||||||
|
message_valid_formats: Final[Tuple[MessageFormat, ...]] = typing_get_args(MessageFormat)
|
||||||
|
message_default_format: Final[MessageFormat] = 'txt'
|
||||||
|
|
||||||
|
|
||||||
class MessageError(Exception):
|
class MessageError(Exception):
|
||||||
@ -92,7 +96,7 @@ class MessageFilter:
|
|||||||
|
|
||||||
class AILine(str):
|
class AILine(str):
|
||||||
"""
|
"""
|
||||||
A line that represents the AI name in a '.txt' file..
|
A line that represents the AI name in the 'txt' format.
|
||||||
"""
|
"""
|
||||||
prefix: Final[str] = 'AI:'
|
prefix: Final[str] = 'AI:'
|
||||||
|
|
||||||
@ -112,7 +116,7 @@ class AILine(str):
|
|||||||
|
|
||||||
class ModelLine(str):
|
class ModelLine(str):
|
||||||
"""
|
"""
|
||||||
A line that represents the model name in a '.txt' file..
|
A line that represents the model name in the 'txt' format.
|
||||||
"""
|
"""
|
||||||
prefix: Final[str] = 'MODEL:'
|
prefix: Final[str] = 'MODEL:'
|
||||||
|
|
||||||
@ -216,7 +220,9 @@ class Message():
|
|||||||
model: Optional[str] = field(default=None, compare=False)
|
model: Optional[str] = field(default=None, compare=False)
|
||||||
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
|
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
|
||||||
# class variables
|
# class variables
|
||||||
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
|
file_suffixes_read: ClassVar[list[str]] = ['.msg', '.txt', '.yaml']
|
||||||
|
file_suffix_write: ClassVar[str] = '.msg'
|
||||||
|
default_format: ClassVar[MessageFormat] = message_default_format
|
||||||
tags_yaml_key: ClassVar[str] = 'tags'
|
tags_yaml_key: ClassVar[str] = 'tags'
|
||||||
file_yaml_key: ClassVar[str] = 'file_path'
|
file_yaml_key: ClassVar[str] = 'file_path'
|
||||||
ai_yaml_key: ClassVar[str] = 'ai'
|
ai_yaml_key: ClassVar[str] = 'ai'
|
||||||
@ -276,16 +282,8 @@ class Message():
|
|||||||
tags: set[Tag] = set()
|
tags: set[Tag] = set()
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
raise MessageError(f"Message file '{file_path}' does not exist")
|
raise MessageError(f"Message file '{file_path}' does not exist")
|
||||||
if file_path.suffix not in cls.file_suffixes:
|
if file_path.suffix not in cls.file_suffixes_read:
|
||||||
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
||||||
# for TXT, it's enough to read the TagLine
|
|
||||||
if file_path.suffix == '.txt':
|
|
||||||
with open(file_path, "r") as fd:
|
|
||||||
try:
|
|
||||||
tags = TagLine(fd.readline()).tags(prefix, contain)
|
|
||||||
except TagError:
|
|
||||||
pass # message without tags
|
|
||||||
else: # '.yaml'
|
|
||||||
try:
|
try:
|
||||||
message = cls.from_file(file_path)
|
message = cls.from_file(file_path)
|
||||||
if message:
|
if message:
|
||||||
@ -328,15 +326,16 @@ class Message():
|
|||||||
"""
|
"""
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
raise MessageError(f"Message file '{file_path}' does not exist")
|
raise MessageError(f"Message file '{file_path}' does not exist")
|
||||||
if file_path.suffix not in cls.file_suffixes:
|
if file_path.suffix not in cls.file_suffixes_read:
|
||||||
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
||||||
|
# try TXT first
|
||||||
if file_path.suffix == '.txt':
|
try:
|
||||||
message = cls.__from_file_txt(file_path,
|
message = cls.__from_file_txt(file_path,
|
||||||
mfilter.tags_or if mfilter else None,
|
mfilter.tags_or if mfilter else None,
|
||||||
mfilter.tags_and if mfilter else None,
|
mfilter.tags_and if mfilter else None,
|
||||||
mfilter.tags_not if mfilter else None)
|
mfilter.tags_not if mfilter else None)
|
||||||
else:
|
# then YAML
|
||||||
|
except MessageError:
|
||||||
message = cls.__from_file_yaml(file_path)
|
message = cls.__from_file_yaml(file_path)
|
||||||
if message and (mfilter is None or message.match(mfilter)):
|
if message and (mfilter is None or message.match(mfilter)):
|
||||||
return message
|
return message
|
||||||
@ -373,10 +372,6 @@ class Message():
|
|||||||
tags = TagLine(fd.readline()).tags()
|
tags = TagLine(fd.readline()).tags()
|
||||||
except TagError:
|
except TagError:
|
||||||
fd.seek(pos)
|
fd.seek(pos)
|
||||||
if tags_or or tags_and or tags_not:
|
|
||||||
# match with an empty set if the file has no tags
|
|
||||||
if not match_tags(tags, tags_or, tags_and, tags_not):
|
|
||||||
return None
|
|
||||||
# AILine (Optional)
|
# AILine (Optional)
|
||||||
try:
|
try:
|
||||||
pos = fd.tell()
|
pos = fd.tell()
|
||||||
@ -401,6 +396,12 @@ class Message():
|
|||||||
answer = Answer.from_list(text[answer_idx + 1:])
|
answer = Answer.from_list(text[answer_idx + 1:])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
question = Question.from_list(text[question_idx:])
|
question = Question.from_list(text[question_idx:])
|
||||||
|
# match tags AFTER reading the whole file
|
||||||
|
# -> make sure it's a valid 'txt' file format
|
||||||
|
if tags_or or tags_and or tags_not:
|
||||||
|
# match with an empty set if the file has no tags
|
||||||
|
if not match_tags(tags, tags_or, tags_and, tags_not):
|
||||||
|
return None
|
||||||
return cls(question, answer, tags, ai, model, file_path)
|
return cls(question, answer, tags, ai, model, file_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -442,21 +443,29 @@ class Message():
|
|||||||
output.append(self.answer)
|
output.append(self.answer)
|
||||||
return '\n'.join(output)
|
return '\n'.join(output)
|
||||||
|
|
||||||
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
|
def to_file(self, file_path: Optional[pathlib.Path]=None, mformat: MessageFormat = message_default_format) -> None: # noqa: 11
|
||||||
"""
|
"""
|
||||||
Write a Message to the given file. Type is determined based on the suffix.
|
Write a Message to the given file. Supported message file formats are 'txt' and 'yaml'.
|
||||||
Currently supported suffixes: ['.txt', '.yaml']
|
Suffix is always '.msg'.
|
||||||
"""
|
"""
|
||||||
if file_path:
|
if file_path:
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
if not self.file_path:
|
if not self.file_path:
|
||||||
raise MessageError("Got no valid path to write message")
|
raise MessageError("Got no valid path to write message")
|
||||||
if self.file_path.suffix not in self.file_suffixes:
|
if mformat not in message_valid_formats:
|
||||||
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
|
raise MessageError(f"File format '{mformat}' is not supported")
|
||||||
|
# check for valid suffix
|
||||||
|
# -> add one if it's empty
|
||||||
|
# -> refuse old or otherwise unsupported suffixes
|
||||||
|
if not self.file_path.suffix:
|
||||||
|
self.file_path = self.file_path.with_suffix(self.file_suffix_write)
|
||||||
|
elif self.file_path.suffix != self.file_suffix_write:
|
||||||
|
raise MessageError(f"File suffix '{self.file_path.suffix}' is not supported")
|
||||||
# TXT
|
# TXT
|
||||||
if self.file_path.suffix == '.txt':
|
if mformat == 'txt':
|
||||||
return self.__to_file_txt(self.file_path)
|
return self.__to_file_txt(self.file_path)
|
||||||
elif self.file_path.suffix == '.yaml':
|
# YAML
|
||||||
|
elif mformat == 'yaml':
|
||||||
return self.__to_file_yaml(self.file_path)
|
return self.__to_file_yaml(self.file_path)
|
||||||
|
|
||||||
def __to_file_txt(self, file_path: pathlib.Path) -> None:
|
def __to_file_txt(self, file_path: pathlib.Path) -> None:
|
||||||
@ -468,8 +477,8 @@ class Message():
|
|||||||
* Model [Optional]
|
* Model [Optional]
|
||||||
* Question.txt_header
|
* Question.txt_header
|
||||||
* Question
|
* Question
|
||||||
* Answer.txt_header
|
* Answer.txt_header [Optional]
|
||||||
* Answer
|
* Answer [Optional]
|
||||||
"""
|
"""
|
||||||
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
||||||
temp_file_path = pathlib.Path(temp_fd.name)
|
temp_file_path = pathlib.Path(temp_fd.name)
|
||||||
|
|||||||
@ -10,6 +10,20 @@ from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
|||||||
from chatmastermind.chat import Chat, ChatDB, ChatError
|
from chatmastermind.chat import Chat, ChatDB, ChatError
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix: str = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
|
def msg_to_file_force_suffix(msg: Message) -> None:
|
||||||
|
"""
|
||||||
|
Force writing a message file with illegal suffixes.
|
||||||
|
"""
|
||||||
|
def_suffix = Message.file_suffix_write
|
||||||
|
assert msg.file_path
|
||||||
|
Message.file_suffix_write = msg.file_path.suffix
|
||||||
|
msg.to_file()
|
||||||
|
Message.file_suffix_write = def_suffix
|
||||||
|
|
||||||
|
|
||||||
class TestChatBase(unittest.TestCase):
|
class TestChatBase(unittest.TestCase):
|
||||||
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
def assert_messages_equal(self, msg1: list[Message], msg2: list[Message]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -27,11 +41,11 @@ class TestChat(TestChatBase):
|
|||||||
self.message1 = Message(Question('Question 1'),
|
self.message1 = Message(Question('Question 1'),
|
||||||
Answer('Answer 1'),
|
Answer('Answer 1'),
|
||||||
{Tag('atag1'), Tag('btag2')},
|
{Tag('atag1'), Tag('btag2')},
|
||||||
file_path=pathlib.Path('0001.txt'))
|
file_path=pathlib.Path(f'0001{msg_suffix}'))
|
||||||
self.message2 = Message(Question('Question 2'),
|
self.message2 = Message(Question('Question 2'),
|
||||||
Answer('Answer 2'),
|
Answer('Answer 2'),
|
||||||
{Tag('btag2')},
|
{Tag('btag2')},
|
||||||
file_path=pathlib.Path('0002.txt'))
|
file_path=pathlib.Path(f'0002{msg_suffix}'))
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
|
||||||
def test_unique_id(self) -> None:
|
def test_unique_id(self) -> None:
|
||||||
@ -99,24 +113,24 @@ class TestChat(TestChatBase):
|
|||||||
|
|
||||||
def test_find_remove_messages(self) -> None:
|
def test_find_remove_messages(self) -> None:
|
||||||
self.chat.msg_add([self.message1, self.message2])
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
msgs = self.chat.msg_find(['0001.txt'])
|
msgs = self.chat.msg_find(['0001'])
|
||||||
self.assertListEqual(msgs, [self.message1])
|
self.assertListEqual(msgs, [self.message1])
|
||||||
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
|
msgs = self.chat.msg_find(['0001', '0002'])
|
||||||
self.assertListEqual(msgs, [self.message1, self.message2])
|
self.assertListEqual(msgs, [self.message1, self.message2])
|
||||||
# add new Message with full path
|
# add new Message with full path
|
||||||
message3 = Message(Question('Question 2'),
|
message3 = Message(Question('Question 2'),
|
||||||
Answer('Answer 2'),
|
Answer('Answer 2'),
|
||||||
{Tag('btag2')},
|
{Tag('btag2')},
|
||||||
file_path=pathlib.Path('/foo/bla/0003.txt'))
|
file_path=pathlib.Path(f'/foo/bla/0003{msg_suffix}'))
|
||||||
self.chat.msg_add([message3])
|
self.chat.msg_add([message3])
|
||||||
# find new Message by full path
|
# find new Message by full path
|
||||||
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
|
msgs = self.chat.msg_find([f'/foo/bla/0003{msg_suffix}'])
|
||||||
self.assertListEqual(msgs, [message3])
|
self.assertListEqual(msgs, [message3])
|
||||||
# find Message with full path only by filename
|
# find Message with full path only by filename
|
||||||
msgs = self.chat.msg_find(['0003.txt'])
|
msgs = self.chat.msg_find([f'0003{msg_suffix}'])
|
||||||
self.assertListEqual(msgs, [message3])
|
self.assertListEqual(msgs, [message3])
|
||||||
# remove last message
|
# remove last message
|
||||||
self.chat.msg_remove(['0003.txt'])
|
self.chat.msg_remove(['0003'])
|
||||||
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
||||||
|
|
||||||
def test_latest_message(self) -> None:
|
def test_latest_message(self) -> None:
|
||||||
@ -146,13 +160,13 @@ Answer 2
|
|||||||
self.chat.msg_add([self.message1, self.message2])
|
self.chat.msg_add([self.message1, self.message2])
|
||||||
self.chat.print(paged=False, with_tags=True, with_files=True)
|
self.chat.print(paged=False, with_tags=True, with_files=True)
|
||||||
expected_output = f"""{TagLine.prefix} atag1 btag2
|
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||||
FILE: 0001.txt
|
FILE: 0001{msg_suffix}
|
||||||
{Question.txt_header}
|
{Question.txt_header}
|
||||||
Question 1
|
Question 1
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 1
|
Answer 1
|
||||||
{TagLine.prefix} btag2
|
{TagLine.prefix} btag2
|
||||||
FILE: 0002.txt
|
FILE: 0002{msg_suffix}
|
||||||
{Question.txt_header}
|
{Question.txt_header}
|
||||||
Question 2
|
Question 2
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
@ -168,31 +182,27 @@ class TestChatDB(TestChatBase):
|
|||||||
|
|
||||||
self.message1 = Message(Question('Question 1'),
|
self.message1 = Message(Question('Question 1'),
|
||||||
Answer('Answer 1'),
|
Answer('Answer 1'),
|
||||||
{Tag('tag1')},
|
{Tag('tag1')})
|
||||||
file_path=pathlib.Path('0001.txt'))
|
|
||||||
self.message2 = Message(Question('Question 2'),
|
self.message2 = Message(Question('Question 2'),
|
||||||
Answer('Answer 2'),
|
Answer('Answer 2'),
|
||||||
{Tag('tag2')},
|
{Tag('tag2')})
|
||||||
file_path=pathlib.Path('0002.yaml'))
|
|
||||||
self.message3 = Message(Question('Question 3'),
|
self.message3 = Message(Question('Question 3'),
|
||||||
Answer('Answer 3'),
|
Answer('Answer 3'),
|
||||||
{Tag('tag3')},
|
{Tag('tag3')})
|
||||||
file_path=pathlib.Path('0003.txt'))
|
|
||||||
self.message4 = Message(Question('Question 4'),
|
self.message4 = Message(Question('Question 4'),
|
||||||
Answer('Answer 4'),
|
Answer('Answer 4'),
|
||||||
{Tag('tag4')},
|
{Tag('tag4')})
|
||||||
file_path=pathlib.Path('0004.yaml'))
|
|
||||||
|
|
||||||
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
|
self.message1.to_file(pathlib.Path(self.db_path.name, '0001'), mformat='txt')
|
||||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
|
self.message2.to_file(pathlib.Path(self.db_path.name, '0002'), mformat='yaml')
|
||||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
|
self.message3.to_file(pathlib.Path(self.db_path.name, '0003'), mformat='txt')
|
||||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
|
self.message4.to_file(pathlib.Path(self.db_path.name, '0004'), mformat='yaml')
|
||||||
# make the next FID match the current state
|
# make the next FID match the current state
|
||||||
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||||
with open(next_fname, 'w') as f:
|
with open(next_fname, 'w') as f:
|
||||||
f.write('4')
|
f.write('4')
|
||||||
# add some "trash" in order to test if it's correctly handled / ignored
|
# add some "trash" in order to test if it's correctly handled / ignored
|
||||||
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
|
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt', 'fubar.msg']
|
||||||
for file in self.trash_files:
|
for file in self.trash_files:
|
||||||
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
|
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
|
||||||
f.write('test trash')
|
f.write('test trash')
|
||||||
@ -207,7 +217,7 @@ class TestChatDB(TestChatBase):
|
|||||||
List all Message files in the given TemporaryDirectory.
|
List all Message files in the given TemporaryDirectory.
|
||||||
"""
|
"""
|
||||||
# exclude '.next'
|
# exclude '.next'
|
||||||
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
|
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[tym]*') if f.name not in self.trash_files]
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
self.db_path.cleanup()
|
self.db_path.cleanup()
|
||||||
@ -218,13 +228,31 @@ class TestChatDB(TestChatBase):
|
|||||||
duplicate_message = Message(Question('Question 4'),
|
duplicate_message = Message(Question('Question 4'),
|
||||||
Answer('Answer 4'),
|
Answer('Answer 4'),
|
||||||
{Tag('tag4')},
|
{Tag('tag4')},
|
||||||
file_path=pathlib.Path('0004.txt'))
|
file_path=pathlib.Path(self.db_path.name, '0004.txt'))
|
||||||
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
|
msg_to_file_force_suffix(duplicate_message)
|
||||||
with self.assertRaises(ChatError) as cm:
|
with self.assertRaises(ChatError) as cm:
|
||||||
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
pathlib.Path(self.db_path.name))
|
pathlib.Path(self.db_path.name))
|
||||||
self.assertEqual(str(cm.exception), "Validation failed")
|
self.assertEqual(str(cm.exception), "Validation failed")
|
||||||
|
|
||||||
|
def test_file_path_ID_exists(self) -> None:
|
||||||
|
"""
|
||||||
|
Tests if the CacheDB chooses another ID if a file path with
|
||||||
|
the given one exists.
|
||||||
|
"""
|
||||||
|
# create a new and empty CacheDB
|
||||||
|
db_path = tempfile.TemporaryDirectory()
|
||||||
|
cache_path = tempfile.TemporaryDirectory()
|
||||||
|
chat_db = ChatDB.from_dir(pathlib.Path(cache_path.name),
|
||||||
|
pathlib.Path(db_path.name))
|
||||||
|
# add a message file
|
||||||
|
message = Message(Question('What?'),
|
||||||
|
file_path=pathlib.Path(cache_path.name) / f'0001{msg_suffix}')
|
||||||
|
message.to_file()
|
||||||
|
message1 = Message(Question('Where?'))
|
||||||
|
chat_db.cache_write([message1])
|
||||||
|
self.assertEqual(message1.msg_id(), '0002')
|
||||||
|
|
||||||
def test_from_dir(self) -> None:
|
def test_from_dir(self) -> None:
|
||||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
pathlib.Path(self.db_path.name))
|
pathlib.Path(self.db_path.name))
|
||||||
@ -233,25 +261,23 @@ class TestChatDB(TestChatBase):
|
|||||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
# check that the files are sorted
|
# check that the files are sorted
|
||||||
self.assertEqual(chat_db.messages[0].file_path,
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[1].file_path,
|
self.assertEqual(chat_db.messages[1].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[2].file_path,
|
self.assertEqual(chat_db.messages[2].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[3].file_path,
|
self.assertEqual(chat_db.messages[3].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0004.yaml'))
|
pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
def test_from_dir_glob(self) -> None:
|
def test_from_dir_glob(self) -> None:
|
||||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
pathlib.Path(self.db_path.name),
|
pathlib.Path(self.db_path.name),
|
||||||
glob='*.txt')
|
glob='*1.*')
|
||||||
self.assertEqual(len(chat_db.messages), 2)
|
self.assertEqual(len(chat_db.messages), 1)
|
||||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
self.assertEqual(chat_db.messages[0].file_path,
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[1].file_path,
|
|
||||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
|
||||||
|
|
||||||
def test_from_dir_filter_tags(self) -> None:
|
def test_from_dir_filter_tags(self) -> None:
|
||||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
@ -261,7 +287,7 @@ class TestChatDB(TestChatBase):
|
|||||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
self.assertEqual(chat_db.messages[0].file_path,
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
|
|
||||||
def test_from_dir_filter_tags_empty(self) -> None:
|
def test_from_dir_filter_tags_empty(self) -> None:
|
||||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
@ -279,7 +305,7 @@ class TestChatDB(TestChatBase):
|
|||||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||||
self.assertEqual(chat_db.messages[0].file_path,
|
self.assertEqual(chat_db.messages[0].file_path,
|
||||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||||
|
|
||||||
def test_from_messages(self) -> None:
|
def test_from_messages(self) -> None:
|
||||||
@ -324,25 +350,25 @@ class TestChatDB(TestChatBase):
|
|||||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
pathlib.Path(self.db_path.name))
|
pathlib.Path(self.db_path.name))
|
||||||
# check that Message.file_path is correct
|
# check that Message.file_path is correct
|
||||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
# write the messages to the cache directory
|
# write the messages to the cache directory
|
||||||
chat_db.cache_write()
|
chat_db.cache_write()
|
||||||
# check if the written files are in the cache directory
|
# check if the written files are in the cache directory
|
||||||
cache_dir_files = self.message_list(self.cache_path)
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
self.assertEqual(len(cache_dir_files), 4)
|
self.assertEqual(len(cache_dir_files), 4)
|
||||||
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'), cache_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'), cache_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'), cache_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
|
self.assertIn(pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'), cache_dir_files)
|
||||||
# check that Message.file_path has been correctly updated
|
# check that Message.file_path has been correctly updated
|
||||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, f'0001{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, f'0002{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, f'0003{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
# check the timestamp of the files in the DB directory
|
# check the timestamp of the files in the DB directory
|
||||||
db_dir_files = self.message_list(self.db_path)
|
db_dir_files = self.message_list(self.db_path)
|
||||||
@ -354,18 +380,18 @@ class TestChatDB(TestChatBase):
|
|||||||
# check if the written files are in the DB directory
|
# check if the written files are in the DB directory
|
||||||
db_dir_files = self.message_list(self.db_path)
|
db_dir_files = self.message_list(self.db_path)
|
||||||
self.assertEqual(len(db_dir_files), 4)
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
|
||||||
# check if all files in the DB dir have actually been overwritten
|
# check if all files in the DB dir have actually been overwritten
|
||||||
for file in db_dir_files:
|
for file in db_dir_files:
|
||||||
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
|
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
|
||||||
# check that Message.file_path has been correctly updated (again)
|
# check that Message.file_path has been correctly updated (again)
|
||||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
def test_db_read(self) -> None:
|
def test_db_read(self) -> None:
|
||||||
# create a new ChatDB instance
|
# create a new ChatDB instance
|
||||||
@ -380,65 +406,65 @@ class TestChatDB(TestChatBase):
|
|||||||
new_message2 = Message(Question('Question 6'),
|
new_message2 = Message(Question('Question 6'),
|
||||||
Answer('Answer 6'),
|
Answer('Answer 6'),
|
||||||
{Tag('tag6')})
|
{Tag('tag6')})
|
||||||
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
|
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
|
||||||
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
|
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
|
||||||
# read and check them
|
# read and check them
|
||||||
chat_db.db_read()
|
chat_db.db_read()
|
||||||
self.assertEqual(len(chat_db.messages), 6)
|
self.assertEqual(len(chat_db.messages), 6)
|
||||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||||
|
|
||||||
# create 2 new files in the cache directory
|
# create 2 new files in the cache directory
|
||||||
new_message3 = Message(Question('Question 7'),
|
new_message3 = Message(Question('Question 7'),
|
||||||
Answer('Answer 5'),
|
Answer('Answer 7'),
|
||||||
{Tag('tag7')})
|
{Tag('tag7')})
|
||||||
new_message4 = Message(Question('Question 8'),
|
new_message4 = Message(Question('Question 8'),
|
||||||
Answer('Answer 6'),
|
Answer('Answer 8'),
|
||||||
{Tag('tag8')})
|
{Tag('tag8')})
|
||||||
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
|
new_message3.to_file(pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'), mformat='txt')
|
||||||
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
|
new_message4.to_file(pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'), mformat='yaml')
|
||||||
# read and check them
|
# read and check them
|
||||||
chat_db.cache_read()
|
chat_db.cache_read()
|
||||||
self.assertEqual(len(chat_db.messages), 8)
|
self.assertEqual(len(chat_db.messages), 8)
|
||||||
# check that the new message have the cache dir path
|
# check that the new message have the cache dir path
|
||||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
|
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, f'0007{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
|
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, f'0008{msg_suffix}'))
|
||||||
# an the old ones keep their path (since they have not been replaced)
|
# an the old ones keep their path (since they have not been replaced)
|
||||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||||
|
|
||||||
# now overwrite two messages in the DB directory
|
# now overwrite two messages in the DB directory
|
||||||
new_message1.question = Question('New Question 1')
|
new_message1.question = Question('New Question 1')
|
||||||
new_message2.question = Question('New Question 2')
|
new_message2.question = Question('New Question 2')
|
||||||
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
|
new_message1.to_file(pathlib.Path(self.db_path.name, f'0005{msg_suffix}'), mformat='txt')
|
||||||
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
|
new_message2.to_file(pathlib.Path(self.db_path.name, f'0006{msg_suffix}'), mformat='yaml')
|
||||||
# read from the DB dir and check if the modified messages have been updated
|
# read from the DB dir and check if the modified messages have been updated
|
||||||
chat_db.db_read()
|
chat_db.db_read()
|
||||||
self.assertEqual(len(chat_db.messages), 8)
|
self.assertEqual(len(chat_db.messages), 8)
|
||||||
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
|
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
|
||||||
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
|
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
|
||||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, f'0005{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, f'0006{msg_suffix}'))
|
||||||
|
|
||||||
# now write the messages from the cache to the DB directory
|
# now write the messages from the cache to the DB directory
|
||||||
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
|
new_message3.to_file(pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
|
||||||
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
|
new_message4.to_file(pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
|
||||||
# read and check them
|
# read and check them
|
||||||
chat_db.db_read()
|
chat_db.db_read()
|
||||||
self.assertEqual(len(chat_db.messages), 8)
|
self.assertEqual(len(chat_db.messages), 8)
|
||||||
# check that they now have the DB path
|
# check that they now have the DB path
|
||||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
|
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, f'0007{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
|
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, f'0008{msg_suffix}'))
|
||||||
|
|
||||||
def test_cache_clear(self) -> None:
|
def test_cache_clear(self) -> None:
|
||||||
# create a new ChatDB instance
|
# create a new ChatDB instance
|
||||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||||
pathlib.Path(self.db_path.name))
|
pathlib.Path(self.db_path.name))
|
||||||
# check that Message.file_path is correct
|
# check that Message.file_path is correct
|
||||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, f'0001{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, f'0002{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, f'0003{msg_suffix}'))
|
||||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, f'0004{msg_suffix}'))
|
||||||
|
|
||||||
# write the messages to the cache directory
|
# write the messages to the cache directory
|
||||||
chat_db.cache_write()
|
chat_db.cache_write()
|
||||||
@ -450,10 +476,10 @@ class TestChatDB(TestChatBase):
|
|||||||
chat_db.db_write()
|
chat_db.db_write()
|
||||||
db_dir_files = self.message_list(self.db_path)
|
db_dir_files = self.message_list(self.db_path)
|
||||||
self.assertEqual(len(db_dir_files), 4)
|
self.assertEqual(len(db_dir_files), 4)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0001{msg_suffix}'), db_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0002{msg_suffix}'), db_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0003{msg_suffix}'), db_dir_files)
|
||||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
self.assertIn(pathlib.Path(self.db_path.name, f'0004{msg_suffix}'), db_dir_files)
|
||||||
|
|
||||||
# add a new message with empty file_path
|
# add a new message with empty file_path
|
||||||
message_empty = Message(question=Question("What the hell am I doing here?"),
|
message_empty = Message(question=Question("What the hell am I doing here?"),
|
||||||
@ -461,7 +487,7 @@ class TestChatDB(TestChatBase):
|
|||||||
# and one for the cache dir
|
# and one for the cache dir
|
||||||
message_cache = Message(question=Question("What the hell am I doing here?"),
|
message_cache = Message(question=Question("What the hell am I doing here?"),
|
||||||
answer=Answer("You're a creep!"),
|
answer=Answer("You're a creep!"),
|
||||||
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
|
file_path=pathlib.Path(self.cache_path.name, '0005'))
|
||||||
chat_db.msg_add([message_empty, message_cache])
|
chat_db.msg_add([message_empty, message_cache])
|
||||||
|
|
||||||
# clear the cache and check the cache dir
|
# clear the cache and check the cache dir
|
||||||
@ -523,11 +549,11 @@ class TestChatDB(TestChatBase):
|
|||||||
chat_db.msg_write([message])
|
chat_db.msg_write([message])
|
||||||
|
|
||||||
# write a message with a valid file_path
|
# write a message with a valid file_path
|
||||||
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
|
message.file_path = pathlib.Path(self.cache_path.name) / '123456'
|
||||||
chat_db.msg_write([message])
|
chat_db.msg_write([message])
|
||||||
cache_dir_files = self.message_list(self.cache_path)
|
cache_dir_files = self.message_list(self.cache_path)
|
||||||
self.assertEqual(len(cache_dir_files), 1)
|
self.assertEqual(len(cache_dir_files), 1)
|
||||||
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
|
self.assertIn(pathlib.Path(self.cache_path.name, f'123456{msg_suffix}'), cache_dir_files)
|
||||||
|
|
||||||
def test_msg_update(self) -> None:
|
def test_msg_update(self) -> None:
|
||||||
# create a new ChatDB instance
|
# create a new ChatDB instance
|
||||||
@ -563,21 +589,21 @@ class TestChatDB(TestChatBase):
|
|||||||
# search for a DB file in memory
|
# search for a DB file in memory
|
||||||
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
|
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], loc='mem'), [self.message1])
|
||||||
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
|
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], loc='mem'), [self.message1])
|
||||||
self.assertEqual(chat_db.msg_find(['0001.txt'], loc='mem'), [self.message1])
|
self.assertEqual(chat_db.msg_find(['0001.msg'], loc='mem'), [self.message1])
|
||||||
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
|
self.assertEqual(chat_db.msg_find(['0001'], loc='mem'), [self.message1])
|
||||||
# and on disk
|
# and on disk
|
||||||
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
|
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], loc='db'), [self.message2])
|
||||||
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
|
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], loc='db'), [self.message2])
|
||||||
self.assertEqual(chat_db.msg_find(['0002.yaml'], loc='db'), [self.message2])
|
self.assertEqual(chat_db.msg_find(['0002.msg'], loc='db'), [self.message2])
|
||||||
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
|
self.assertEqual(chat_db.msg_find(['0002'], loc='db'), [self.message2])
|
||||||
# now search the cache -> expect empty result
|
# now search the cache -> expect empty result
|
||||||
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
|
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], loc='cache'), [])
|
||||||
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
|
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], loc='cache'), [])
|
||||||
self.assertEqual(chat_db.msg_find(['0003.txt'], loc='cache'), [])
|
self.assertEqual(chat_db.msg_find(['0003.msg'], loc='cache'), [])
|
||||||
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
|
self.assertEqual(chat_db.msg_find(['0003'], loc='cache'), [])
|
||||||
# search for multiple messages
|
# search for multiple messages
|
||||||
# -> search one twice, expect result to be unique
|
# -> search one twice, expect result to be unique
|
||||||
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
|
search_names = ['0001', '0002.msg', self.message3.msg_id(), str(self.message3.file_path)]
|
||||||
expected_result = [self.message1, self.message2, self.message3]
|
expected_result = [self.message1, self.message2, self.message3]
|
||||||
result = chat_db.msg_find(search_names, loc='all')
|
result = chat_db.msg_find(search_names, loc='all')
|
||||||
self.assert_messages_equal(result, expected_result)
|
self.assert_messages_equal(result, expected_result)
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import itertools
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine,\
|
||||||
|
MessageFilter, message_in, message_valid_formats
|
||||||
from chatmastermind.tags import Tag, TagLine
|
from chatmastermind.tags import Tag, TagLine
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix: str = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
class SourceCodeTestCase(unittest.TestCase):
|
class SourceCodeTestCase(unittest.TestCase):
|
||||||
def test_source_code_with_include_delims(self) -> None:
|
def test_source_code_with_include_delims(self) -> None:
|
||||||
text = """
|
text = """
|
||||||
@ -101,7 +106,7 @@ class AnswerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
class MessageToFileTxtTestCase(unittest.TestCase):
|
class MessageToFileTxtTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
self.message_complete = Message(Question('This is a question.'),
|
self.message_complete = Message(Question('This is a question.'),
|
||||||
Answer('This is an answer.'),
|
Answer('This is an answer.'),
|
||||||
@ -117,7 +122,7 @@ class MessageToFileTxtTestCase(unittest.TestCase):
|
|||||||
self.file_path.unlink()
|
self.file_path.unlink()
|
||||||
|
|
||||||
def test_to_file_txt_complete(self) -> None:
|
def test_to_file_txt_complete(self) -> None:
|
||||||
self.message_complete.to_file(self.file_path)
|
self.message_complete.to_file(self.file_path, mformat='txt')
|
||||||
|
|
||||||
with open(self.file_path, "r") as fd:
|
with open(self.file_path, "r") as fd:
|
||||||
content = fd.read()
|
content = fd.read()
|
||||||
@ -132,7 +137,7 @@ This is an answer.
|
|||||||
self.assertEqual(content, expected_content)
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
def test_to_file_txt_min(self) -> None:
|
def test_to_file_txt_min(self) -> None:
|
||||||
self.message_min.to_file(self.file_path)
|
self.message_min.to_file(self.file_path, mformat='txt')
|
||||||
|
|
||||||
with open(self.file_path, "r") as fd:
|
with open(self.file_path, "r") as fd:
|
||||||
content = fd.read()
|
content = fd.read()
|
||||||
@ -141,11 +146,17 @@ This is a question.
|
|||||||
"""
|
"""
|
||||||
self.assertEqual(content, expected_content)
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
def test_to_file_unsupported_file_type(self) -> None:
|
def test_to_file_unsupported_file_suffix(self) -> None:
|
||||||
unsupported_file_path = pathlib.Path("example.doc")
|
unsupported_file_path = pathlib.Path("example.doc")
|
||||||
with self.assertRaises(MessageError) as cm:
|
with self.assertRaises(MessageError) as cm:
|
||||||
self.message_complete.to_file(unsupported_file_path)
|
self.message_complete.to_file(unsupported_file_path)
|
||||||
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
|
self.assertEqual(str(cm.exception), "File suffix '.doc' is not supported")
|
||||||
|
|
||||||
|
def test_to_file_unsupported_file_format(self) -> None:
|
||||||
|
unsupported_file_format = pathlib.Path(f"example{msg_suffix}")
|
||||||
|
with self.assertRaises(MessageError) as cm:
|
||||||
|
self.message_complete.to_file(unsupported_file_format, mformat='doc') # type: ignore [arg-type]
|
||||||
|
self.assertEqual(str(cm.exception), "File format 'doc' is not supported")
|
||||||
|
|
||||||
def test_to_file_no_file_path(self) -> None:
|
def test_to_file_no_file_path(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -159,10 +170,24 @@ This is a question.
|
|||||||
# reset the internal file_path
|
# reset the internal file_path
|
||||||
self.message_complete.file_path = self.file_path
|
self.message_complete.file_path = self.file_path
|
||||||
|
|
||||||
|
def test_to_file_txt_auto_suffix(self) -> None:
|
||||||
|
"""
|
||||||
|
Test if suffix is auto-generated if omitted.
|
||||||
|
"""
|
||||||
|
file_path_no_suffix = self.file_path.with_suffix('')
|
||||||
|
# test with file_path member
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(mformat='txt')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
# test with explicit file_path
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(file_path=file_path_no_suffix, mformat='txt')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
|
||||||
|
|
||||||
class MessageToFileYamlTestCase(unittest.TestCase):
|
class MessageToFileYamlTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
self.message_complete = Message(Question('This is a question.'),
|
self.message_complete = Message(Question('This is a question.'),
|
||||||
Answer('This is an answer.'),
|
Answer('This is an answer.'),
|
||||||
@ -184,7 +209,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
|
|||||||
self.file_path.unlink()
|
self.file_path.unlink()
|
||||||
|
|
||||||
def test_to_file_yaml_complete(self) -> None:
|
def test_to_file_yaml_complete(self) -> None:
|
||||||
self.message_complete.to_file(self.file_path)
|
self.message_complete.to_file(self.file_path, mformat='yaml')
|
||||||
|
|
||||||
with open(self.file_path, "r") as fd:
|
with open(self.file_path, "r") as fd:
|
||||||
content = fd.read()
|
content = fd.read()
|
||||||
@ -199,7 +224,7 @@ class MessageToFileYamlTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(content, expected_content)
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
def test_to_file_yaml_multiline(self) -> None:
|
def test_to_file_yaml_multiline(self) -> None:
|
||||||
self.message_multiline.to_file(self.file_path)
|
self.message_multiline.to_file(self.file_path, mformat='yaml')
|
||||||
|
|
||||||
with open(self.file_path, "r") as fd:
|
with open(self.file_path, "r") as fd:
|
||||||
content = fd.read()
|
content = fd.read()
|
||||||
@ -218,17 +243,31 @@ class MessageToFileYamlTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(content, expected_content)
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
def test_to_file_yaml_min(self) -> None:
|
def test_to_file_yaml_min(self) -> None:
|
||||||
self.message_min.to_file(self.file_path)
|
self.message_min.to_file(self.file_path, mformat='yaml')
|
||||||
|
|
||||||
with open(self.file_path, "r") as fd:
|
with open(self.file_path, "r") as fd:
|
||||||
content = fd.read()
|
content = fd.read()
|
||||||
expected_content = f"{Question.yaml_key}: This is a question.\n"
|
expected_content = f"{Question.yaml_key}: This is a question.\n"
|
||||||
self.assertEqual(content, expected_content)
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
def test_to_file_yaml_auto_suffix(self) -> None:
|
||||||
|
"""
|
||||||
|
Test if suffix is auto-generated if omitted.
|
||||||
|
"""
|
||||||
|
file_path_no_suffix = self.file_path.with_suffix('')
|
||||||
|
# test with file_path member
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(mformat='yaml')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
# test with explicit file_path
|
||||||
|
self.message_min.file_path = file_path_no_suffix
|
||||||
|
self.message_min.to_file(file_path=file_path_no_suffix, mformat='yaml')
|
||||||
|
self.assertEqual(self.message_min.file_path.suffix, msg_suffix)
|
||||||
|
|
||||||
|
|
||||||
class MessageFromFileTxtTestCase(unittest.TestCase):
|
class MessageFromFileTxtTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
with open(self.file_path, "w") as fd:
|
with open(self.file_path, "w") as fd:
|
||||||
fd.write(f"""{TagLine.prefix} tag1 tag2
|
fd.write(f"""{TagLine.prefix} tag1 tag2
|
||||||
@ -239,7 +278,7 @@ This is a question.
|
|||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
This is an answer.
|
This is an answer.
|
||||||
""")
|
""")
|
||||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_min = pathlib.Path(self.file_min.name)
|
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||||
with open(self.file_path_min, "w") as fd:
|
with open(self.file_path_min, "w") as fd:
|
||||||
fd.write(f"""{Question.txt_header}
|
fd.write(f"""{Question.txt_header}
|
||||||
@ -259,7 +298,7 @@ This is a question.
|
|||||||
message = Message.from_file(self.file_path)
|
message = Message.from_file(self.file_path)
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertEqual(message.answer, 'This is an answer.')
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
@ -274,7 +313,7 @@ This is a question.
|
|||||||
message = Message.from_file(self.file_path_min)
|
message = Message.from_file(self.file_path_min)
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertEqual(message.file_path, self.file_path_min)
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
self.assertIsNone(message.answer)
|
self.assertIsNone(message.answer)
|
||||||
@ -284,7 +323,7 @@ This is a question.
|
|||||||
MessageFilter(tags_or={Tag('tag1')}))
|
MessageFilter(tags_or={Tag('tag1')}))
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertEqual(message.answer, 'This is an answer.')
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
@ -311,13 +350,13 @@ This is a question.
|
|||||||
MessageFilter(tags_not={Tag('tag1')}))
|
MessageFilter(tags_not={Tag('tag1')}))
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||||
self.assertEqual(message.file_path, self.file_path_min)
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
|
|
||||||
def test_from_file_not_exists(self) -> None:
|
def test_from_file_not_exists(self) -> None:
|
||||||
file_not_exists = pathlib.Path("example.txt")
|
file_not_exists = pathlib.Path(f"example{msg_suffix}")
|
||||||
with self.assertRaises(MessageError) as cm:
|
with self.assertRaises(MessageError) as cm:
|
||||||
Message.from_file(file_not_exists)
|
Message.from_file(file_not_exists)
|
||||||
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||||
@ -396,7 +435,7 @@ This is a question.
|
|||||||
|
|
||||||
class MessageFromFileYamlTestCase(unittest.TestCase):
|
class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
with open(self.file_path, "w") as fd:
|
with open(self.file_path, "w") as fd:
|
||||||
fd.write(f"""
|
fd.write(f"""
|
||||||
@ -410,7 +449,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
|||||||
- tag1
|
- tag1
|
||||||
- tag2
|
- tag2
|
||||||
""")
|
""")
|
||||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_min = pathlib.Path(self.file_min.name)
|
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||||
with open(self.file_path_min, "w") as fd:
|
with open(self.file_path_min, "w") as fd:
|
||||||
fd.write(f"""
|
fd.write(f"""
|
||||||
@ -431,7 +470,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
|||||||
message = Message.from_file(self.file_path)
|
message = Message.from_file(self.file_path)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertEqual(message.answer, 'This is an answer.')
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
@ -446,14 +485,14 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
|||||||
message = Message.from_file(self.file_path_min)
|
message = Message.from_file(self.file_path_min)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||||
self.assertEqual(message.file_path, self.file_path_min)
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
self.assertIsNone(message.answer)
|
self.assertIsNone(message.answer)
|
||||||
|
|
||||||
def test_from_file_not_exists(self) -> None:
|
def test_from_file_not_exists(self) -> None:
|
||||||
file_not_exists = pathlib.Path("example.yaml")
|
file_not_exists = pathlib.Path(f"example{msg_suffix}")
|
||||||
with self.assertRaises(MessageError) as cm:
|
with self.assertRaises(MessageError) as cm:
|
||||||
Message.from_file(file_not_exists)
|
Message.from_file(file_not_exists)
|
||||||
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||||
@ -463,7 +502,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
|||||||
MessageFilter(tags_or={Tag('tag1')}))
|
MessageFilter(tags_or={Tag('tag1')}))
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertEqual(message.answer, 'This is an answer.')
|
self.assertEqual(message.answer, 'This is an answer.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||||
@ -484,7 +523,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
|||||||
MessageFilter(tags_not={Tag('tag1')}))
|
MessageFilter(tags_not={Tag('tag1')}))
|
||||||
self.assertIsNotNone(message)
|
self.assertIsNotNone(message)
|
||||||
self.assertIsInstance(message, Message)
|
self.assertIsInstance(message, Message)
|
||||||
if message: # mypy bug
|
assert message
|
||||||
self.assertEqual(message.question, 'This is a question.')
|
self.assertEqual(message.question, 'This is a question.')
|
||||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||||
self.assertEqual(message.file_path, self.file_path_min)
|
self.assertEqual(message.file_path, self.file_path_min)
|
||||||
@ -563,7 +602,7 @@ class MessageFromFileYamlTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
class TagsFromFileTestCase(unittest.TestCase):
|
class TagsFromFileTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
||||||
with open(self.file_path_txt, "w") as fd:
|
with open(self.file_path_txt, "w") as fd:
|
||||||
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
|
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
|
||||||
@ -572,7 +611,7 @@ This is a question.
|
|||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
This is an answer.
|
This is an answer.
|
||||||
""")
|
""")
|
||||||
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
|
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
|
||||||
with open(self.file_path_txt_no_tags, "w") as fd:
|
with open(self.file_path_txt_no_tags, "w") as fd:
|
||||||
fd.write(f"""{Question.txt_header}
|
fd.write(f"""{Question.txt_header}
|
||||||
@ -580,7 +619,7 @@ This is a question.
|
|||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
This is an answer.
|
This is an answer.
|
||||||
""")
|
""")
|
||||||
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
|
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
|
||||||
with open(self.file_path_txt_tags_empty, "w") as fd:
|
with open(self.file_path_txt_tags_empty, "w") as fd:
|
||||||
fd.write(f"""TAGS:
|
fd.write(f"""TAGS:
|
||||||
@ -589,7 +628,7 @@ This is a question.
|
|||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
This is an answer.
|
This is an answer.
|
||||||
""")
|
""")
|
||||||
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
|
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
|
||||||
with open(self.file_path_yaml, "w") as fd:
|
with open(self.file_path_yaml, "w") as fd:
|
||||||
fd.write(f"""
|
fd.write(f"""
|
||||||
@ -602,7 +641,7 @@ This is an answer.
|
|||||||
- tag2
|
- tag2
|
||||||
- ptag3
|
- ptag3
|
||||||
""")
|
""")
|
||||||
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
|
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
|
||||||
with open(self.file_path_yaml_no_tags, "w") as fd:
|
with open(self.file_path_yaml_no_tags, "w") as fd:
|
||||||
fd.write(f"""
|
fd.write(f"""
|
||||||
@ -679,24 +718,25 @@ class TagsFromDirTestCase(unittest.TestCase):
|
|||||||
{Tag('ctag5'), Tag('ctag6')}
|
{Tag('ctag5'), Tag('ctag6')}
|
||||||
]
|
]
|
||||||
self.files = [
|
self.files = [
|
||||||
pathlib.Path(self.temp_dir.name, 'file1.txt'),
|
pathlib.Path(self.temp_dir.name, f'file1{msg_suffix}'),
|
||||||
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
|
pathlib.Path(self.temp_dir.name, f'file2{msg_suffix}'),
|
||||||
pathlib.Path(self.temp_dir.name, 'file3.txt')
|
pathlib.Path(self.temp_dir.name, f'file3{msg_suffix}')
|
||||||
]
|
]
|
||||||
self.files_no_tags = [
|
self.files_no_tags = [
|
||||||
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
|
pathlib.Path(self.temp_dir_no_tags.name, f'file4{msg_suffix}'),
|
||||||
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
|
pathlib.Path(self.temp_dir_no_tags.name, f'file5{msg_suffix}'),
|
||||||
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
|
pathlib.Path(self.temp_dir_no_tags.name, f'file6{msg_suffix}')
|
||||||
]
|
]
|
||||||
|
mformats = itertools.cycle(message_valid_formats)
|
||||||
for file, tags in zip(self.files, self.tag_sets):
|
for file, tags in zip(self.files, self.tag_sets):
|
||||||
message = Message(Question('This is a question.'),
|
message = Message(Question('This is a question.'),
|
||||||
Answer('This is an answer.'),
|
Answer('This is an answer.'),
|
||||||
tags)
|
tags)
|
||||||
message.to_file(file)
|
message.to_file(file, next(mformats))
|
||||||
for file in self.files_no_tags:
|
for file in self.files_no_tags:
|
||||||
message = Message(Question('This is a question.'),
|
message = Message(Question('This is a question.'),
|
||||||
Answer('This is an answer.'))
|
Answer('This is an answer.'))
|
||||||
message.to_file(file)
|
message.to_file(file, next(mformats))
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
self.temp_dir.cleanup()
|
self.temp_dir.cleanup()
|
||||||
@ -719,7 +759,7 @@ class TagsFromDirTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
class MessageIDTestCase(unittest.TestCase):
|
class MessageIDTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix=msg_suffix)
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
self.message = Message(Question('This is a question.'),
|
self.message = Message(Question('This is a question.'),
|
||||||
file_path=self.file_path)
|
file_path=self.file_path)
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from chatmastermind.ai import AIError
|
|||||||
from .test_common import TestWithFakeAI
|
from .test_common import TestWithFakeAI
|
||||||
|
|
||||||
|
|
||||||
|
msg_suffix = Message.file_suffix_write
|
||||||
|
|
||||||
|
|
||||||
class TestMessageCreate(TestWithFakeAI):
|
class TestMessageCreate(TestWithFakeAI):
|
||||||
"""
|
"""
|
||||||
Test if messages created by the 'question' command have
|
Test if messages created by the 'question' command have
|
||||||
@ -83,7 +86,7 @@ Aaaand again some text."""
|
|||||||
|
|
||||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||||
# exclude '.next'
|
# exclude '.next'
|
||||||
return list(Path(tmp_dir.name).glob('*.[ty]*'))
|
return list(Path(tmp_dir.name).glob(f'*{msg_suffix}'))
|
||||||
|
|
||||||
def test_message_file_created(self) -> None:
|
def test_message_file_created(self) -> None:
|
||||||
self.args.ask = ["What is this?"]
|
self.args.ask = ["What is this?"]
|
||||||
@ -231,7 +234,7 @@ class TestQuestionCmd(TestWithFakeAI):
|
|||||||
|
|
||||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||||
# exclude '.next'
|
# exclude '.next'
|
||||||
return sorted([f for f in Path(tmp_dir.name).glob('*.[ty]*')])
|
return sorted([f for f in Path(tmp_dir.name).glob(f'*{msg_suffix}')])
|
||||||
|
|
||||||
|
|
||||||
class TestQuestionCmdAsk(TestQuestionCmd):
|
class TestQuestionCmdAsk(TestQuestionCmd):
|
||||||
@ -330,14 +333,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
Repeat a single question.
|
Repeat a single question.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.side_effect = self.mock_create_ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
# create a message
|
# create a message
|
||||||
message = Message(Question(self.args.ask[0]),
|
message = Message(Question(self.args.ask[0]),
|
||||||
Answer('Old Answer'),
|
Answer('Old Answer'),
|
||||||
tags=set(self.args.output_tags),
|
tags=set(self.args.output_tags),
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
message.to_file()
|
chat.msg_write([message])
|
||||||
|
|
||||||
# repeat the last question (without overwriting)
|
# repeat the last question (without overwriting)
|
||||||
# -> expect two identical messages (except for the file_path)
|
# -> expect two identical messages (except for the file_path)
|
||||||
@ -353,8 +358,6 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
# we expect the original message + the one with the new response
|
# we expect the original message + the one with the new response
|
||||||
expected_responses = [message] + [expected_response]
|
expected_responses = [message] + [expected_response]
|
||||||
question_cmd(self.args, self.config)
|
question_cmd(self.args, self.config)
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
print(self.message_list(self.cache_dir))
|
print(self.message_list(self.cache_dir))
|
||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 2)
|
||||||
@ -366,16 +369,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
Repeat a single question and overwrite the old one.
|
Repeat a single question and overwrite the old one.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.side_effect = self.mock_create_ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
# create a message
|
# create a message
|
||||||
message = Message(Question(self.args.ask[0]),
|
message = Message(Question(self.args.ask[0]),
|
||||||
Answer('Old Answer'),
|
Answer('Old Answer'),
|
||||||
tags=set(self.args.output_tags),
|
tags=set(self.args.output_tags),
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
message.to_file()
|
chat.msg_write([message])
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
cached_msg_file_id = cached_msg[0].file_path.stem
|
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||||
@ -405,16 +408,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
Repeat a single question after an error.
|
Repeat a single question after an error.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.side_effect = self.mock_create_ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
# create a question WITHOUT an answer
|
# create a question WITHOUT an answer
|
||||||
# -> just like after an error, which is tested above
|
# -> just like after an error, which is tested above
|
||||||
message = Message(Question(self.args.ask[0]),
|
message = Message(Question(self.args.ask[0]),
|
||||||
tags=set(self.args.output_tags),
|
tags=set(self.args.output_tags),
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
message.to_file()
|
chat.msg_write([message])
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
cached_msg_file_id = cached_msg[0].file_path.stem
|
cached_msg_file_id = cached_msg[0].file_path.stem
|
||||||
@ -445,16 +448,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
Repeat a single question with new arguments.
|
Repeat a single question with new arguments.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.side_effect = self.mock_create_ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
# create a message
|
# create a message
|
||||||
message = Message(Question(self.args.ask[0]),
|
message = Message(Question(self.args.ask[0]),
|
||||||
Answer('Old Answer'),
|
Answer('Old Answer'),
|
||||||
tags=set(self.args.output_tags),
|
tags=set(self.args.output_tags),
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
message.to_file()
|
chat.msg_write([message])
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
|
|
||||||
@ -483,16 +486,16 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
Repeat a single question with new arguments, overwriting the old one.
|
Repeat a single question with new arguments, overwriting the old one.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.side_effect = self.mock_create_ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
# create a message
|
# create a message
|
||||||
message = Message(Question(self.args.ask[0]),
|
message = Message(Question(self.args.ask[0]),
|
||||||
Answer('Old Answer'),
|
Answer('Old Answer'),
|
||||||
tags=set(self.args.output_tags),
|
tags=set(self.args.output_tags),
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
message.to_file()
|
chat.msg_write([message])
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
assert cached_msg[0].file_path
|
assert cached_msg[0].file_path
|
||||||
|
|
||||||
@ -520,29 +523,29 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
Repeat multiple questions.
|
Repeat multiple questions.
|
||||||
"""
|
"""
|
||||||
mock_create_ai.side_effect = self.mock_create_ai
|
mock_create_ai.side_effect = self.mock_create_ai
|
||||||
|
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
||||||
|
Path(self.db_dir.name))
|
||||||
# 1. === create three questions ===
|
# 1. === create three questions ===
|
||||||
# cached message without an answer
|
# cached message without an answer
|
||||||
message1 = Message(Question(self.args.ask[0]),
|
message1 = Message(Question(self.args.ask[0]),
|
||||||
tags=self.args.output_tags,
|
tags=self.args.output_tags,
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0001.txt')
|
file_path=Path(self.cache_dir.name) / f'0001{msg_suffix}')
|
||||||
# cached message with an answer
|
# cached message with an answer
|
||||||
message2 = Message(Question(self.args.ask[0]),
|
message2 = Message(Question(self.args.ask[0]),
|
||||||
Answer('Old Answer'),
|
Answer('Old Answer'),
|
||||||
tags=self.args.output_tags,
|
tags=self.args.output_tags,
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.cache_dir.name) / '0002.txt')
|
file_path=Path(self.cache_dir.name) / f'0002{msg_suffix}')
|
||||||
# DB message without an answer
|
# DB message without an answer
|
||||||
message3 = Message(Question(self.args.ask[0]),
|
message3 = Message(Question(self.args.ask[0]),
|
||||||
tags=self.args.output_tags,
|
tags=self.args.output_tags,
|
||||||
ai=self.args.AI,
|
ai=self.args.AI,
|
||||||
model=self.args.model,
|
model=self.args.model,
|
||||||
file_path=Path(self.db_dir.name) / '0003.txt')
|
file_path=Path(self.db_dir.name) / f'0003{msg_suffix}')
|
||||||
message1.to_file()
|
chat.msg_write([message1, message2, message3])
|
||||||
message2.to_file()
|
|
||||||
message3.to_file()
|
|
||||||
questions = [message1, message2, message3]
|
questions = [message1, message2, message3]
|
||||||
expected_responses: list[Message] = []
|
expected_responses: list[Message] = []
|
||||||
fake_ai = self.mock_create_ai(self.args, self.config)
|
fake_ai = self.mock_create_ai(self.args, self.config)
|
||||||
@ -566,8 +569,6 @@ class TestQuestionCmdRepeat(TestQuestionCmd):
|
|||||||
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
|
self.assertEqual(len(self.message_list(self.cache_dir)), 4)
|
||||||
self.assertEqual(len(self.message_list(self.db_dir)), 1)
|
self.assertEqual(len(self.message_list(self.db_dir)), 1)
|
||||||
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
|
expected_cache_messages = [expected_responses[0], message2, expected_responses[1], expected_responses[2]]
|
||||||
chat = ChatDB.from_dir(Path(self.cache_dir.name),
|
|
||||||
Path(self.db_dir.name))
|
|
||||||
cached_msg = chat.msg_gather(loc='cache')
|
cached_msg = chat.msg_gather(loc='cache')
|
||||||
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
|
self.assert_msgs_equal_except_file_path(cached_msg, expected_cache_messages)
|
||||||
# check that the DB message has not been modified at all
|
# check that the DB message has not been modified at all
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user