Compare commits

...

8 Commits

8 changed files with 1018 additions and 132 deletions

1
.gitignore vendored
View File

@ -131,3 +131,4 @@ dmypy.json
.config.yaml .config.yaml
db db
noweb noweb
Session.vim

204
chatmastermind/chat.py Normal file
View File

@ -0,0 +1,204 @@
"""
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
from pprint import PrettyPrinter
import pathlib
from dataclasses import dataclass, field
from typing import TypeVar, Type, Optional, ClassVar, Any
from .message import Message, MessageFilter, MessageError
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
class ChatError(Exception):
pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
@dataclass
class Chat:
"""
A class containing a complete chat history.
"""
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
try:
# the message may not have an ID if it doesn't have a file_path
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
except MessageError:
pass
def clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def add_msgs(self, msgs: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += msgs
self.sort()
def print(self, dump: bool = False) -> None:
if dump:
pp(self)
return
# for message in self.messages:
# text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
# if source_code:
# display_source_code(message['content'])
# continue
# if message['role'] == 'user':
# print('-' * terminal_width())
# if text_too_long:
# print(f"{message['role'].upper()}:")
# print(message['content'])
# else:
# print(f"{message['role'].upper()}: {message['content']}")
@dataclass
class ChatDB(Chat):
"""
A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: pathlib.Path
db_path: pathlib.Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
# set containing all file names of the current messages
message_files: set[str] = field(default_factory=set, repr=False)
@classmethod
def from_dir(cls: Type[ChatDBInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob' fs specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
message_files: set[str] = set()
file_iter = db_path.glob(glob) if glob else db_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
message_files.add(file_path.name)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob, message_files)
@classmethod
def from_messages(cls: Type[ChatDBInst],
cache_path: pathlib.Path,
db_path: pathlib.Path,
messages: list[Message],
mfilter: Optional[MessageFilter]) -> ChatDBInst:
"""
Create a ChatDB instance from the given message list. Note that the next
call to 'dump()' will write all files in order to synchronize the messages.
Similarly, 'update()' will read all messages, so you may end up with a lot
of duplicates when using 'update()' first.
"""
return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int:
next_fname = self.db_path / '.next'
try:
with open(next_fname, 'r') as f:
return int(f.read()) + 1
except Exception:
return 1
def set_next_fid(self, fid: int) -> None:
next_fname = self.db_path / '.next'
with open(next_fname, 'w') as f:
f.write(f'{fid}')
def dump(self, to_db: bool = False, force_all: bool = False) -> None:
"""
Write all messages to 'cache_path' (or 'db_path' if 'to_db' is True). If a message
has no file_path, a new one will be created. By default, only messages that have
not been written (or read) before will be dumped. Use 'force_all' to force writing
all message files.
"""
for message in self.messages:
# skip messages that we have already written (or read)
if message.file_path and message.file_path in self.message_files and not force_all:
continue
file_path = message.file_path
if not file_path:
fid = self.get_next_fid()
fname = f"{fid:04d}{self.file_suffix}"
file_path = self.db_path / fname if to_db else self.cache_path / fname
self.set_next_fid(fid)
message.to_file(file_path)
def update(self, from_cache: bool = False, force_all: bool = False) -> None:
"""
Read new messages from 'db_path' (or 'cache_path' if 'from_cache' is true).
By default, only messages that have not been read (or written) before will
be read. Use 'force_all' to force reading all messages.
"""
if from_cache:
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
else:
file_iter = self.cache_path.glob(self.glob) if self.glob else self.cache_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
if file_path.name in self.message_files and not force_all:
continue
try:
message = Message.from_file(file_path, self.mfilter)
if message:
self.messages.append(message)
self.message_files.add(file_path.name)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
self.sort()

View File

@ -63,4 +63,7 @@ class Config():
def to_file(self, path: str) -> None: def to_file(self, path: str) -> None:
with open(path, 'w') as f: with open(path, 'w') as f:
yaml.dump(asdict(self), f) yaml.dump(asdict(self), f, sort_keys=False)
def as_dict(self) -> dict[str, Any]:
return asdict(self)

View File

@ -219,9 +219,12 @@ class Message():
file_path=data.get(cls.file_yaml_key, None)) file_path=data.get(cls.file_yaml_key, None))
@classmethod @classmethod
def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: def tags_from_file(cls: Type[MessageInst],
file_path: pathlib.Path,
prefix: Optional[str] = None) -> set[Tag]:
""" """
Return only the tags from the given Message file. Return only the tags from the given Message file,
optionally filtered based on prefix.
""" """
if not file_path.exists(): if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist") raise MessageError(f"Message file '{file_path}' does not exist")
@ -229,11 +232,33 @@ class Message():
raise MessageError(f"File type '{file_path.suffix}' is not supported") raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt': if file_path.suffix == '.txt':
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
tags = TagLine(fd.readline()).tags() tags = TagLine(fd.readline()).tags(prefix)
else: # '.yaml' else: # '.yaml'
with open(file_path, "r") as fd: with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader) data = yaml.load(fd, Loader=yaml.FullLoader)
tags = set(sorted(data[cls.tags_yaml_key])) if prefix and len(prefix) > 0:
tags = set(sorted([t.strip() for t in data[cls.tags_yaml_key] if t.startswith(prefix)]))
else:
tags = set(sorted(data[cls.tags_yaml_key]))
return tags
@classmethod
def tags_from_dir(cls: Type[MessageInst],
path: pathlib.Path,
glob: Optional[str] = None,
prefix: Optional[str] = None) -> set[Tag]:
"""
Return only the tags from message files in the given directory.
The files can be filtered using 'glob', the tags by using 'prefix'.
"""
tags: set[Tag] = set()
file_iter = path.glob(glob) if glob else path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
tags |= cls.tags_from_file(file_path, prefix)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return tags return tags
@classmethod @classmethod

View File

@ -1,7 +1,7 @@
""" """
Module implementing tag related functions and classes. Module implementing tag related functions and classes.
""" """
from typing import Type, TypeVar, Optional from typing import Type, TypeVar, Optional, Final
TagInst = TypeVar('TagInst', bound='Tag') TagInst = TypeVar('TagInst', bound='Tag')
TagLineInst = TypeVar('TagLineInst', bound='TagLine') TagLineInst = TypeVar('TagLineInst', bound='TagLine')
@ -16,9 +16,9 @@ class Tag(str):
A single tag. A string that can contain anything but the default separator (' '). A single tag. A string that can contain anything but the default separator (' ').
""" """
# default separator # default separator
default_separator = ' ' default_separator: Final[str] = ' '
# alternative separators (e. g. for backwards compatibility) # alternative separators (e. g. for backwards compatibility)
alternative_separators = [','] alternative_separators: Final[list[str]] = [',']
def __new__(cls: Type[TagInst], string: str) -> TagInst: def __new__(cls: Type[TagInst], string: str) -> TagInst:
""" """
@ -93,19 +93,21 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s
class TagLine(str): class TagLine(str):
""" """
A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by
separated by the defaut separator (' '). Any operations on a TagLine will sort a list of tags, separated by the defaut separator (' '). Any operations on a
the tags. TagLine will sort the tags.
""" """
# the prefix # the prefix
prefix = 'TAGS:' prefix: Final[str] = 'TAGS:'
def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst:
""" """
Make sure the tagline string starts with the prefix. Make sure the tagline string starts with the prefix. Also replace newlines
and multiple spaces with ' ', in order to support multiline TagLines.
""" """
if not string.startswith(cls.prefix): if not string.startswith(cls.prefix):
raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'")
string = ' '.join(string.split())
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@ -114,11 +116,12 @@ class TagLine(str):
""" """
Create a new TagLine from a set of tags. Create a new TagLine from a set of tags.
""" """
return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) return cls(' '.join([cls.prefix] + sorted([t for t in tags])))
def tags(self) -> set[Tag]: def tags(self, prefix: Optional[str] = None) -> set[Tag]:
""" """
Returns all tags contained in this line as a set. Returns all tags contained in this line as a set, optionally
filtered based on prefix.
""" """
tagstr = self[len(self.prefix):].strip() tagstr = self[len(self.prefix):].strip()
separator = Tag.default_separator separator = Tag.default_separator
@ -128,7 +131,10 @@ class TagLine(str):
if s in tagstr: if s in tagstr:
separator = s separator = s
break break
return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) if prefix and len(prefix) > 0:
return set(sorted([Tag(t.strip()) for t in tagstr.split(separator) if t.startswith(prefix)]))
else:
return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)]))
def merge(self, taglines: set['TagLine']) -> 'TagLine': def merge(self, taglines: set['TagLine']) -> 'TagLine':
""" """

View File

@ -7,7 +7,6 @@ from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai from chatmastermind.api_client import ai
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from chatmastermind.tags import Tag, TagLine, TagError
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY from unittest.mock import patch, MagicMock, Mock, ANY
@ -232,116 +231,3 @@ class TestCreateParser(CmmTestCase):
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertTrue('.config.yaml' in parser.get_default('config'))
class TestTag(CmmTestCase):
def test_valid_tag(self) -> None:
tag = Tag('mytag')
self.assertEqual(tag, 'mytag')
def test_invalid_tag(self) -> None:
with self.assertRaises(TagError):
Tag('tag with space')
def test_default_separator(self) -> None:
self.assertEqual(Tag.default_separator, ' ')
def test_alternative_separators(self) -> None:
self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(CmmTestCase):
def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_invalid_tagline(self) -> None:
with self.assertRaises(TagError):
TagLine('tag1 tag2')
def test_prefix(self) -> None:
self.assertEqual(TagLine.prefix, 'TAGS:')
def test_from_set(self) -> None:
tags = {Tag('tag1'), Tag('tag2')}
tagline = TagLine.from_set(tags)
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2')
tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_merge(self) -> None:
tagline1 = TagLine('TAGS: tag1 tag2')
tagline2 = TagLine('TAGS: tag2 tag3')
merged_tagline = tagline1.merge({tagline2})
self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3')
def test_delete_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2 tag3')
new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')})
self.assertEqual(new_tagline, 'TAGS: tag2')
def test_add_tags(self) -> None:
tagline = TagLine('TAGS: tag1')
new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')})
self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3')
def test_rename_tags(self) -> None:
tagline = TagLine('TAGS: old1 old2')
new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))})
self.assertEqual(new_tagline, 'TAGS: new1 new2')
def test_match_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2 tag3')
# Test case 1: Match any tag in 'tags_or'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and: set[Tag] = set()
tags_not: set[Tag] = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 2: Match all tags in 'tags_and'
tags_or = set()
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')}
tags_not = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and = {Tag('tag1'), Tag('tag2')}
tags_not = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and = {Tag('tag1'), Tag('tag2')}
tags_not = {Tag('tag5')}
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 5: No matching tags in 'tags_or'
tags_or = {Tag('tag4'), Tag('tag5')}
tags_and = set()
tags_not = set()
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 6: Not all tags in 'tags_and' are present
tags_or = set()
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')}
tags_not = set()
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 7: Some tags in 'tags_not' are present
tags_or = {Tag('tag1')}
tags_and = set()
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 8: 'tags_or' and 'tags_and' are None, match all tags
tags_not = set()
self.assertTrue(tagline.match_tags(None, None, tags_not))
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not))

630
tests/test_message.py Normal file
View File

@ -0,0 +1,630 @@
import pathlib
import tempfile
from typing import cast
from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter
from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(CmmTestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" ```python\n print(\"Hello, World!\")\n ```\n",
" ```python\n x = 10\n y = 20\n print(x + y)\n ```\n"
]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_without_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" print(\"Hello, World!\")\n",
" x = 10\n y = 20\n print(x + y)\n"
]
result = source_code(text, include_delims=False)
self.assertEqual(result, expected_result)
def test_source_code_with_single_code_block(self) -> None:
text = "```python\nprint(\"Hello, World!\")\n```"
expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_with_no_code_blocks(self) -> None:
text = "Some text without any code blocks"
expected_result: list[str] = []
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
class QuestionTestCase(CmmTestCase):
def test_question_with_prefix(self) -> None:
with self.assertRaises(MessageError):
Question("=== QUESTION === What is your name?")
def test_question_without_prefix(self) -> None:
question = Question("What is your favorite color?")
self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase):
def test_answer_with_prefix(self) -> None:
with self.assertRaises(MessageError):
Answer("=== ANSWER === Yes")
def test_answer_without_prefix(self) -> None:
answer = Answer("No")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(CmmTestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
{Tag('tag1'), Tag('tag2')},
ai='ChatGPT',
model='gpt-3.5-turbo',
file_path=self.file_path)
self.message_min = Message(Question('This is a question.'),
file_path=self.file_path)
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{TagLine.prefix} tag1 tag2
{AILine.prefix} ChatGPT
{ModelLine.prefix} gpt-3.5-turbo
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
"""
self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{Question.txt_header}
This is a question.
"""
self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_type(self) -> None:
unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
def test_to_file_no_file_path(self) -> None:
"""
Provoke an exception using an empty path.
"""
with self.assertRaises(MessageError) as cm:
# clear the internal file_path
self.message_complete.file_path = None
self.message_complete.to_file(None)
self.assertEqual(str(cm.exception), "Got no valid path to write message")
# reset the internal file_path
self.message_complete.file_path = self.file_path
class MessageToFileYamlTestCase(CmmTestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
{Tag('tag1'), Tag('tag2')},
ai='ChatGPT',
model='gpt-3.5-turbo',
file_path=self.file_path)
self.message_multiline = Message(Question('This is a\nmultiline question.'),
Answer('This is a\nmultiline answer.'),
{Tag('tag1'), Tag('tag2')},
ai='ChatGPT',
model='gpt-3.5-turbo',
file_path=self.file_path)
self.message_min = Message(Question('This is a question.'),
file_path=self.file_path)
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{Question.yaml_key}: This is a question.
{Answer.yaml_key}: This is an answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}:
- tag1
- tag2
"""
self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{Question.yaml_key}: |-
This is a
multiline question.
{Answer.yaml_key}: |-
This is a
multiline answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}:
- tag1
- tag2
"""
self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content)
class MessageFromFileTxtTestCase(CmmTestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2
{AILine.prefix} ChatGPT
{ModelLine.prefix} gpt-3.5-turbo
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header}
This is a question.
""")
def tearDown(self) -> None:
self.file.close()
self.file_min.close()
self.file_path.unlink()
self.file_path_min.unlink()
def test_from_file_txt_complete(self) -> None:
"""
Read a complete message (with all optional values).
"""
message = Message.from_file(self.file_path)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_min(self) -> None:
"""
Read a message with only required values.
"""
message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_txt_tags_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag3')}))
self.assertIsNone(message)
def test_from_file_txt_no_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.txt")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
def test_from_file_txt_question_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='question'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='answer'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_available(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='available'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_missing(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='missing'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_question_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='answer'))
self.assertIsNone(message)
def test_from_file_txt_answer_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='question'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_exists(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_contains='answer'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_available(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='available'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_missing(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='missing'))
self.assertIsNone(message)
def test_from_file_txt_ai_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='ChatGPT'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_ai_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='Foo'))
self.assertIsNone(message)
def test_from_file_txt_model_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='gpt-3.5-turbo'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_model_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='Bar'))
self.assertIsNone(message)
class MessageFromFileYamlTestCase(CmmTestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}:
- tag1
- tag2
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
""")
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
self.file_min.close()
self.file_path_min.unlink()
def test_from_file_yaml_complete(self) -> None:
"""
Read a complete message (with all optional values).
"""
message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_min(self) -> None:
"""
Read a message with only the required values.
"""
message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.yaml")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
def test_from_file_yaml_tags_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag3')}))
self.assertIsNone(message)
def test_from_file_yaml_no_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message)
def test_from_file_yaml_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_yaml_question_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='question'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='answer'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_available(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='available'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_missing(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='missing'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_question_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='answer'))
self.assertIsNone(message)
def test_from_file_yaml_answer_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='question'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_exists(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_contains='answer'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_available(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='available'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_missing(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='missing'))
self.assertIsNone(message)
def test_from_file_yaml_ai_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='ChatGPT'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_ai_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='Foo'))
self.assertIsNone(message)
def test_from_file_yaml_model_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='gpt-3.5-turbo'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_model_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='Bar'))
self.assertIsNone(message)
class TagsFromFileTestCase(CmmTestCase):
def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
{Message.tags_yaml_key}:
- tag1
- tag2
- ptag3
""")
def tearDown(self) -> None:
self.file_txt.close()
self.file_path_txt.unlink()
self.file_yaml.close()
self.file_path_yaml.unlink()
def test_tags_from_file_txt(self) -> None:
tags = Message.tags_from_file(self.file_path_txt)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
def test_tags_from_file_yaml(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
def test_tags_from_file_txt_prefix(self) -> None:
tags = Message.tags_from_file(self.file_path_txt, prefix='p')
self.assertSetEqual(tags, {Tag('ptag3')})
def test_tags_from_file_yaml_prefix(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml, prefix='p')
self.assertSetEqual(tags, {Tag('ptag3')})
class MessageIDTestCase(CmmTestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message_no_file_path = Message(Question('This is a question.'))
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.name)
def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError):
self.message_no_file_path.msg_id()
class MessageHashTestCase(CmmTestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')},
file_path=pathlib.Path('/tmp/foo/bla'))
self.message2 = Message(Question('This is a new question.'),
file_path=pathlib.Path('/tmp/foo/bla'))
self.message3 = Message(Question('This is a question.'),
Answer('This is an answer.'),
file_path=pathlib.Path('/tmp/foo/bla'))
# message4 is a copy of message1, because only question and
# answer are used for hashing and comparison
self.message4 = Message(Question('This is a question.'),
tags={Tag('tag1'), Tag('tag2')},
ai='Blabla',
file_path=pathlib.Path('foobla'))
def test_set_hashing(self) -> None:
msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4}
self.assertEqual(len(msgs), 3)
for msg in [self.message1, self.message2, self.message3]:
self.assertIn(msg, msgs)

131
tests/test_tags.py Normal file
View File

@ -0,0 +1,131 @@
from .test_main import CmmTestCase
from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(CmmTestCase):
def test_valid_tag(self) -> None:
tag = Tag('mytag')
self.assertEqual(tag, 'mytag')
def test_invalid_tag(self) -> None:
with self.assertRaises(TagError):
Tag('tag with space')
def test_default_separator(self) -> None:
self.assertEqual(Tag.default_separator, ' ')
def test_alternative_separators(self) -> None:
self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(CmmTestCase):
def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_valid_tagline_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_invalid_tagline(self) -> None:
with self.assertRaises(TagError):
TagLine('tag1 tag2')
def test_prefix(self) -> None:
self.assertEqual(TagLine.prefix, 'TAGS:')
def test_from_set(self) -> None:
tags = {Tag('tag1'), Tag('tag2')}
tagline = TagLine.from_set(tags)
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2')
tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_tags_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2')
tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_tags_prefix(self) -> None:
tagline = TagLine('TAGS: atag1 stag2 stag3')
tags = tagline.tags(prefix='a')
self.assertEqual(tags, {Tag('atag1')})
tags = tagline.tags(prefix='s')
self.assertEqual(tags, {Tag('stag2'), Tag('stag3')})
def test_merge(self) -> None:
tagline1 = TagLine('TAGS: tag1 tag2')
tagline2 = TagLine('TAGS: tag2 tag3')
merged_tagline = tagline1.merge({tagline2})
self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3')
def test_delete_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2 tag3')
new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')})
self.assertEqual(new_tagline, 'TAGS: tag2')
def test_add_tags(self) -> None:
tagline = TagLine('TAGS: tag1')
new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')})
self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3')
def test_rename_tags(self) -> None:
tagline = TagLine('TAGS: old1 old2')
new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))})
self.assertEqual(new_tagline, 'TAGS: new1 new2')
def test_match_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2 tag3')
# Test case 1: Match any tag in 'tags_or'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and: set[Tag] = set()
tags_not: set[Tag] = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 2: Match all tags in 'tags_and'
tags_or = set()
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')}
tags_not = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and = {Tag('tag1'), Tag('tag2')}
tags_not = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and = {Tag('tag1'), Tag('tag2')}
tags_not = {Tag('tag5')}
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 5: No matching tags in 'tags_or'
tags_or = {Tag('tag4'), Tag('tag5')}
tags_and = set()
tags_not = set()
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 6: Not all tags in 'tags_and' are present
tags_or = set()
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')}
tags_not = set()
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 7: Some tags in 'tags_not' are present
tags_or = {Tag('tag1')}
tags_and = set()
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 8: 'tags_or' and 'tags_and' are None, match all tags
tags_not = set()
self.assertTrue(tagline.match_tags(None, None, tags_not))
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not))