refactor: renamed (almost) all Chat/ChatDB functions

This commit is contained in:
juk0de 2023-09-15 08:41:32 +02:00
parent aae8151a00
commit 378bba6002
5 changed files with 195 additions and 187 deletions

View File

@ -118,14 +118,14 @@ class Chat:
messages: list[Message] messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None: def msg_filter(self, mfilter: MessageFilter) -> None:
""" """
Use 'Message.match(mfilter) to remove all messages that Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements. don't fulfill the filter requirements.
""" """
self.messages = [m for m in self.messages if m.match(mfilter)] self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None: def msg_sort(self, reverse: bool = False) -> None:
""" """
Sort the messages according to 'Message.msg_id()'. Sort the messages according to 'Message.msg_id()'.
""" """
@ -135,33 +135,33 @@ class Chat:
except MessageError: except MessageError:
pass pass
def clear(self) -> None: def msg_clear(self) -> None:
""" """
Delete all messages. Delete all messages.
""" """
self.messages = [] self.messages = []
def add_messages(self, messages: list[Message]) -> None: def msg_add(self, messages: list[Message]) -> None:
""" """
Add new messages and sort them if possible. Add new messages and sort them if possible.
""" """
self.messages += messages self.messages += messages
self.sort() self.msg_sort()
def latest_message(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]: def msg_latest(self, mfilter: Optional[MessageFilter] = None) -> Optional[Message]:
""" """
Return the last added message (according to the file ID) that matches the given filter. Return the last added message (according to the file ID) that matches the given filter.
When containing messages without a valid file_path, it returns the latest message in When containing messages without a valid file_path, it returns the latest message in
the internal list. the internal list.
""" """
if len(self.messages) > 0: if len(self.messages) > 0:
self.sort() self.msg_sort()
for m in reversed(self.messages): for m in reversed(self.messages):
if mfilter is None or m.match(mfilter): if mfilter is None or m.match(mfilter):
return m return m
return None return None
def find_messages(self, msg_names: list[str]) -> list[Message]: def msg_find(self, msg_names: list[str]) -> list[Message]:
""" """
Search and return the messages with the given names. Names can either be filenames Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
@ -170,16 +170,16 @@ class Chat:
return [m for m in self.messages return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None: def msg_remove(self, msg_names: list[str]) -> None:
""" """
Remove the messages with the given names. Names can either be filenames Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths. (incl. the suffix) or full paths.
""" """
self.messages = [m for m in self.messages self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort() self.msg_sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: def msg_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
""" """
Get the tags of all messages, optionally filtered by prefix or substring. Get the tags of all messages, optionally filtered by prefix or substring.
""" """
@ -188,7 +188,7 @@ class Chat:
tags |= m.filter_tags(prefix, contain) tags |= m.filter_tags(prefix, contain)
return set(sorted(tags)) return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: def msg_tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
""" """
Get the frequency of all tags of all messages, optionally filtered by prefix or substring. Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
""" """
@ -292,44 +292,78 @@ class ChatDB(Chat):
with open(self.next_path, 'w') as f: with open(self.next_path, 'w') as f:
f.write(f'{fid}') f.write(f'{fid}')
def read_db(self) -> None: def msg_write(self, messages: Optional[list[Message]] = None) -> None:
""" """
Reads new messages from the DB directory. New ones are added to the internal list, Write either the given messages or the internal ones to their CURRENT file_path.
existing ones are replaced. A message is determined as 'existing' if a message with If messages are given, they all must have a valid file_path. When writing the
the same base filename (i. e. 'file_path.name') is already in the list. internal messages, the ones with a valid file_path are written, the others
are ignored.
""" """
new_messages = read_dir(self.db_path, self.glob, self.mfilter) if messages and any(m.file_path is None for m in messages):
# remove all messages from self.messages that are in the new list raise ChatError("Can't write files without a valid file_path")
self.messages = [m for m in self.messages if not message_in(m, new_messages)] msgs = iter(messages if messages else self.messages)
# copy the messages from the temporary list to self.messages and sort them while (m := next(msgs, None)):
self.messages += new_messages m.to_file()
self.sort()
def read_cache(self) -> None: def msg_update(self, messages: list[Message], write: bool = True) -> None:
""" """
Reads new messages from the cache directory. New ones are added to the internal list, Update EXISTING messages. A message is determined as 'existing' if a message with
existing ones are replaced. A message is determined as 'existing' if a message with the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
the same base filename (i. e. 'file_path.name') is already in the list. existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.msg_sort()
# write the UPDATED messages if requested
if write:
self.msg_write(messages)
def msg_latest(self,
mfilter: Optional[MessageFilter] = None,
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]:
"""
Return the last added message (according to the file ID) that matches the given filter.
Only consider messages with a valid file_path (except if source is 'mem').
Searches one of the following sources:
* 'mem' : only search messages currently in memory
* 'disk' : search messages on disk (cache + DB directory), but not in memory
* 'cache': only search messages in the cache directory
* 'db' : only search messages in the DB directory
* 'all' : search all messages ('mem' + 'disk')
"""
source_messages: list[Message] = []
if source == 'mem':
return super().msg_latest(mfilter)
if source in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter)
if source in ['all']:
# only consider messages with a valid file_path so they can be sorted
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages:
if mfilter is None or m.match(mfilter):
return m
return None
def cache_read(self) -> None:
"""
Read messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message
with the same base filename (i. e. 'file_path.name') is already in the list.
""" """
new_messages = read_dir(self.cache_path, self.glob, self.mfilter) new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list # remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)] self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them # copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages self.messages += new_messages
self.sort() self.msg_sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None: def cache_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
""" """
Write messages to the cache directory. If a message has no file_path, a new one Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to will be created. If message.file_path exists, it will be modified to point to
@ -340,36 +374,9 @@ class ChatDB(Chat):
self.file_suffix, self.file_suffix,
self.get_next_fid) self.get_next_fid)
def clear_cache(self) -> None: def cache_add(self, messages: list[Message], write: bool = True) -> None:
""" """
Deletes all Message files from the cache dir and removes those messages from Add NEW messages and set the file_path to the cache directory.
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path. Only accepts messages without a file_path.
""" """
if any(m.file_path is not None for m in messages): if any(m.file_path is not None for m in messages):
@ -383,62 +390,54 @@ class ChatDB(Chat):
for m in messages: for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages self.messages += messages
self.sort() self.msg_sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None: def cache_clear(self) -> None:
""" """
Write either the given messages or the internal ones to their current file_path. Delete all message files from the cache dir and remove them from the internal list.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
""" """
if messages and any(m.file_path is None for m in messages): clear_dir(self.cache_path, self.glob)
raise ChatError("Can't write files without a valid file_path") # only keep messages from DB dir (or those that have not yet been written)
msgs = iter(messages if messages else self.messages) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None: def db_read(self) -> None:
""" """
Update existing messages. A message is determined as 'existing' if a message with Read messages from the DB directory. New ones are added to the internal list,
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts existing ones are replaced. A message is determined as 'existing' if a message
existing messages. with the same base filename (i. e. 'file_path.name') is already in the list.
""" """
if any(not message_in(m, self.messages) for m in messages): new_messages = read_dir(self.db_path, self.glob, self.mfilter)
raise ChatError("Can't update messages that are not in the internal list") # remove all messages from self.messages that are in the new list
# remove old versions and add new ones self.messages = [m for m in self.messages if not message_in(m, new_messages)]
self.messages = [m for m in self.messages if not message_in(m, messages)] # copy the messages from the temporary list to self.messages and sort them
self.messages += messages self.messages += new_messages
self.sort() self.msg_sort()
# write the UPDATED messages if requested
def db_write(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def db_add(self, messages: list[Message], write: bool = True) -> None:
"""
Add NEW messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write: if write:
self.write_messages(messages) write_dir(self.db_path,
messages,
def latest_message(self, self.file_suffix,
mfilter: Optional[MessageFilter] = None, self.get_next_fid)
source: Literal['mem', 'disk', 'cache', 'db', 'all'] = 'mem') -> Optional[Message]: else:
""" for m in messages:
Return the last added message (according to the file ID) that matches the given filter. m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
Only consider messages with a valid file_path (except if source is 'mem'). self.messages += messages
Searches one of the following sources: self.msg_sort()
* 'mem' : only search messages currently in memory
* 'disk' : search messages on disk (cache + DB directory), but not in memory
* 'cache': only search messages in the cache directory
* 'db' : only search messages in the DB directory
* 'all' : search all messages ('mem' + 'disk')
"""
source_messages: list[Message] = []
if source == 'mem':
return super().latest_message(mfilter)
if source in ['cache', 'disk', 'all']:
source_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter)
if source in ['all']:
# only consider messages with a valid file_path so they can be sorted
source_messages += [m for m in self.messages if (m.file_path is not None and (mfilter is None or m.match(mfilter)))]
source_messages.sort(key=lambda m: m.msg_id(), reverse=True)
for m in source_messages:
if mfilter is None or m.match(mfilter):
return m
return None

View File

@ -52,7 +52,8 @@ def add_file_as_code(question_parts: list[str], file: str) -> None:
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
""" """
Creates (and writes) a new message from the given arguments. Creates a new message from the given arguments and writes it
to the cache directory.
""" """
question_parts = [] question_parts = []
question_list = args.ask if args.ask is not None else [] question_list = args.ask if args.ask is not None else []
@ -73,18 +74,26 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
tags=args.output_tags, # FIXME tags=args.output_tags, # FIXME
ai=args.AI, ai=args.AI,
model=args.model) model=args.model)
chat.write_cache([message]) # only write the message (as a backup), don't add it
# to the current chat history
chat.cache_write([message])
return message return message
def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None: def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None:
"""
Make an AI request with the give AI, chat history, message and CLI arguments.
Print all answers.
"""
ai.print() ai.print()
chat.print(paged=False) chat.print(paged=False)
print(message.to_str() + '\n')
response: AIResponse = ai.request(message, response: AIResponse = ai.request(message,
chat, chat,
args.num_answers, # FIXME args.num_answers,
args.output_tags) # FIXME args.output_tags)
chat.write_cache(response.messages) # write all answers to the cache, don't add them to the chat history
chat.cache_write(response.messages)
for idx, msg in enumerate(response.messages): for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===") print(f"=== ANSWER {idx+1} ===")
print(msg.answer) print(msg.answer)
@ -117,7 +126,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None:
make_request(ai, chat, message, args) make_request(ai, chat, message, args)
# === REPEAT === # === REPEAT ===
elif args.repeat is not None: elif args.repeat is not None:
lmessage = chat.latest_message(source='cache') lmessage = chat.msg_latest(source='cache')
if lmessage is None: if lmessage is None:
print("No message found to repeat!") print("No message found to repeat!")
sys.exit(1) sys.exit(1)

View File

@ -11,7 +11,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None:
chat = ChatDB.from_dir(cache_path=Path('.'), chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db)) db_path=Path(config.db))
if args.list: if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain) tags_freq = chat.msg_tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items(): for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}") print(f"- {tag}: {freq}")
# TODO: add renaming # TODO: add renaming

View File

@ -22,78 +22,78 @@ class TestChat(unittest.TestCase):
file_path=pathlib.Path('0002.txt')) file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.chat.msg_filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1) self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None: def test_sort(self) -> None:
self.chat.add_messages([self.message2, self.message1]) self.chat.msg_add([self.message2, self.message1])
self.chat.sort() self.chat.msg_sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True) self.chat.msg_sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2') self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None: def test_clear(self) -> None:
self.chat.add_messages([self.message1]) self.chat.msg_add([self.message1])
self.chat.clear() self.chat.msg_clear()
self.assertEqual(len(self.chat.messages), 0) self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None: def test_add_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2) self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2') self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None: def test_tags(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
tags_all = self.chat.tags() tags_all = self.chat.msg_tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a') tags_pref = self.chat.msg_tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')}) self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2') tags_cont = self.chat.msg_tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None: def test_tags_frequency(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
tags_freq = self.chat.tags_frequency() tags_freq = self.chat.msg_tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None: def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt']) msgs = self.chat.msg_find(['0001.txt'])
self.assertListEqual(msgs, [self.message1]) self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt']) msgs = self.chat.msg_find(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2]) self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path # add new Message with full path
message3 = Message(Question('Question 2'), message3 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt')) file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.add_messages([message3]) self.chat.msg_add([message3])
# find new Message by full path # find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt']) msgs = self.chat.msg_find(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3]) self.assertListEqual(msgs, [message3])
# find Message with full path only by filename # find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt']) msgs = self.chat.msg_find(['0003.txt'])
self.assertListEqual(msgs, [message3]) self.assertListEqual(msgs, [message3])
# remove last message # remove last message
self.chat.remove_messages(['0003.txt']) self.chat.msg_remove(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2]) self.assertListEqual(self.chat.messages, [self.message1, self.message2])
def test_latest_message(self) -> None: def test_latest_message(self) -> None:
self.assertIsNone(self.chat.latest_message()) self.assertIsNone(self.chat.msg_latest())
self.chat.add_messages([self.message1]) self.chat.msg_add([self.message1])
self.assertEqual(self.chat.latest_message(), self.message1) self.assertEqual(self.chat.msg_latest(), self.message1)
self.chat.add_messages([self.message2]) self.chat.msg_add([self.message2])
self.assertEqual(self.chat.latest_message(), self.message2) self.assertEqual(self.chat.msg_latest(), self.message2)
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False) self.chat.print(paged=False)
expected_output = f"""{Question.txt_header} expected_output = f"""{Question.txt_header}
Question 1 Question 1
@ -108,7 +108,7 @@ Answer 2
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True) self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2 expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt FILE: 0001.txt
@ -260,7 +260,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.cache_write()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
@ -280,7 +280,7 @@ class TestChatDB(unittest.TestCase):
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory # overwrite the messages in the db directory
time.sleep(0.05) time.sleep(0.05)
chat_db.write_db() chat_db.db_write()
# check if the written files are in the DB directory # check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
@ -313,7 +313,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them # read and check them
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 6) self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
@ -328,7 +328,7 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them # read and check them
chat_db.read_cache() chat_db.cache_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path # check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
@ -343,7 +343,7 @@ class TestChatDB(unittest.TestCase):
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated # read from the DB dir and check if the modified messages have been updated
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1') self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2') self.assertEqual(chat_db.messages[5].question, 'New Question 2')
@ -354,7 +354,7 @@ class TestChatDB(unittest.TestCase):
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them # read and check them
chat_db.read_db() chat_db.db_read()
self.assertEqual(len(chat_db.messages), 8) self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path # check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
@ -371,13 +371,13 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory # write the messages to the cache directory
chat_db.write_cache() chat_db.cache_write()
# check if the written files are in the cache directory # check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4) self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths # now rewrite them to the DB dir and check for modified paths
chat_db.write_db() chat_db.db_write()
db_dir_files = self.message_list(self.db_path) db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4) self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
@ -392,10 +392,10 @@ class TestChatDB(unittest.TestCase):
message_cache = Message(question=Question("What the hell am I doing here?"), message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"), answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt')) file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.add_messages([message_empty, message_cache]) chat_db.msg_add([message_empty, message_cache])
# clear the cache and check the cache dir # clear the cache and check the cache dir
chat_db.clear_cache() chat_db.cache_clear()
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0) self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there # make sure that the DB messages (and the new message) are still there
@ -416,7 +416,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the cache dir # add new messages to the cache dir
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
chat_db.add_to_cache([message1]) chat_db.cache_add([message1])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path) self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
@ -426,7 +426,7 @@ class TestChatDB(unittest.TestCase):
# add new messages to the DB dir # add new messages to the DB dir
message2 = Message(question=Question("Question 2"), message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2")) answer=Answer("Answer 2"))
chat_db.add_to_db([message2]) chat_db.db_add([message2])
# check if the file_path has been correctly set # check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path) self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
@ -434,7 +434,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(len(db_dir_files), 5) self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None: def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
@ -450,11 +450,11 @@ class TestChatDB(unittest.TestCase):
message = Message(question=Question("Question 1"), message = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.write_messages([message]) chat_db.msg_write([message])
# write a message with a valid file_path # write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.write_messages([message]) chat_db.msg_write([message])
cache_dir_files = self.message_list(self.cache_path) cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
@ -472,37 +472,37 @@ class TestChatDB(unittest.TestCase):
message = chat_db.messages[0] message = chat_db.messages[0]
message.answer = Answer("New answer") message.answer = Answer("New answer")
# update message without writing # update message without writing
chat_db.update_messages([message], write=False) chat_db.msg_update([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content # re-read the message and check for old content
chat_db.read_db() chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten) # now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True) chat_db.msg_update([message], write=True)
chat_db.read_db() chat_db.db_read()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error # test without file_path -> expect error
message1 = Message(question=Question("Question 1"), message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1")) answer=Answer("Answer 1"))
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.update_messages([message1]) chat_db.msg_update([message1])
def test_chat_db_latest_message(self) -> None: def test_chat_db_latest_message(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.latest_message(source='mem'), self.message4) self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)
self.assertEqual(chat_db.latest_message(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(source='db'), self.message4)
self.assertEqual(chat_db.latest_message(source='disk'), self.message4) self.assertEqual(chat_db.msg_latest(source='disk'), self.message4)
self.assertEqual(chat_db.latest_message(source='all'), self.message4) self.assertEqual(chat_db.msg_latest(source='all'), self.message4)
# the cache is currently empty: # the cache is currently empty:
self.assertIsNone(chat_db.latest_message(source='cache')) self.assertIsNone(chat_db.msg_latest(source='cache'))
# add new messages to the cache dir # add new messages to the cache dir
new_message = Message(question=Question("New Question"), new_message = Message(question=Question("New Question"),
answer=Answer("New Answer")) answer=Answer("New Answer"))
chat_db.add_to_cache([new_message]) chat_db.cache_add([new_message])
self.assertEqual(chat_db.latest_message(source='cache'), new_message) self.assertEqual(chat_db.msg_latest(source='cache'), new_message)
self.assertEqual(chat_db.latest_message(source='mem'), new_message) self.assertEqual(chat_db.msg_latest(source='mem'), new_message)
self.assertEqual(chat_db.latest_message(source='disk'), new_message) self.assertEqual(chat_db.msg_latest(source='disk'), new_message)
self.assertEqual(chat_db.latest_message(source='all'), new_message) self.assertEqual(chat_db.msg_latest(source='all'), new_message)
# the DB does not contain the new message # the DB does not contain the new message
self.assertEqual(chat_db.latest_message(source='db'), self.message4) self.assertEqual(chat_db.msg_latest(source='db'), self.message4)

View File

@ -25,7 +25,7 @@ class TestMessageCreate(unittest.TestCase):
Answer("It is pure text")) Answer("It is pure text"))
self.message_code = Message(Question("What is this?"), self.message_code = Message(Question("What is this?"),
Answer("Text\n```\nIt is embedded code\n```\ntext")) Answer("Text\n```\nIt is embedded code\n```\ntext"))
self.chat.add_to_db([self.message_text, self.message_code]) self.chat.db_add([self.message_text, self.message_code])
# create arguments mock # create arguments mock
self.args = MagicMock(spec=argparse.Namespace) self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None self.args.source_text = None