Compare commits

..

No commits in common. "487898e6400d312950105f924cab201a0f4c2fdd" and "17a0264025489b978a6fc450cbb29b1b77467f4b" have entirely different histories.

9 changed files with 206 additions and 480 deletions

View File

@ -6,18 +6,13 @@ from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable, Literal
from .configuration import default_config_file
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
db_next_file = '.next'
ignored_files = [db_next_file, default_config_file]
msg_place = Literal['mem', 'disk', 'cache', 'db', 'all']
class ChatError(Exception):
pass
@ -50,15 +45,13 @@ def read_dir(dir_path: Path,
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if (file_path.is_file()
and file_path.name not in ignored_files # noqa: W503
and file_path.suffix in Message.file_suffixes): # noqa: W503
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
except MessageError as e:
print(f"WARNING: Skipping message in '{file_path}': {str(e)}")
print(f"Error processing message in '{file_path}': {str(e)}")
return messages
@ -119,43 +112,14 @@ class Chat:
messages: list[Message]
def __post_init__(self) -> None:
self.validate()
def validate(self) -> None:
"""
Validate this Chat instance.
"""
def msg_paths(stem: str) -> list[str]:
return [str(fp) for fp in file_paths if fp.stem == stem]
file_paths: set[Path] = {m.file_path for m in self.messages if m.file_path is not None}
file_stems = [m.file_path.stem for m in self.messages if m.file_path is not None]
error = False
for fp in file_paths:
if file_stems.count(fp.stem) > 1:
print(f"ERROR: Found multiple copies of message '{fp.stem}': {msg_paths(fp.stem)}")
error = True
if error:
raise ChatError("Validation failed")
def msg_name_matches(self, file_path: Path, name: str) -> bool:
"""
Return True if the given name matches the given file_path.
Matching is True if:
* 'name' matches the full 'file_path'
* 'name' matches 'file_path.name' (i. e. including the suffix)
* 'name' matches 'file_path.stem' (i. e. without a suffix)
"""
return Path(name) == file_path or name == file_path.name or name == file_path.stem
def msg_filter(self, mfilter: MessageFilter) -> None:
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def msg_sort(self, reverse: bool = False) -> None:
def sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
@ -165,71 +129,48 @@ class Chat:
except MessageError:
pass
def msg_unique_id(self) -> None:
"""
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
Messages without a file_path are kept.
"""
old_msgs = self.messages.copy()
self.messages = []
for m in old_msgs:
if not message_in(m, self.messages):
self.messages.append(m)
self.msg_sort()
def msg_unique_content(self) -> None:
"""
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
"""
self.messages = list(set(self.messages))
self.msg_sort()
def msg_clear(self) -> None:
def clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def msg_add(self, messages: list[Message]) -> None:
def add_messages(self, messages: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += messages
self.msg_sort()
self.sort()
def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
def latest_message(self) -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in
the internal list.
Returns the last added message (according to the file ID).
"""
if len(self.messages) > 0:
self.msg_sort()
for m in reversed(self.messages):
if mfilter is None or m.match(mfilter):
return m
return None
self.sort()
return self.messages[-1]
else:
return None
def msg_find(self, msg_names: list[str]) -> list[Message]:
def find_messages(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
caller should check the result if he requires all messages).
"""
return [m for m in self.messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None:
def remove_messages(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id().
(incl. the suffix) or full paths.
"""
self.messages = [m for m in self.messages
if not any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
self.msg_sort()
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort()
def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
"""
@ -238,7 +179,7 @@ class Chat:
tags |= m.filter_tags(prefix, contain)
return set(sorted(tags))
def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
"""
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
@ -291,11 +232,10 @@ class ChatDB(Chat):
def __post_init__(self) -> None:
# contains the latest message ID
self.next_path = self.db_path / db_next_file
self.next_fname = self.db_path / '.next'
# make all paths absolute
self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute()
self.validate()
@classmethod
def from_dir(cls: Type[ChatDBInst],
@ -331,7 +271,7 @@ class ChatDB(Chat):
def get_next_fid(self) -> int:
try:
with open(self.next_path, 'r') as f:
with open(self.next_fname, 'r') as f:
next_fid = int(f.read()) + 1
self.set_next_fid(next_fid)
return next_fid
@ -340,204 +280,69 @@ class ChatDB(Chat):
return 1
def set_next_fid(self, fid: int) -> None:
with open(self.next_path, 'w') as f:
with open(self.next_fname, 'w') as f:
f.write(f'{fid}')
def msg_write(self, messages: Optional[list[Message]] = None) -> None:
def read_db(self) -> None:
"""
Write either the given messages or the internal ones to their CURRENT file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def msg_update(self, messages: list[Message], write: bool = True) -> None:
"""
Update EXISTING messages. A message is determined as 'existing' if a message with
Reads new messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
Only accepts existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.msg_sort()
# write the UPDATED messages if requested
if write:
self.msg_write(messages)
def msg_gather(self,
source: msg_place,
require_file_path: bool = False,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Gather and return messages from the given source:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
If 'require_file_path' is True, return only files with a valid file_path.
"""
source_messages: list[Message] = []
if source in ['mem', 'all']:
if require_file_path:
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
else:
source_messages += [m for m in self.messages if (mfilter is None or m.match(mfilter))]
if source in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter)
# remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in source_messages:
if not message_in(m, unique_messages):
unique_messages.append(m)
unique_messages.sort(key=lambda m: m.msg_id())
return unique_messages
def msg_find(self,
msg_names: list[str],
source: msg_place = 'mem',
) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Messages that can't be
found are ignored (i. e. the caller should check the result if they require all
messages).
Searches one of the following places:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
source_messages = self.msg_gather(source, require_file_path=True)
return [m for m in source_messages
if any((m.file_path and self.msg_name_matches(m.file_path, mn)) for mn in msg_names)]
def msg_remove(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(with or without suffix), full paths or Message.msg_id(). Also deletes the
files of all given messages with a valid file_path.
"""
# delete the message files first
rm_messages = self.msg_find(msg_names, source='all')
for m in rm_messages:
if (m.file_path):
m.file_path.unlink()
# then remove them from the internal list
super().msg_remove(msg_names)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
source: msg_place = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem').
Searches one of the following places:
* 'mem' : messages currently in memory
* 'disk' : messages on disk (cache + DB directory), but not in memory
* 'cache': messages in the cache directory
* 'db' : messages in the DB directory
* 'all' : all messages ('mem' + 'disk')
"""
# only consider messages with a valid file_path so they can be sorted
source_messages = self.msg_gather(source, require_file_path=True)
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages:
if mfilter is None or m.match(mfilter):
return m
return None
def cache_read(self) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
def cache_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def cache_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
def cache_clear(self) -> None:
"""
Delete all message files from the cache dir and remove them from the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def db_read(self) -> None:
"""
Read messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.msg_sort()
self.sort()
def db_write(self, messages: Optional[list[Message]] = None) -> None:
def read_cache(self) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
Does NOT add the messages to the internal list (use 'db_add()' for that)!
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None:
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
"""
Add NEW messages and set the file_path to the DB directory.
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
@ -551,4 +356,51 @@ class ChatDB(Chat):
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.msg_sort()
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)

View File

@ -1,4 +1,3 @@
import sys
import argparse
from pathlib import Path
from itertools import zip_longest
@ -52,8 +51,7 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates a new message from the given arguments and writes it
to the cache directory.
Creates (and writes) a new message from the given arguments.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
@ -74,34 +72,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
# only write the message (as a backup), don't add it
# to the current chat history
chat.cache_write([message])
chat.add_to_cache([message])
return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the give AI, chat history, message and CLI arguments.
Print all answers.
"""
ai.print()
chat.print(paged=False)
print(message.to_str() + '\n')
response: AIResponse = ai.request(message,
chat,
args.num_answers,
args.output_tags)
# write all answers to the cache, don't add them to the chat history
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
@ -120,29 +94,28 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
# create the correct AI instance
ai: AI = create_ai(args, config)
# === ASK ===
if args.ask:
make_request(ai, chat, message, args)
# === REPEAT ===
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.msg_latest(source='cache')
if lmessage is None:
print("No message found to repeat!")
sys.exit(1)
else:
print(f"Repeating message '{lmessage.msg_id()}':")
# overwrite the latest message if requested or empty
if lmessage.answer is None or args.overwrite is True:
lmessage.clear_answer()
make_request(ai, chat, lmessage, args)
# otherwise create a new one
else:
args.ask = [lmessage.question]
message = create_message(chat, args)
make_request(ai, chat, message, args)
# === PROCESS ===
lmessage = chat.latest_message()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'

View File

@ -11,7 +11,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming

View File

@ -9,7 +9,7 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_config_file = '.config.yaml'
default_config_path = '.config.yaml'
class ConfigError(Exception):

View File

@ -7,7 +7,7 @@ import argcomplete
import argparse
from pathlib import Path
from typing import Any
from .configuration import Config, default_config_file
from .configuration import Config, default_config_path
from .message import Message
from .commands.question import question_cmd
from .commands.tags import tags_cmd
@ -24,7 +24,7 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_file)
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path)
# subcommand-parser
cmdparser = parser.add_subparsers(dest='command',

View File

@ -370,7 +370,7 @@ class Message():
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"'{file_path}' does not contain a valid message")
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
@ -390,11 +390,8 @@ class Message():
* Message.model_yaml_key: str [Optional]
"""
with open(file_path, "r") as fd:
try:
data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path
except Exception:
raise MessageError(f"'{file_path}' does not contain a valid message")
data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
@ -540,17 +537,13 @@ class Message():
if self.tags:
self.tags = rename_tags(self.tags, tags_rename)
def clear_answer(self) -> None:
self.answer = None
def msg_id(self) -> str:
"""
Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name without suffix. The ID is also used for sorting
messages.
Currently this is the file name. The ID is also used for sorting messages.
"""
if self.file_path:
return self.file_path.stem
return self.file_path.name
else:
raise MessageError("Can't create file ID without a file path")

View File

@ -20,103 +20,73 @@ class TestChat(unittest.TestCase):
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('0002.txt'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.msg_filter(MessageFilter(answer_contains='Answer 1'))
self.chat.add_messages([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.msg_add([self.message2, self.message1])
self.chat.msg_sort()
self.chat.add_messages([self.message2, self.message1])
self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.msg_sort(reverse=True)
self.chat.sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.msg_add([self.message1])
self.chat.msg_clear()
self.chat.add_messages([self.message1])
self.chat.clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None:
self.chat.msg_add([self.message1, self.message2])
tags_all = self.chat.msg_tags()
self.chat.add_messages([self.message1, self.message2])
tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.msg_tags(prefix='a')
tags_pref = self.chat.tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.msg_tags(contain='2')
tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.msg_add([self.message1, self.message2])
tags_freq = self.chat.msg_tags_frequency()
self.chat.add_messages([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None:
self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.msg_find(['0001.txt'])
self.chat.add_messages([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.msg_add([message3])
self.chat.add_messages([message3])
# find new Message by full path
msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.msg_find(['0003.txt'])
msgs = self.chat.find_messages(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.msg_remove(['0003.txt'])
self.chat.remove_messages(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None:
self.assertIsNone(self.chat.msg_latest())
self.chat.msg_add([self.message1])
self.assertEqual(self.chat.msg_latest(), self.message1)
self.chat.msg_add([self.message2])
self.assertEqual(self.chat.msg_latest(), self.message2)
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"""{Question.txt_header}
Question 1
@ -131,7 +101,7 @@ Answer 2
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.msg_add([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt
@ -179,37 +149,20 @@ class TestChatDB(unittest.TestCase):
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
# add some "trash" in order to test if it's correctly handled / ignored
self.trash_files = ['.config.yaml', 'foo.yaml', 'bla.txt']
for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash')
self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
"""
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return [f for f in pathlib.Path(tmp_dir.name).glob('*.[ty]*') if f.name not in self.trash_files]
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
def tearDown(self) -> None:
self.db_path.cleanup()
self.cache_path.cleanup()
pass
def test_validate(self) -> None:
duplicate_message = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.txt'))
duplicate_message.to_file(pathlib.Path(self.db_path.name, '0004.txt'))
with self.assertRaises(ChatError) as cm:
ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(str(cm.exception), "Validation failed")
def test_from_dir(self) -> None:
def test_chat_db_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
@ -225,7 +178,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_from_dir_glob(self) -> None:
def test_chat_db_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*.txt')
@ -237,7 +190,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_from_dir_filter_tags(self) -> None:
def test_chat_db_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
@ -247,7 +200,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_from_dir_filter_tags_empty(self) -> None:
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
@ -255,7 +208,7 @@ class TestChatDB(unittest.TestCase):
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_from_dir_filter_answer(self) -> None:
def test_chat_db_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2'))
@ -266,7 +219,7 @@ class TestChatDB(unittest.TestCase):
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_from_messages(self) -> None:
def test_chat_db_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2,
@ -275,16 +228,16 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_fids(self) -> None:
def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_path, 'r') as f:
with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '7')
def test_db_write(self) -> None:
def test_chat_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -295,7 +248,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.cache_write()
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
@ -315,7 +268,7 @@ class TestChatDB(unittest.TestCase):
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.db_write()
chat_db.write_db()
# check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
@ -332,7 +285,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_db_read(self) -> None:
def test_chat_db_read(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -348,7 +301,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them
chat_db.db_read()
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
@ -363,7 +316,7 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them
chat_db.cache_read()
chat_db.read_cache()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
@ -378,7 +331,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated
chat_db.db_read()
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
@ -389,13 +342,13 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them
chat_db.db_read()
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_cache_clear(self) -> None:
def test_chat_db_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -406,13 +359,13 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.cache_write()
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths
chat_db.db_write()
chat_db.write_db()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
@ -427,10 +380,10 @@ class TestChatDB(unittest.TestCase):
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.msg_add([message_empty, message_cache])
chat_db.add_messages([message_empty, message_cache])
# clear the cache and check the cache dir
chat_db.cache_clear()
chat_db.clear_cache()
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there
@ -440,7 +393,7 @@ class TestChatDB(unittest.TestCase):
# but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_add(self) -> None:
def test_chat_db_add(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -451,7 +404,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the cache dir
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.cache_add([message1])
chat_db.add_to_cache([message1])
# check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@ -461,7 +414,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the DB dir
message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2"))
chat_db.db_add([message2])
chat_db.add_to_db([message2])
# check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
@ -469,9 +422,9 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_msg_write(self) -> None:
def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -485,16 +438,16 @@ class TestChatDB(unittest.TestCase):
message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.msg_write([message])
chat_db.write_messages([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.msg_write([message])
chat_db.write_messages([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_msg_update(self) -> None:
def test_chat_db_update_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
@ -507,62 +460,17 @@ class TestChatDB(unittest.TestCase):
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.msg_update([message], write=False)
chat_db.update_messages([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.db_read()
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.msg_update([message], write=True)
chat_db.db_read()
chat_db.update_messages([message], write=True)
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.msg_update([message1])
def test_msg_find(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find([self.message1.msg_id()], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find([self.message2.msg_id()], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), [])
self.assertEqual(chat_db.msg_find([self.message3.msg_id()], source='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), [])
# search for multiple messages
# -> search one twice, expect result to be unique
search_names = ['0001', '0002.yaml', self.message3.msg_id(), str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, source='all')
self.assertSequenceEqual(result, expected_result)
def test_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)
self.assertEqual(chat_db.msg_latest(source='db'), self.message4)
self.assertEqual(chat_db.msg_latest(source='disk'), self.message4)
self.assertEqual(chat_db.msg_latest(source='all'), self.message4)
# the cache is currently empty:
self.assertIsNone(chat_db.msg_latest(source='cache'))
# add new messages to the cache dir
new_message = Message(question=Question("New Question"),
answer=Answer("New Answer"))
chat_db.cache_add([new_message])
self.assertEqual(chat_db.msg_latest(source='cache'), new_message)
self.assertEqual(chat_db.msg_latest(source='mem'), new_message)
self.assertEqual(chat_db.msg_latest(source='disk'), new_message)
self.assertEqual(chat_db.msg_latest(source='all'), new_message)
# the DB does not contain the new message
self.assertEqual(chat_db.msg_latest(source='db'), self.message4)
chat_db.update_messages([message1])

View File

@ -730,7 +730,7 @@ class MessageIDTestCase(unittest.TestCase):
self.file_path.unlink()
def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.stem)
self.assertEqual(self.message.msg_id(), self.file_path.name)
def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError):

View File

@ -25,7 +25,7 @@ class TestMessageCreate(unittest.TestCase):
Answer("It is pure text"))
self.message_code = Message(Question("What is this?"),
Answer("Text\n```\nIt is embedded code\n```\ntext"))
self.chat.db_add([self.message_text, self.message_code])
self.chat.add_to_db([self.message_text, self.message_code])
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None