chat: added new functions: msg_unique_id(), msg_unique_content() and tests
This commit is contained in:
parent
44fbff33fe
commit
bbc51c2f51
@ -146,6 +146,25 @@ class Chat:
|
||||
except MessageError:
|
||||
pass
|
||||
|
||||
def msg_unique_id(self) -> None:
|
||||
"""
|
||||
Remove duplicates from the internal messages, based on the msg_id (i. e. file_path).
|
||||
Messages without a file_path are kept.
|
||||
"""
|
||||
old_msgs = self.messages.copy()
|
||||
self.messages = []
|
||||
for m in old_msgs:
|
||||
if not message_in(m, self.messages):
|
||||
self.messages.append(m)
|
||||
self.msg_sort()
|
||||
|
||||
def msg_unique_content(self) -> None:
|
||||
"""
|
||||
Remove duplicates from the internal messages, based on the content (i. e. question + answer).
|
||||
"""
|
||||
self.messages = list(set(self.messages))
|
||||
self.msg_sort()
|
||||
|
||||
def msg_clear(self) -> None:
|
||||
"""
|
||||
Delete all messages.
|
||||
@ -356,7 +375,13 @@ class ChatDB(Chat):
|
||||
source_messages += read_dir(self.cache_path, mfilter=mfilter)
|
||||
if source in ['db', 'disk', 'all']:
|
||||
source_messages += read_dir(self.db_path, mfilter=mfilter)
|
||||
return source_messages
|
||||
# remove_duplicates and sort the list
|
||||
unique_messages: list[Message] = []
|
||||
for m in source_messages:
|
||||
if not message_in(m, unique_messages):
|
||||
unique_messages.append(m)
|
||||
unique_messages.sort(key=lambda m: m.msg_id())
|
||||
return unique_messages
|
||||
|
||||
def msg_find(self,
|
||||
msg_names: list[str],
|
||||
@ -430,6 +455,7 @@ class ChatDB(Chat):
|
||||
Write messages to the cache directory. If a message has no file_path, a new one
|
||||
will be created. If message.file_path exists, it will be modified to point to
|
||||
the cache directory.
|
||||
Does NOT add the messages to the internal list (use 'cache_add()' for that)!
|
||||
"""
|
||||
write_dir(self.cache_path,
|
||||
messages if messages else self.messages,
|
||||
@ -480,6 +506,7 @@ class ChatDB(Chat):
|
||||
Write messages to the DB directory. If a message has no file_path, a new one
|
||||
will be created. If message.file_path exists, it will be modified to point
|
||||
to the DB directory.
|
||||
Does NOT add the messages to the internal list (use 'db_add()' for that)!
|
||||
"""
|
||||
write_dir(self.db_path,
|
||||
messages if messages else self.messages,
|
||||
|
||||
@ -20,6 +20,29 @@ class TestChat(unittest.TestCase):
|
||||
Answer('Answer 2'),
|
||||
{Tag('btag2')},
|
||||
file_path=pathlib.Path('0002.txt'))
|
||||
self.maxDiff = None
|
||||
|
||||
def test_unique_id(self) -> None:
|
||||
# test with two identical messages
|
||||
self.chat.msg_add([self.message1, self.message1])
|
||||
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
|
||||
self.chat.msg_unique_id()
|
||||
self.assertSequenceEqual(self.chat.messages, [self.message1])
|
||||
# test with two different messages
|
||||
self.chat.msg_add([self.message2])
|
||||
self.chat.msg_unique_id()
|
||||
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
|
||||
|
||||
def test_unique_content(self) -> None:
|
||||
# test with two identical messages
|
||||
self.chat.msg_add([self.message1, self.message1])
|
||||
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message1])
|
||||
self.chat.msg_unique_content()
|
||||
self.assertSequenceEqual(self.chat.messages, [self.message1])
|
||||
# test with two different messages
|
||||
self.chat.msg_add([self.message2])
|
||||
self.chat.msg_unique_content()
|
||||
self.assertSequenceEqual(self.chat.messages, [self.message1, self.message2])
|
||||
|
||||
def test_filter(self) -> None:
|
||||
self.chat.msg_add([self.message1, self.message2])
|
||||
@ -161,6 +184,7 @@ class TestChatDB(unittest.TestCase):
|
||||
for file in self.trash_files:
|
||||
with open(pathlib.Path(self.db_path.name) / file, 'w') as f:
|
||||
f.write('test trash')
|
||||
self.maxDiff = None
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
|
||||
"""
|
||||
@ -174,7 +198,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.cache_path.cleanup()
|
||||
pass
|
||||
|
||||
def test_chat_db_from_dir(self) -> None:
|
||||
def test_from_dir(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(len(chat_db.messages), 4)
|
||||
@ -190,7 +214,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.messages[3].file_path,
|
||||
pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_chat_db_from_dir_glob(self) -> None:
|
||||
def test_from_dir_glob(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
glob='*.txt')
|
||||
@ -202,7 +226,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
|
||||
def test_chat_db_from_dir_filter_tags(self) -> None:
|
||||
def test_from_dir_filter_tags(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(tags_or={Tag('tag1')}))
|
||||
@ -212,7 +236,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
|
||||
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
|
||||
def test_from_dir_filter_tags_empty(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(tags_or=set(),
|
||||
@ -220,7 +244,7 @@ class TestChatDB(unittest.TestCase):
|
||||
tags_not=set()))
|
||||
self.assertEqual(len(chat_db.messages), 0)
|
||||
|
||||
def test_chat_db_from_dir_filter_answer(self) -> None:
|
||||
def test_from_dir_filter_answer(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(answer_contains='Answer 2'))
|
||||
@ -231,7 +255,7 @@ class TestChatDB(unittest.TestCase):
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||
|
||||
def test_chat_db_from_messages(self) -> None:
|
||||
def test_from_messages(self) -> None:
|
||||
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
messages=[self.message1, self.message2,
|
||||
@ -240,7 +264,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
|
||||
def test_chat_db_fids(self) -> None:
|
||||
def test_fids(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.get_next_fid(), 5)
|
||||
@ -249,7 +273,7 @@ class TestChatDB(unittest.TestCase):
|
||||
with open(chat_db.next_path, 'r') as f:
|
||||
self.assertEqual(f.read(), '7')
|
||||
|
||||
def test_chat_db_write(self) -> None:
|
||||
def test_db_write(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -297,7 +321,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_chat_db_read(self) -> None:
|
||||
def test_db_read(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -360,7 +384,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
|
||||
|
||||
def test_chat_db_clear(self) -> None:
|
||||
def test_cache_clear(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -405,7 +429,7 @@ class TestChatDB(unittest.TestCase):
|
||||
# but not the message with the cache dir path
|
||||
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
|
||||
|
||||
def test_chat_db_add(self) -> None:
|
||||
def test_add(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -436,7 +460,7 @@ class TestChatDB(unittest.TestCase):
|
||||
with self.assertRaises(ChatError):
|
||||
chat_db.cache_add([Message(Question("?"), file_path=pathlib.Path("foo"))])
|
||||
|
||||
def test_chat_db_write_messages(self) -> None:
|
||||
def test_msg_write(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -459,7 +483,7 @@ class TestChatDB(unittest.TestCase):
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
|
||||
|
||||
def test_chat_db_update_messages(self) -> None:
|
||||
def test_msg_update(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
@ -487,7 +511,28 @@ class TestChatDB(unittest.TestCase):
|
||||
with self.assertRaises(ChatError):
|
||||
chat_db.msg_update([message1])
|
||||
|
||||
def test_chat_db_latest_message(self) -> None:
|
||||
def test_msg_find(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# search for a DB file in memory
|
||||
self.assertEqual(chat_db.msg_find([str(self.message1.file_path)], source='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001.txt'], source='mem'), [self.message1])
|
||||
self.assertEqual(chat_db.msg_find(['0001'], source='mem'), [self.message1])
|
||||
# and on disk
|
||||
self.assertEqual(chat_db.msg_find([str(self.message2.file_path)], source='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002.yaml'], source='db'), [self.message2])
|
||||
self.assertEqual(chat_db.msg_find(['0002'], source='db'), [self.message2])
|
||||
# now search the cache -> expect empty result
|
||||
self.assertEqual(chat_db.msg_find([str(self.message3.file_path)], source='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003.txt'], source='cache'), [])
|
||||
self.assertEqual(chat_db.msg_find(['0003'], source='cache'), [])
|
||||
# search for multiple messages
|
||||
search_names = ['0001', '0002.yaml', str(self.message3.file_path)]
|
||||
expected_result = [self.message1, self.message2, self.message3]
|
||||
result = chat_db.msg_find(search_names, source='all')
|
||||
self.assertSequenceEqual(result, expected_result)
|
||||
|
||||
def test_msg_latest(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.msg_latest(source='mem'), self.message4)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user