chat: added new functions: msg_unique_id(), msg_unique_content() and tests

This commit is contained in:
juk0de 2023-09-15 10:17:20 +02:00
parent 44fbff33fe
commit e2d9bf3f69
2 changed files with 79 additions and 7 deletions

View File

@ -146,6 +146,25 @@ class Chat:
except MessageError: except MessageError:
pass pass
def msg_unique_id(self) -> None:
"""
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
Messages without a file_path are kept.
"""
old_msgs = self.messages.copy()
self.messages = []
for m in old_msgs:
if not message_in(m, self.messages):
self.messages.append(m)
self.msg_sort()
def msg_unique_content(self) -> None:
"""
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
"""
self.messages = list(set(self.messages))
self.msg_sort()
def msg_clear(self) -> None: def msg_clear(self) -> None:
""" """
Delete all messages. Delete all messages.
@ -356,7 +375,13 @@ class ChatDB(Chat):
source_messages += read_dir(self.cache_path, mfilter=mfilter) source_messages += read_dir(self.cache_path, mfilter=mfilter)
if source in ['db', 'disk', 'all']: if source in ['db', 'disk', 'all']:
source_messages += read_dir(self.db_path, mfilter=mfilter) source_messages += read_dir(self.db_path, mfilter=mfilter)
return source_messages # remove_duplicates and sort the list
unique_messages: list[Message] = []
for m in source_messages:
if not message_in(m, unique_messages):
unique_messages.append(m)
unique_messages.sort(key=lambda m: m.msg_id())
return unique_messages
def msg_find(self, def msg_find(self,
msg_names: list[str], msg_names: list[str],
@ -430,6 +455,7 @@ class ChatDB(Chat):
Write messages to the cache directory. If a message has no file_path, a new one Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to will be created. If message.file_path exists, it will be modified to point to
the cache directory. the cache directory.
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
""" """
write_dir(self.cache_path, write_dir(self.cache_path,
messages if messages else self.messages, messages if messages else self.messages,
@ -480,6 +506,7 @@ class ChatDB(Chat):
Write messages to the DB directory. If a message has no file_path, a new one Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point will be created. If message.file_path exists, it will be modified to point
to the DB directory. to the DB directory.
Does NOT add the messages to the internal list (use 'db_add()' for that)!
""" """
write_dir(self.db_path, write_dir(self.db_path,
messages if messages else self.messages, messages if messages else self.messages,

View File

@ -20,6 +20,29 @@ class TestChat(unittest.TestCase):
Answer('Answer 2'), Answer('Answer 2'),
{Tag('btag2')}, {Tag('btag2')},
file_path=pathlib.Path('0002.txt')) file_path=pathlib.Path('0002.txt'))
self.maxDiff = None
def test_unique_id(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_id()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_unique_content(self) -> None:
# test with two identical messages
self.chat.msg_add([self.message1, self.message1])
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1])
# test with two different messages
self.chat.msg_add([self.message2])
self.chat.msg_unique_content()
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
def test_filter(self) -> None: def test_filter(self) -> None:
self.chat.msg_add([self.message1, self.message2]) self.chat.msg_add([self.message1, self.message2])
@ -161,6 +184,7 @@ class TestChatDB(unittest.TestCase):
for file in self.trash_files: for file in self.trash_files:
with open(pathlib.Path(self.db_path.name) / file, 'w') as f: with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
f.write('test trash') f.write('test trash')
self.maxDiff = None
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
""" """
@ -249,7 +273,7 @@ class TestChatDB(unittest.TestCase):
with open(chat_db.next_path, 'r') as f: with open(chat_db.next_path, 'r') as f:
self.assertEqual(f.read(), '7') self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None: def test_chat_db_db_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@ -297,7 +321,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_read(self) -> None: def test_chat_db_db_read(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@ -360,7 +384,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_chat_db_clear(self) -> None: def test_chat_db_cache_clear(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@ -436,7 +460,7 @@ class TestChatDB(unittest.TestCase):
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))]) chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None: def test_chat_db_msg_write(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@ -459,7 +483,7 @@ class TestChatDB(unittest.TestCase):
self.assertEqual(len(cache_dir_files), 1) self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None: def test_chat_db_msg_update(self) -> None:
# create a new ChatDB instance # create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
@ -487,7 +511,28 @@ class TestChatDB(unittest.TestCase):
with self.assertRaises(ChatError): with self.assertRaises(ChatError):
chat_db.msg_update([message1]) chat_db.msg_update([message1])
def test_chat_db_latest_message(self) -> None: def test_chat_db_msg_find(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# search for a DB file in memory
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1])
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1])
# and on disk
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2])
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2])
# now search the cache -> expect empty result
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), [])
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), [])
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), [])
# search for multiple messages
search_names = ['0001', '0002.yaml', str(self.message3.file_path)]
expected_result = [self.message1, self.message2, self.message3]
result = chat_db.msg_find(search_names, source='all')
self.assertSequenceEqual(result, expected_result)
def test_chat_db_msg_latest(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name)) pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.msg_latest(source='mem'), self.message4) self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)