From b13a68836a4ed49f3777ad4c8cf7038a776bcb3e Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 17:07:01 +0200 Subject: [PATCH 001/121] added new module 'tags.py' with classes 'Tag' and 'TagLine' --- chatmastermind/tags.py | 130 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 chatmastermind/tags.py diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py new file mode 100644 index 0000000..28583a2 --- /dev/null +++ b/chatmastermind/tags.py @@ -0,0 +1,130 @@ +""" +Module implementing tag related functions and classes. +""" +from typing import Type, TypeVar, Optional + +TagInst = TypeVar('TagInst', bound='Tag') +TagLineInst = TypeVar('TagLineInst', bound='TagLine') + + +class TagError(Exception): + pass + + +class Tag(str): + """ + A single tag. A string that can contain anything but the default separator (' '). + """ + # default separator + default_separator = ' ' + # alternative separators (e. g. for backwards compatibility) + alternative_separators = [','] + + def __new__(cls: Type[TagInst], string: str) -> TagInst: + """ + Make sure the tag string does not contain the default separator. + """ + if cls.default_separator in string: + raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'") + instance = super().__new__(cls, string) + return instance + + +class TagLine(str): + """ + A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, + separated by the defaut separator (' '). Any operations on a TagLine will sort + the tags. + """ + # the prefix + prefix = 'TAGS:' + + def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: + """ + Make sure the tagline string starts with the prefix. + """ + if not string.startswith(cls.prefix): + raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst: + """ + Create a new TagLine from a set of tags. + """ + return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + + def tags(self) -> set[Tag]: + """ + Returns all tags contained in this line as a set. + """ + tagstr = self[len(self.prefix):].strip() + separator = Tag.default_separator + # look for alternative separators and use the first one found + # -> we don't support different separators in the same TagLine + for s in Tag.alternative_separators: + if s in tagstr: + separator = s + break + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + + def merge(self, taglines: set['TagLine']) -> 'TagLine': + """ + Merges the tags of all given taglines into the current one + and returns a new TagLine. + """ + merged_tags = self.tags() + for tl in taglines: + merged_tags |= tl.tags() + return self.from_set(set(sorted(merged_tags))) + + def delete_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Deletes the given tags and returns a new TagLine. + """ + return self.from_set(self.tags().difference(tags)) + + def add_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Adds the given tags and returns a new TagLine. + """ + return self.from_set(set(sorted(self.tags() | tags))) + + def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + """ + Renames the given tags and returns a new TagLine. The first + tuple element is the old name, the second one is the new name. + """ + new_tags = self.tags() + for t in tags: + if t[0] in new_tags: + new_tags.remove(t[0]) + new_tags.add(t[1]) + return self.from_set(set(sorted(new_tags))) + + def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the current TagLine matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + tag_set = self.tags() + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing -- 2.36.6 From ef46f5efc942551b0ccbd37b9807eb983bcdb628 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 17 Aug 2023 08:28:15 +0200 Subject: [PATCH 002/121] added testcases for Tag and TagLine classes --- tests/test_main.py | 114 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..eb13dc5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data +from chatmastermind.tags import Tag, TagLine, TagError from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -231,3 +232,116 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From 604e5ccf73e2d3aafc45a48128317f5462bd5348 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 12:11:56 +0200 Subject: [PATCH 003/121] tags.py: converted most TagLine functions to module functions --- chatmastermind/tags.py | 99 ++++++++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 28 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 28583a2..bfe5fd5 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -30,6 +30,67 @@ class Tag(str): return instance +def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]: + """ + Deletes the given tags and returns a new set. + """ + return tags.difference(tags_delete) + + +def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]: + """ + Adds the given tags and returns a new set. + """ + return set(sorted(tags | tags_add)) + + +def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]: + """ + Merges the tags in 'tags_merge' into the current one and returns a new set. + """ + for ts in tags_merge: + tags |= ts + return tags + + +def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]: + """ + Renames the given tags and returns a new set. The first tuple element + is the old name, the second one is the new name. + """ + for t in tags_rename: + if t[0] in tags: + tags.remove(t[0]) + tags.add(t[1]) + return set(sorted(tags)) + + +def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the given set 'tags' matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tags for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing + + class TagLine(str): """ A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, @@ -71,37 +132,29 @@ class TagLine(str): def merge(self, taglines: set['TagLine']) -> 'TagLine': """ - Merges the tags of all given taglines into the current one - and returns a new TagLine. + Merges the tags of all given taglines into the current one and returns a new TagLine. """ - merged_tags = self.tags() - for tl in taglines: - merged_tags |= tl.tags() - return self.from_set(set(sorted(merged_tags))) + tags_merge = [tl.tags() for tl in taglines] + return self.from_set(merge_tags(self.tags(), tags_merge)) - def delete_tags(self, tags: set[Tag]) -> 'TagLine': + def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine': """ Deletes the given tags and returns a new TagLine. """ - return self.from_set(self.tags().difference(tags)) + return self.from_set(delete_tags(self.tags(), tags_delete)) - def add_tags(self, tags: set[Tag]) -> 'TagLine': + def add_tags(self, tags_add: set[Tag]) -> 'TagLine': """ Adds the given tags and returns a new TagLine. """ - return self.from_set(set(sorted(self.tags() | tags))) + return self.from_set(add_tags(self.tags(), tags_add)) - def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine': """ Renames the given tags and returns a new TagLine. The first tuple element is the old name, the second one is the new name. """ - new_tags = self.tags() - for t in tags: - if t[0] in new_tags: - new_tags.remove(t[0]) - new_tags.add(t[1]) - return self.from_set(set(sorted(new_tags))) + return self.from_set(rename_tags(self.tags(), tags_rename)) def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], tags_not: Optional[set[Tag]]) -> bool: @@ -117,14 +170,4 @@ class TagLine(str): 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag exclusion is still done if 'tags_not' is not 'None'). """ - tag_set = self.tags() - required_tags_present = False - excluded_tags_missing = False - if ((tags_or is None and tags_and is None) - or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 - or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 - required_tags_present = True - if ((tags_not is None) - or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 - excluded_tags_missing = True - return required_tags_present and excluded_tags_missing + return match_tags(self.tags(), tags_or, tags_and, tags_not) -- 2.36.6 From 173a46a9b52b1ff1f30d5f9acb27538daaa9379a Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:07:50 +0200 Subject: [PATCH 004/121] added new module 'message.py' --- chatmastermind/message.py | 430 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 chatmastermind/message.py diff --git a/chatmastermind/message.py b/chatmastermind/message.py new file mode 100644 index 0000000..157cd46 --- /dev/null +++ b/chatmastermind/message.py @@ -0,0 +1,430 @@ +""" +Module implementing message related functions and classes. +""" +import pathlib +import yaml +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from dataclasses import dataclass, asdict, field +from .tags import Tag, TagLine, TagError, match_tags + +QuestionInst = TypeVar('QuestionInst', bound='Question') +AnswerInst = TypeVar('AnswerInst', bound='Answer') +MessageInst = TypeVar('MessageInst', bound='Message') +AILineInst = TypeVar('AILineInst', bound='AILine') +ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') +YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] + + +class MessageError(Exception): + pass + + +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + +def source_code(text: str, include_delims: bool = False) -> list[str]: + """ + Extract all source code sections from the given text, i. e. all lines + surrounded by lines tarting with '```'. If 'include_delims' is True, + the surrounding lines are included, otherwise they are omitted. The + result list contains every source code section as a single string. + The order in the list represents the order of the sections in the text. + """ + code_sections: list[str] = [] + code_lines: list[str] = [] + in_code_block = False + + for line in text.split('\n'): + if line.strip().startswith('```'): + if include_delims: + code_lines.append(line) + if in_code_block: + code_sections.append('\n'.join(code_lines) + '\n') + code_lines.clear() + in_code_block = not in_code_block + elif in_code_block: + code_lines.append(line) + + return code_sections + + +@dataclass(kw_only=True) +class MessageFilter: + """ + Various filters for a Message. + """ + tags_or: Optional[set[Tag]] = None + tags_and: Optional[set[Tag]] = None + tags_not: Optional[set[Tag]] = None + ai: Optional[str] = None + model: Optional[str] = None + question_contains: Optional[str] = None + answer_contains: Optional[str] = None + answer_state: Optional[Literal['available', 'missing']] = None + ai_state: Optional[Literal['available', 'missing']] = None + model_state: Optional[Literal['available', 'missing']] = None + + +class AILine(str): + """ + A line that represents the AI name in a '.txt' file.. + """ + prefix: Final[str] = 'AI:' + + def __new__(cls: Type[AILineInst], string: str) -> AILineInst: + if not string.startswith(cls.prefix): + raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def ai(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: + return cls(' '.join([cls.prefix, ai])) + + +class ModelLine(str): + """ + A line that represents the model name in a '.txt' file.. + """ + prefix: Final[str] = 'MODEL:' + + def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: + if not string.startswith(cls.prefix): + raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def model(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: + return cls(' '.join([cls.prefix, model])) + + +class Question(str): + """ + A single question with a defined header. + """ + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' + + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + """ + Make sure the question string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +class Answer(str): + """ + A single answer with a defined header. + """ + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' + + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + """ + Make sure the answer string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +@dataclass +class Message(): + """ + Single message. Consists of a question and optionally an answer, a set of tags + and a file path. + """ + question: Question + answer: Optional[Answer] = None + # metadata, ignored when comparing messages + tags: Optional[set[Tag]] = field(default=None, compare=False) + ai: Optional[str] = field(default=None, compare=False) + model: Optional[str] = field(default=None, compare=False) + file_path: Optional[pathlib.Path] = field(default=None, compare=False) + # class variables + file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] + tags_yaml_key: ClassVar[str] = 'tags' + file_yaml_key: ClassVar[str] = 'file_path' + ai_yaml_key: ClassVar[str] = 'ai' + model_yaml_key: ClassVar[str] = 'model' + + def __hash__(self) -> int: + """ + The hash value is computed based on immutable members. + """ + return hash((self.question, self.answer)) + + @classmethod + def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: + """ + Create a Message from the given dict. + """ + return cls(question=data[Question.yaml_key], + answer=data.get(Answer.yaml_key, None), + tags=set(data.get(cls.tags_yaml_key, [])), + ai=data.get(cls.ai_yaml_key, None), + model=data.get(cls.model_yaml_key, None), + file_path=data.get(cls.file_yaml_key, None)) + + @classmethod + def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + """ + Return only the tags from the given Message file. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + if file_path.suffix == '.txt': + with open(file_path, "r") as fd: + tags = TagLine(fd.readline()).tags() + else: # '.yaml' + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + tags = set(sorted(data[cls.tags_yaml_key])) + return tags + + @classmethod + def from_file(cls: Type[MessageInst], file_path: pathlib.Path, + mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: + """ + Create a Message from the given file. Returns 'None' if the message does + not fulfill the filter requirements. For TXT files, the tags are matched + before building the whole message. The other filters are applied afterwards. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + + if file_path.suffix == '.txt': + message = cls.__from_file_txt(file_path, + mfilter.tags_or if mfilter else None, + mfilter.tags_and if mfilter else None, + mfilter.tags_not if mfilter else None) + else: + message = cls.__from_file_yaml(file_path) + if message and (not mfilter or (mfilter and message.match(mfilter))): + return message + else: + return None + + @classmethod + def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 + tags_or: Optional[set[Tag]] = None, + tags_and: Optional[set[Tag]] = None, + tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: + """ + Create a Message from the given TXT file. Expects the following file structures: + For '.txt': + * TagLine [Optional] + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header [Optional] + * Answer [Optional] + + Returns 'None' if the message does not fulfill the tag requirements. + """ + tags: set[Tag] = set() + question: Question + answer: Optional[Answer] = None + ai: Optional[str] = None + model: Optional[str] = None + with open(file_path, "r") as fd: + # TagLine (Optional) + try: + pos = fd.tell() + tags = TagLine(fd.readline()).tags() + except TagError: + fd.seek(pos) + if tags_or or tags_and or tags_not: + # match with an empty set if the file has no tags + if not match_tags(tags, tags_or, tags_and, tags_not): + return None + # AILine (Optional) + try: + pos = fd.tell() + ai = AILine(fd.readline()).ai() + except TagError: + fd.seek(pos) + # ModelLine (Optional) + try: + pos = fd.tell() + model = ModelLine(fd.readline()).model() + except TagError: + fd.seek(pos) + # Question and Answer + text = fd.read().strip().split('\n') + question_idx = text.index(Question.txt_header) + 1 + try: + answer_idx = text.index(Answer.txt_header) + question = Question.from_list(text[question_idx:answer_idx]) + answer = Answer.from_list(text[answer_idx + 1:]) + except ValueError: + question = Question.from_list(text[question_idx:]) + return cls(question, answer, tags, ai, model, file_path) + + @classmethod + def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: + """ + Create a Message from the given YAML file. Expects the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string [Optional] + * Message.tags_yaml_key: list of strings [Optional] + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + data[cls.file_yaml_key] = file_path + return cls.from_dict(data) + + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 + """ + Write a Message to the given file. Type is determined based on the suffix. + Currently supported suffixes: ['.txt', '.yaml'] + """ + if file_path: + self.file_path = file_path + if not self.file_path: + raise MessageError("Got no valid path to write message") + if self.file_path.suffix not in self.file_suffixes: + raise MessageError(f"File type '{self.file_path.suffix}' is not supported") + # TXT + if self.file_path.suffix == '.txt': + return self.__to_file_txt(self.file_path) + elif self.file_path.suffix == '.yaml': + return self.__to_file_yaml(self.file_path) + + def __to_file_txt(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in TXT format. + Creates the following file structures: + * TagLine + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header + * Answer + """ + with open(file_path, "w") as fd: + if self.tags: + fd.write(f'{TagLine.from_set(self.tags)}\n') + if self.ai: + fd.write(f'{AILine.from_ai(self.ai)}\n') + if self.model: + fd.write(f'{ModelLine.from_model(self.model)}\n') + fd.write(f'{Question.txt_header}\n{self.question}\n') + if self.answer: + fd.write(f'{Answer.txt_header}\n{self.answer}\n') + + def __to_file_yaml(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in YAML format. + Creates the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string + * Message.tags_yaml_key: list of strings + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "w") as fd: + data: YamlDict = {Question.yaml_key: str(self.question)} + if self.answer: + data[Answer.yaml_key] = str(self.answer) + if self.ai: + data[self.ai_yaml_key] = self.ai + if self.model: + data[self.model_yaml_key] = self.model + if self.tags: + data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) + yaml.dump(data, fd, sort_keys=False) + + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 + """ + Matches the current Message to the given filter atttributes. + Return True if all attributes match, else False. + """ + mytags = self.tags or set() + if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 + or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 + or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 + or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 + or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 + or (mfilter.model_state == 'available' and not self.model) # noqa: W503 + or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 + or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 + or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 + return False + return True + + def msg_id(self) -> str: + """ + Returns an ID that is unique throughout all messages in the same (DB) directory. + Currently this is the file name. The ID is also used for sorting messages. + """ + if self.file_path: + return self.file_path.name + else: + raise MessageError("Can't create file ID without a file path") + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From dfc12619319626757c6e776431a9581b32e4d984 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:08:22 +0200 Subject: [PATCH 005/121] added testcases for messages.py --- tests/test_main.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index eb13dc5..8ce06cb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,6 +8,7 @@ from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.tags import Tag, TagLine, TagError +from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -345,3 +346,79 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") -- 2.36.6 From 879831d7f50f6dd39a3571933453c0e8406ab3f9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:04:41 +0200 Subject: [PATCH 006/121] configuration: added 'as_dict()' as an instance function --- chatmastermind/configuration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0037916..5ae32d6 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -63,4 +63,7 @@ class Config(): def to_file(self, path: str) -> None: with open(path, 'w') as f: - yaml.dump(asdict(self), f) + yaml.dump(asdict(self), f, sort_keys=False) + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From 580c86e948bd5ac1b83209e8dbeafb4ebc6d7385 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:30:24 +0200 Subject: [PATCH 007/121] tags: TagLine constructor now supports multiline taglines and multiple spaces --- chatmastermind/tags.py | 20 +++++++++++--------- tests/test_main.py | 9 +++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bfe5fd5..544270c 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -1,7 +1,7 @@ """ Module implementing tag related functions and classes. """ -from typing import Type, TypeVar, Optional +from typing import Type, TypeVar, Optional, Final TagInst = TypeVar('TagInst', bound='Tag') TagLineInst = TypeVar('TagLineInst', bound='TagLine') @@ -16,9 +16,9 @@ class Tag(str): A single tag. A string that can contain anything but the default separator (' '). """ # default separator - default_separator = ' ' + default_separator: Final[str] = ' ' # alternative separators (e. g. for backwards compatibility) - alternative_separators = [','] + alternative_separators: Final[list[str]] = [','] def __new__(cls: Type[TagInst], string: str) -> TagInst: """ @@ -93,19 +93,21 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s class TagLine(str): """ - A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, - separated by the defaut separator (' '). Any operations on a TagLine will sort - the tags. + A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by + a list of tags, separated by the defaut separator (' '). Any operations on a + TagLine will sort the tags. """ # the prefix - prefix = 'TAGS:' + prefix: Final[str] = 'TAGS:' def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: """ - Make sure the tagline string starts with the prefix. + Make sure the tagline string starts with the prefix. Also replace newlines + and multiple spaces with ' ', in order to support multiline TagLines. """ if not string.startswith(cls.prefix): raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + string = ' '.join(string.split()) instance = super().__new__(cls, string) return instance @@ -114,7 +116,7 @@ class TagLine(str): """ Create a new TagLine from a set of tags. """ - return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) def tags(self) -> set[Tag]: """ diff --git a/tests/test_main.py b/tests/test_main.py index 8ce06cb..25cdc37 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -256,6 +256,10 @@ class TestTagLine(CmmTestCase): tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_invalid_tagline(self) -> None: with self.assertRaises(TagError): TagLine('tag1 tag2') @@ -273,6 +277,11 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From 0d6a6dd6043651ef33b4072b8298ca19a5dd507d Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 21 Aug 2023 08:29:48 +0200 Subject: [PATCH 008/121] gitignore: added vim session file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4ade1df..89bf5fb 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json .config.yaml db -noweb \ No newline at end of file +noweb +Session.vim -- 2.36.6 From aa89270876c622fbf7205133f3af99283c7ef472 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 08:46:03 +0200 Subject: [PATCH 009/121] tests: splitted 'test_main.py' into 3 modules --- tests/test_main.py | 200 ------------------------------------------ tests/test_message.py | 78 ++++++++++++++++ tests/test_tags.py | 124 ++++++++++++++++++++++++++ 3 files changed, 202 insertions(+), 200 deletions(-) create mode 100644 tests/test_message.py create mode 100644 tests/test_tags.py diff --git a/tests/test_main.py b/tests/test_main.py index 25cdc37..db5fcdb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,8 +7,6 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from chatmastermind.tags import Tag, TagLine, TagError -from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -233,201 +231,3 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) - - -class TestTag(CmmTestCase): - def test_valid_tag(self) -> None: - tag = Tag('mytag') - self.assertEqual(tag, 'mytag') - - def test_invalid_tag(self) -> None: - with self.assertRaises(TagError): - Tag('tag with space') - - def test_default_separator(self) -> None: - self.assertEqual(Tag.default_separator, ' ') - - def test_alternative_separators(self) -> None: - self.assertEqual(Tag.alternative_separators, [',']) - - -class TestTagLine(CmmTestCase): - def test_valid_tagline(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_valid_tagline_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_invalid_tagline(self) -> None: - with self.assertRaises(TagError): - TagLine('tag1 tag2') - - def test_prefix(self) -> None: - self.assertEqual(TagLine.prefix, 'TAGS:') - - def test_from_set(self) -> None: - tags = {Tag('tag1'), Tag('tag2')} - tagline = TagLine.from_set(tags) - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_tags_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_merge(self) -> None: - tagline1 = TagLine('TAGS: tag1 tag2') - tagline2 = TagLine('TAGS: tag2 tag3') - merged_tagline = tagline1.merge({tagline2}) - self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') - - def test_delete_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag2') - - def test_add_tags(self) -> None: - tagline = TagLine('TAGS: tag1') - new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') - - def test_rename_tags(self) -> None: - tagline = TagLine('TAGS: old1 old2') - new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) - self.assertEqual(new_tagline, 'TAGS: new1 new2') - - def test_match_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - - # Test case 1: Match any tag in 'tags_or' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and: set[Tag] = set() - tags_not: set[Tag] = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 2: Match all tags in 'tags_and' - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = {Tag('tag5')} - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 5: No matching tags in 'tags_or' - tags_or = {Tag('tag4'), Tag('tag5')} - tags_and = set() - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 6: Not all tags in 'tags_and' are present - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 7: Some tags in 'tags_not' are present - tags_or = {Tag('tag1')} - tags_and = set() - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 8: 'tags_or' and 'tags_and' are None, match all tags - tags_not = set() - self.assertTrue(tagline.match_tags(None, None, tags_not)) - - # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(None, None, tags_not)) - - -class SourceCodeTestCase(CmmTestCase): - def test_source_code_with_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " ```python\n print(\"Hello, World!\")\n ```\n", - " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" - ] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_without_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " print(\"Hello, World!\")\n", - " x = 10\n y = 20\n print(x + y)\n" - ] - result = source_code(text, include_delims=False) - self.assertEqual(result, expected_result) - - def test_source_code_with_single_code_block(self) -> None: - text = "```python\nprint(\"Hello, World!\")\n```" - expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_with_no_code_blocks(self) -> None: - text = "Some text without any code blocks" - expected_result: list[str] = [] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - -class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") - - def test_question_without_prefix(self) -> None: - question = Question("What is your favorite color?") - self.assertIsInstance(question, Question) - self.assertEqual(question, "What is your favorite color?") - - -class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") - - def test_answer_without_prefix(self) -> None: - answer = Answer("No") - self.assertIsInstance(answer, Answer) - self.assertEqual(answer, "No") diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..220fef2 --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,78 @@ +from .test_main import CmmTestCase +from chatmastermind.message import source_code, MessageError, Question, Answer + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") diff --git a/tests/test_tags.py b/tests/test_tags.py new file mode 100644 index 0000000..9ac9746 --- /dev/null +++ b/tests/test_tags.py @@ -0,0 +1,124 @@ +from .test_main import CmmTestCase +from chatmastermind.tags import Tag, TagLine, TagError + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From fc1b8006a0298bac1392756275a08f65e6be4db4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 19:59:38 +0200 Subject: [PATCH 010/121] tests: added testcases for Message.from/to_file() and others --- tests/test_message.py | 545 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 544 insertions(+), 1 deletion(-) diff --git a/tests/test_message.py b/tests/test_message.py index 220fef2..0e326b4 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,5 +1,9 @@ +import pathlib +import tempfile +from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, MessageError, Question, Answer +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.tags import Tag, TagLine class SourceCodeTestCase(CmmTestCase): @@ -76,3 +80,542 @@ class AnswerTestCase(CmmTestCase): answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") + + +class MessageToFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_txt_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""" + self.assertEqual(content, expected_content) + + def test_to_file_txt_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.txt_header} +This is a question. +""" + self.assertEqual(content, expected_content) + + def test_to_file_unsupported_file_type(self) -> None: + unsupported_file_path = pathlib.Path("example.doc") + with self.assertRaises(MessageError) as cm: + self.message_complete.to_file(unsupported_file_path) + self.assertEqual(str(cm.exception), "File type '.doc' is not supported") + + def test_to_file_no_file_path(self) -> None: + """ + Provoke an exception using an empty path. + """ + with self.assertRaises(MessageError) as cm: + # clear the internal file_path + self.message_complete.file_path = None + self.message_complete.to_file(None) + self.assertEqual(str(cm.exception), "Got no valid path to write message") + # reset the internal file_path + self.message_complete.file_path = self.file_path + + +class MessageToFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_multiline = Message(Question('This is a\nmultiline question.'), + Answer('This is a\nmultiline answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_yaml_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: This is a question. +{Answer.yaml_key}: This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_multiline(self) -> None: + self.message_multiline.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: |- + This is a + multiline question. +{Answer.yaml_key}: |- + This is a + multiline answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"{Question.yaml_key}: This is a question.\n" + self.assertEqual(content, expected_content) + + +class MessageFromFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_min.close() + self.file_path.unlink() + self.file_path_min.unlink() + + def test_from_file_txt_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_min(self) -> None: + """ + Read a message with only required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_txt_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.txt") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_txt_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_txt_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_txt_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class MessageFromFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + self.file_min.close() + self.file_path_min.unlink() + + def test_from_file_yaml_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_min(self) -> None: + """ + Read a message with only the required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.yaml") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_yaml_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_yaml_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_yaml_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_yaml_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class TagsFromFileTestCase(CmmTestCase): + def setUp(self) -> None: + self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt = pathlib.Path(self.file_txt.name) + with open(self.file_path_txt, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml = pathlib.Path(self.file_yaml.name) + with open(self.file_path_yaml, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + + def tearDown(self) -> None: + self.file_txt.close() + self.file_path_txt.unlink() + self.file_yaml.close() + self.file_path_yaml.unlink() + + def test_tags_from_file_txt(self) -> None: + tags = Message.tags_from_file(self.file_path_txt) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_from_file_yaml(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + +class MessageIDTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message = Message(Question('This is a question.'), + file_path=self.file_path) + self.message_no_file_path = Message(Question('This is a question.')) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_msg_id_txt(self) -> None: + self.assertEqual(self.message.msg_id(), self.file_path.name) + + def test_msg_id_txt_exception(self) -> None: + with self.assertRaises(MessageError): + self.message_no_file_path.msg_id() + + +class MessageHashTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a new question.'), + file_path=pathlib.Path('/tmp/foo/bla')) + self.message3 = Message(Question('This is a question.'), + Answer('This is an answer.'), + file_path=pathlib.Path('/tmp/foo/bla')) + # message4 is a copy of message1, because only question and + # answer are used for hashing and comparison + self.message4 = Message(Question('This is a question.'), + tags={Tag('tag1'), Tag('tag2')}, + ai='Blabla', + file_path=pathlib.Path('foobla')) + + def test_set_hashing(self) -> None: + msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4} + self.assertEqual(len(msgs), 3) + for msg in [self.message1, self.message2, self.message3]: + self.assertIn(msg, msgs) -- 2.36.6 From 7f91a2b567c721f96e11ffc0156a90d3f59a5032 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 26 Aug 2023 12:50:47 +0200 Subject: [PATCH 011/121] Added tags filtering (prefix and contained string) to TagLine and Message --- chatmastermind/message.py | 71 ++++++++++++++++++++++-- chatmastermind/tags.py | 12 +++- tests/test_message.py | 113 +++++++++++++++++++++++++++++++++++++- tests/test_tags.py | 22 +++++++- 4 files changed, 204 insertions(+), 14 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..902aaa2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -219,21 +219,57 @@ class Message(): file_path=data.get(cls.file_yaml_key, None)) @classmethod - def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: """ - Return only the tags from the given Message file. + Return only the tags from the given Message file, + optionally filtered based on prefix or contained string. """ + tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") + # for TXT, it's enough to read the TagLine if file_path.suffix == '.txt': with open(file_path, "r") as fd: - tags = TagLine(fd.readline()).tags() + try: + tags = TagLine(fd.readline()).tags(prefix, contain) + except TagError: + pass # message without tags else: # '.yaml' - with open(file_path, "r") as fd: - data = yaml.load(fd, Loader=yaml.FullLoader) - tags = set(sorted(data[cls.tags_yaml_key])) + try: + message = cls.from_file(file_path) + if message: + msg_tags = message.filter_tags(prefix=prefix, contain=contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + if msg_tags: + tags = msg_tags + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix' + and 'contain'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix, contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod @@ -395,6 +431,29 @@ class Message(): data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, fd, sort_keys=False) + def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Filter tags based on their prefix (i. e. the tag starts with a given string) + or some contained string. + """ + res_tags = self.tags + if res_tags: + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() + + def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: + """ + Returns all tags as a string with the TagLine prefix. Optionally filtered + using 'Message.filter_tags()'. + """ + if self.tags: + return str(TagLine.from_set(self.filter_tags(prefix, contain))) + else: + return str(TagLine.from_set(set())) + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 """ Matches the current Message to the given filter atttributes. diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 544270c..c438db9 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -118,9 +118,10 @@ class TagLine(str): """ return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) - def tags(self) -> set[Tag]: + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ - Returns all tags contained in this line as a set. + Returns all tags contained in this line as a set, optionally + filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() separator = Tag.default_separator @@ -130,7 +131,12 @@ class TagLine(str): if s in tagstr: separator = s break - return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() def merge(self, taglines: set['TagLine']) -> 'TagLine': """ diff --git a/tests/test_message.py b/tests/test_message.py index 0e326b4..7b8aee9 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -543,11 +543,19 @@ class TagsFromFileTestCase(CmmTestCase): self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: - fd.write(f"""{TagLine.prefix} tag1 tag2 + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 {Question.txt_header} This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) + with open(self.file_path_txt_no_tags, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -560,6 +568,16 @@ This is an answer. {Message.tags_yaml_key}: - tag1 - tag2 + - ptag3 +""") + self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) + with open(self.file_path_yaml_no_tags, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. """) def tearDown(self) -> None: @@ -570,11 +588,90 @@ This is an answer. def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_no_tags) + self.assertSetEqual(tags, set()) def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_yaml_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml_no_tags) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, contain='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, contain='R') + self.assertSetEqual(tags, set()) + + +class TagsFromDirTestCase(CmmTestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_dir_no_tags = tempfile.TemporaryDirectory() + self.tag_sets = [ + {Tag('atag1'), Tag('atag2')}, + {Tag('btag3'), Tag('btag4')}, + {Tag('ctag5'), Tag('ctag6')} + ] + self.files = [ + pathlib.Path(self.temp_dir.name, 'file1.txt'), + pathlib.Path(self.temp_dir.name, 'file2.yaml'), + pathlib.Path(self.temp_dir.name, 'file3.txt') + ] + self.files_no_tags = [ + pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), + pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), + pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + ] + for file, tags in zip(self.files, self.tag_sets): + message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags) + message.to_file(file) + for file in self.files_no_tags: + message = Message(Question('This is a question.'), + Answer('This is an answer.')) + message.to_file(file) + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def test_tags_from_dir(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) + expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2] + self.assertEqual(all_tags, expected_tags) + + def test_tags_from_dir_prefix(self) -> None: + atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a') + expected_tags = self.tag_sets[0] + self.assertEqual(atags, expected_tags) + + def test_tags_from_dir_no_tags(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name)) + self.assertSetEqual(all_tags, set()) class MessageIDTestCase(CmmTestCase): @@ -619,3 +716,13 @@ class MessageHashTestCase(CmmTestCase): self.assertEqual(len(msgs), 3) for msg in [self.message1, self.message2, self.message3]: self.assertIn(msg, msgs) + + +class MessageTagsStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_tags_str(self) -> None: + self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') diff --git a/tests/test_tags.py b/tests/test_tags.py index 9ac9746..bd2b685 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -40,15 +40,33 @@ class TestTagLine(CmmTestCase): self.assertEqual(tagline, 'TAGS: tag1 tag2') def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') + tagline = TagLine('TAGS: atag1 btag2') tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_contain(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(contain='t') + self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(contain='1') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(contain='R') + self.assertSetEqual(tags, set()) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From 169f1bb4585c495d8ce34856bf044af6da4bcc50 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 27 Aug 2023 18:07:38 +0200 Subject: [PATCH 012/121] fixed handling empty tags in TXT file --- chatmastermind/tags.py | 2 ++ tests/test_message.py | 13 +++++++++++++ tests/test_tags.py | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index c438db9..bb45a08 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -124,6 +124,8 @@ class TagLine(str): filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() + if tagstr == '': + return set() # no tags, only prefix separator = Tag.default_separator # look for alternative separators and use the first one found # -> we don't support different separators in the same TagLine diff --git a/tests/test_message.py b/tests/test_message.py index 7b8aee9..9cfb30a 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -556,6 +556,15 @@ This is an answer. This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) + with open(self.file_path_txt_tags_empty, "w") as fd: + fd.write(f"""TAGS: +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -594,6 +603,10 @@ This is an answer. tags = Message.tags_from_file(self.file_path_txt_no_tags) self.assertSetEqual(tags, set()) + def test_tags_from_file_txt_tags_empty(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_tags_empty) + self.assertSetEqual(tags, set()) + def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) diff --git a/tests/test_tags.py b/tests/test_tags.py index bd2b685..eeab199 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -44,6 +44,10 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) + def test_tags_empty(self) -> None: + tagline = TagLine('TAGS:') + self.assertSetEqual(tagline.tags(), set()) + def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() -- 2.36.6 From 73d2a9ea3b866d7780b101ed255d7a4a198969fc Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 29 Aug 2023 11:35:18 +0200 Subject: [PATCH 013/121] fixed test case file cleanup --- tests/test_message.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_message.py b/tests/test_message.py index 9cfb30a..83a73ea 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -594,6 +594,12 @@ This is an answer. self.file_path_txt.unlink() self.file_yaml.close() self.file_path_yaml.unlink() + self.file_txt_no_tags.close + self.file_path_txt_no_tags.unlink() + self.file_txt_tags_empty.close + self.file_path_txt_tags_empty.unlink() + self.file_yaml_no_tags.close() + self.file_path_yaml_no_tags.unlink() def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) @@ -671,6 +677,7 @@ class TagsFromDirTestCase(CmmTestCase): def tearDown(self) -> None: self.temp_dir.cleanup() + self.temp_dir_no_tags.cleanup() def test_tags_from_dir(self) -> None: all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) -- 2.36.6 From 8e1cdee3bfca4c6b26c5a086feb9ac3671395c1c Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 30 Aug 2023 08:20:25 +0200 Subject: [PATCH 014/121] fixed Message.filter_tags --- chatmastermind/message.py | 15 ++++++++------- tests/test_message.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 902aaa2..820d104 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -436,13 +436,14 @@ class Message(): Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ - res_tags = self.tags - if res_tags: - if prefix and len(prefix) > 0: - res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} - if contain and len(contain) > 0: - res_tags -= {tag for tag in res_tags if contain not in tag} - return res_tags or set() + if not self.tags: + return set() + res_tags = self.tags.copy() + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ diff --git a/tests/test_message.py b/tests/test_message.py index 83a73ea..2a9d0ff 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -746,3 +746,18 @@ class MessageTagsStrTestCase(CmmTestCase): def test_tags_str(self) -> None: self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') + + +class MessageFilterTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_filter_tags(self) -> None: + tags_all = self.message.filter_tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.message.filter_tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.message.filter_tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) -- 2.36.6 From b83cbb719bc7ca617a6ab4c5e05bd94a8a5ef0d8 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 09:19:38 +0200 Subject: [PATCH 015/121] added 'message_in()' function and test --- chatmastermind/message.py | 16 +++++++++++++++- tests/test_message.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 820d104..3eca26e 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,7 +3,7 @@ Module implementing message related functions and classes. """ import pathlib import yaml -from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags @@ -57,6 +57,20 @@ def source_code(text: str, include_delims: bool = False) -> list[str]: return code_sections +def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: + """ + Searches the given message list for a message with the same file + name as the given one (i. e. it compares Message.file_path.name). + If the given message has no file_path, False is returned. + """ + if not message.file_path: + return False + for m in messages: + if m.file_path and m.file_path.name == message.file_path.name: + return True + return False + + @dataclass(kw_only=True) class MessageFilter: """ diff --git a/tests/test_message.py b/tests/test_message.py index 2a9d0ff..0d7953e 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -2,7 +2,7 @@ import pathlib import tempfile from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine @@ -761,3 +761,17 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.message.filter_tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + + +class MessageInTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/bla/foo')) + + def test_message_in(self) -> None: + self.assertTrue(message_in(self.message1, [self.message1])) + self.assertFalse(message_in(self.message1, [self.message2])) -- 2.36.6 From 214a6919db1051437f2b0f05b1ce8ababd05a8b0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 15:47:29 +0200 Subject: [PATCH 016/121] tags: some clarification and new tests --- chatmastermind/tags.py | 3 ++- tests/test_tags.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bb45a08..5ea1a3a 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -77,7 +77,8 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' or all of the tags in 'tags_and' but it must never contain any of the tags in 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag - exclusion is still done if 'tags_not' is not 'None'). + exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()), + they match no tags. """ required_tags_present = False excluded_tags_missing = False diff --git a/tests/test_tags.py b/tests/test_tags.py index eeab199..aa89a06 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -144,3 +144,20 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + # Test case 10: 'tags_or' and 'tags_and' are empty, match no tags + self.assertFalse(tagline.match_tags(set(), set(), None)) + + # Test case 11: 'tags_or' is empty, match no tags + self.assertFalse(tagline.match_tags(set(), None, None)) + + # Test case 12: 'tags_and' is empty, match no tags + self.assertFalse(tagline.match_tags(None, set(), None)) + + # Test case 13: 'tags_or' is empty, match 'tags_and' + tags_and = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(None, tags_and, None)) + + # Test case 14: 'tags_and' is empty, match 'tags_or' + tags_or = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(tags_or, None, None)) -- 2.36.6 From 9f4897a5b8e94bb347b03da63e089b7f18eb6e77 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 24 Aug 2023 16:49:54 +0200 Subject: [PATCH 017/121] added new module 'chat.py' --- chatmastermind/chat.py | 278 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 chatmastermind/chat.py diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py new file mode 100644 index 0000000..c5d8bf3 --- /dev/null +++ b/chatmastermind/chat.py @@ -0,0 +1,278 @@ +""" +Module implementing various chat classes and functions for managing a chat history. +""" +import shutil +import pathlib +from pprint import PrettyPrinter +from pydoc import pager +from dataclasses import dataclass +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable +from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .tags import Tag + +ChatInst = TypeVar('ChatInst', bound='Chat') +ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') + + +class ChatError(Exception): + pass + + +def terminal_width() -> int: + return shutil.get_terminal_size().columns + + +def pp(*args: Any, **kwargs: Any) -> None: + return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) + + +def print_paged(text: str) -> None: + pager(text) + + +def read_dir(dir_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> list[Message]: + """ + Reads the messages from the given folder. + Parameters: + * 'dir_path': source directory + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages: list[Message] = [] + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + message = Message.from_file(file_path, mfilter) + if message: + messages.append(message) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return messages + + +def write_dir(dir_path: pathlib.Path, + messages: list[Message], + file_suffix: str, + next_fid: Callable[[], int]) -> None: + """ + Write all messages to the given directory. If a message has no file_path, + a new one will be created. If message.file_path exists, it will be modified + to point to the given directory. + Parameters: + * 'dir_path': destination directory + * 'messages': list of messages to write + * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] + * 'next_fid': callable that returns the next file ID + """ + for message in messages: + file_path = message.file_path + # message has no file_path: create one + if not file_path: + fid = next_fid() + fname = f"{fid:04d}{file_suffix}" + file_path = dir_path / fname + # file_path does not point to given directory: modify it + elif not file_path.parent.samefile(dir_path): + file_path = dir_path / file_path.name + message.to_file(file_path) + + +@dataclass +class Chat: + """ + A class containing a complete chat history. + """ + + messages: list[Message] + + def filter(self, mfilter: MessageFilter) -> None: + """ + Use 'Message.match(mfilter) to remove all messages that + don't fulfill the filter requirements. + """ + self.messages = [m for m in self.messages if m.match(mfilter)] + + def sort(self, reverse: bool = False) -> None: + """ + Sort the messages according to 'Message.msg_id()'. + """ + try: + # the message may not have an ID if it doesn't have a file_path + self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) + except MessageError: + pass + + def clear(self) -> None: + """ + Delete all messages. + """ + self.messages = [] + + def add_msgs(self, msgs: list[Message]) -> None: + """ + Add new messages and sort them if possible. + """ + self.messages += msgs + self.sort() + + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Get the tags of all messages, optionally filtered by prefix or substring. + """ + tags: set[Tag] = set() + for m in self.messages: + tags |= m.filter_tags(prefix, contain) + return tags + + def print(self, dump: bool = False, source_code_only: bool = False, + with_tags: bool = False, with_file: bool = False, + paged: bool = True) -> None: + if dump: + pp(self) + return + output: list[str] = [] + for message in self.messages: + if source_code_only: + output.extend(source_code(message.question, include_delims=True)) + continue + output.append('-' * terminal_width()) + output.append(Question.txt_header) + output.append(message.question) + if message.answer: + output.append(Answer.txt_header) + output.append(message.answer) + if with_tags: + output.append(message.tags_str()) + if with_file: + output.append('FILE: ' + str(message.file_path)) + if paged: + print_paged('\n'.join(output)) + else: + print(*output, sep='\n') + + +@dataclass +class ChatDB(Chat): + """ + A 'Chat' class that is bound to a given directory structure. Supports reading + and writing messages from / to that structure. Such a structure consists of + two directories: a 'cache directory', where all messages are temporarily + stored, and a 'DB' directory, where selected messages can be stored + persistently. + """ + + default_file_suffix: ClassVar[str] = '.txt' + + cache_path: pathlib.Path + db_path: pathlib.Path + # a MessageFilter that all messages must match (if given) + mfilter: Optional[MessageFilter] = None + file_suffix: str = default_file_suffix + # the glob pattern for all messages + glob: Optional[str] = None + + def __post_init__(self) -> None: + # contains the latest message ID + self.next_fname = self.db_path / '.next' + # make all paths absolute + self.cache_path = self.cache_path.absolute() + self.db_path = self.db_path.absolute() + + @classmethod + def from_dir(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a 'ChatDB' instance from the given directory structure. + Reads all messages from 'db_path' into the local message list. + Parameters: + * 'cache_path': path to the directory for temporary messages + * 'db_path': path to the directory for persistent messages + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages = read_dir(db_path, glob, mfilter) + return cls(messages, cache_path, db_path, mfilter, + cls.default_file_suffix, glob) + + @classmethod + def from_messages(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + messages: list[Message], + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a ChatDB instance from the given message list. + """ + return cls(messages, cache_path, db_path, mfilter) + + def get_next_fid(self) -> int: + try: + with open(self.next_fname, 'r') as f: + next_fid = int(f.read()) + 1 + self.set_next_fid(next_fid) + return next_fid + except Exception: + self.set_next_fid(1) + return 1 + + def set_next_fid(self, fid: int) -> None: + with open(self.next_fname, 'w') as f: + f.write(f'{fid}') + + def read_db(self) -> None: + """ + Reads new messages from the DB directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.db_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def read_cache(self) -> None: + """ + Reads new messages from the cache directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.cache_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def write_db(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the DB directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point + to the DB directory. + """ + write_dir(self.db_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) + + def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the cache directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point to + the cache directory. + """ + write_dir(self.cache_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) -- 2.36.6 From 93290da5b5badac83c1319bd5643475894b77697 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 28 Aug 2023 14:24:24 +0200 Subject: [PATCH 018/121] added tests for 'chat.py' --- tests/test_chat.py | 297 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 tests/test_chat.py diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..2d0ffa0 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,297 @@ +import pathlib +import tempfile +import time +from io import StringIO +from unittest.mock import patch +from chatmastermind.tags import TagLine +from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter +from chatmastermind.chat import Chat, ChatDB, terminal_width +from .test_main import CmmTestCase + + +class TestChat(CmmTestCase): + def setUp(self) -> None: + self.chat = Chat([]) + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('atag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('0002.txt')) + + def test_filter(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.filter(MessageFilter(answer_contains='Answer 1')) + + self.assertEqual(len(self.chat.messages), 1) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + + def test_sort(self) -> None: + self.chat.add_msgs([self.message2, self.message1]) + self.chat.sort() + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + self.chat.sort(reverse=True) + self.assertEqual(self.chat.messages[0].question, 'Question 2') + self.assertEqual(self.chat.messages[1].question, 'Question 1') + + def test_clear(self) -> None: + self.chat.add_msgs([self.message1]) + self.chat.clear() + self.assertEqual(len(self.chat.messages), 0) + + def test_add_msgs(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.assertEqual(len(self.chat.messages), 2) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + + def test_tags(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_all = self.chat.tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.chat.tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.chat.tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + @patch('sys.stdout', new_callable=StringIO) + def test_print(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False, with_tags=True, with_file=True) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{TagLine.prefix} atag1 +FILE: 0001.txt +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +{TagLine.prefix} btag2 +FILE: 0002.txt +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + +class TestChatDB(CmmTestCase): + def setUp(self) -> None: + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('tag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('tag2')}, + file_path=pathlib.Path('0002.yaml')) + self.message3 = Message(Question('Question 3'), + Answer('Answer 3'), + {Tag('tag3')}, + file_path=pathlib.Path('0003.txt')) + self.message4 = Message(Question('Question 4'), + Answer('Answer 4'), + {Tag('tag4')}, + file_path=pathlib.Path('0004.yaml')) + + self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) + self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) + self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) + self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + + def tearDown(self) -> None: + self.db_path.cleanup() + self.cache_path.cleanup() + pass + + def test_chat_db_from_dir(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + # check that the files are sorted + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, + pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_from_dir_glob(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + glob='*.txt') + self.assertEqual(len(chat_db.messages), 2) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + + def test_chat_db_filter(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(answer_contains='Answer 2')) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[0].answer, 'Answer 2') + + def test_chat_db_from_messges(self) -> None: + chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + messages=[self.message1, self.message2, + self.message3, self.message4]) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + + def test_chat_db_fids(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.get_next_fid(), 1) + self.assertEqual(chat_db.get_next_fid(), 2) + self.assertEqual(chat_db.get_next_fid(), 3) + with open(chat_db.next_fname, 'r') as f: + self.assertEqual(f.read(), '3') + + def test_chat_db_write(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + # check that Message.file_path has been correctly updated + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + + # check the timestamp of the files in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} + # overwrite the messages in the db directory + time.sleep(0.05) + chat_db.write_db() + # check if the written files are in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + # check if all files in the DB dir have actually been overwritten + for file in db_dir_files: + self.assertGreater(file.stat().st_mtime, old_timestamps[file]) + # check that Message.file_path has been correctly updated (again) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_read(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + + # create 2 new files in the DB directory + new_message1 = Message(Question('Question 5'), + Answer('Answer 5'), + {Tag('tag5')}) + new_message2 = Message(Question('Question 6'), + Answer('Answer 6'), + {Tag('tag6')}) + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 6) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # create 2 new files in the cache directory + new_message3 = Message(Question('Question 7'), + Answer('Answer 5'), + {Tag('tag7')}) + new_message4 = Message(Question('Question 8'), + Answer('Answer 6'), + {Tag('tag8')}) + new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) + # read and check them + chat_db.read_cache() + self.assertEqual(len(chat_db.messages), 8) + # check that the new message have the cache dir path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) + # an the old ones keep their path (since they have not been replaced) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now overwrite two messages in the DB directory + new_message1.question = Question('New Question 1') + new_message2.question = Question('New Question 2') + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read from the DB dir and check if the modified messages have been updated + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + self.assertEqual(chat_db.messages[4].question, 'New Question 1') + self.assertEqual(chat_db.messages[5].question, 'New Question 2') + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now write the messages from the cache to the DB directory + new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + # check that they now have the DB path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) -- 2.36.6 From 7f612bfc1745711334dac3f427d2cd63b988eda1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 08:57:54 +0200 Subject: [PATCH 019/121] added tokens() function to Message and Chat --- chatmastermind/chat.py | 7 +++++++ chatmastermind/message.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c5d8bf3..4a458df 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -129,6 +129,13 @@ class Chat: tags |= m.filter_tags(prefix, contain) return tags + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by all messages in this chat. + If unknown, 0 is returned. + """ + return sum(m.tokens() for m in self.messages) + def print(self, dump: bool = False, source_code_only: bool = False, with_tags: bool = False, with_file: bool = False, paged: bool = True) -> None: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 3eca26e..675ab3a 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -132,6 +132,7 @@ class Question(str): """ A single question with a defined header. """ + tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' @@ -165,6 +166,7 @@ class Answer(str): """ A single answer with a defined header. """ + tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '=== ANSWER ===' yaml_key: ClassVar[str] = 'answer' @@ -502,3 +504,13 @@ class Message(): def as_dict(self) -> dict[str, Any]: return asdict(self) + + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by this message. + If unknown, 0 is returned. + """ + if self.answer: + return self.question.tokens + self.answer.tokens + else: + return self.question.tokens -- 2.36.6 From d93598a74fa7490f79158b219ac5a22f2310ccb1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:07:58 +0200 Subject: [PATCH 020/121] configuration: added AIConfig class --- chatmastermind/configuration.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 5ae32d6..0780604 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -7,7 +7,15 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') @dataclass -class OpenAIConfig(): +class AIConfig: + """ + The base class of all AI configurations. + """ + name: str + + +@dataclass +class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ @@ -25,6 +33,7 @@ class OpenAIConfig(): Create OpenAIConfig from a dict. """ return cls( + name='OpenAI', api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -36,7 +45,7 @@ class OpenAIConfig(): @dataclass -class Config(): +class Config: """ The configuration file structure. """ @@ -47,7 +56,7 @@ class Config(): @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create OpenAIConfig from a dict. + Create Config from a dict. """ return cls( system=str(source['system']), -- 2.36.6 From ddfe29b9511e7e77c4df45e3e2ac55b8d10c5a36 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:35:32 +0200 Subject: [PATCH 021/121] chat: added tags_frequency() function and test --- chatmastermind/chat.py | 11 ++++++++++- tests/test_chat.py | 9 +++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4a458df..759467d 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -127,7 +127,16 @@ class Chat: tags: set[Tag] = set() for m in self.messages: tags |= m.filter_tags(prefix, contain) - return tags + return set(sorted(tags)) + + def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: + """ + Get the frequency of all tags of all messages, optionally filtered by prefix or substring. + """ + tags: list[Tag] = [] + for m in self.messages: + tags += [tag for tag in m.filter_tags(prefix, contain)] + return {tag: tags.count(tag) for tag in sorted(tags)} def tokens(self) -> int: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index 2d0ffa0..5f1fcb6 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -14,7 +14,7 @@ class TestChat(CmmTestCase): self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), - {Tag('atag1')}, + {Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), @@ -57,6 +57,11 @@ class TestChat(CmmTestCase): tags_cont = self.chat.tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + def test_tags_frequency(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_freq = self.chat.tags_frequency() + self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) @@ -83,7 +88,7 @@ Answer 2 Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 +{TagLine.prefix} atag1 btag2 FILE: 0001.txt {'-'*terminal_width()} {Question.txt_header} -- 2.36.6 From d80c3962bd9451ef45681ca8df6ef5780dc55d5f Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:44:27 +0200 Subject: [PATCH 022/121] chat: fixed handling of unsupported files in DB and chache dir --- chatmastermind/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 759467d..11f1d74 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path, messages: list[Message] = [] file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in sorted(file_iter): - if file_path.is_file(): + if file_path.is_file() and file_path.suffix in Message.file_suffixes: try: message = Message.from_file(file_path, mfilter) if message: -- 2.36.6 From ba56caf01309c50a46ff679f77fef3c2037c2a0a Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:18:41 +0200 Subject: [PATCH 023/121] chat: improved history printing --- chatmastermind/chat.py | 15 ++++++--------- tests/test_chat.py | 10 +++++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 11f1d74..e4e8ab6 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -145,27 +145,24 @@ class Chat: """ return sum(m.tokens() for m in self.messages) - def print(self, dump: bool = False, source_code_only: bool = False, - with_tags: bool = False, with_file: bool = False, + def print(self, source_code_only: bool = False, + with_tags: bool = False, with_files: bool = False, paged: bool = True) -> None: - if dump: - pp(self) - return output: list[str] = [] for message in self.messages: if source_code_only: output.extend(source_code(message.question, include_delims=True)) continue output.append('-' * terminal_width()) + if with_tags: + output.append(message.tags_str()) + if with_files: + output.append('FILE: ' + str(message.file_path)) output.append(Question.txt_header) output.append(message.question) if message.answer: output.append(Answer.txt_header) output.append(message.answer) - if with_tags: - output.append(message.tags_str()) - if with_file: - output.append('FILE: ' + str(message.file_path)) if paged: print_paged('\n'.join(output)) else: diff --git a/tests/test_chat.py b/tests/test_chat.py index 5f1fcb6..8e1ad0d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -82,21 +82,21 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) - self.chat.print(paged=False, with_tags=True, with_file=True) + self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} +{TagLine.prefix} atag1 btag2 +FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 btag2 -FILE: 0001.txt {'-'*terminal_width()} +{TagLine.prefix} btag2 +FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 -{TagLine.prefix} btag2 -FILE: 0002.txt """ self.assertEqual(mock_stdout.getvalue(), expected_output) -- 2.36.6 From f9d749cdd8f3f921b275c89302fedc8f844caa4a Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 09:19:47 +0200 Subject: [PATCH 024/121] chat: added clear_cache() function and test --- chatmastermind/chat.py | 20 +++++++++++++++++++ tests/test_chat.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index e4e8ab6..9fc0a27 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -82,6 +82,17 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) +def clear_dir(dir_path: pathlib.Path, + glob: Optional[str] = None) -> None: + """ + Deletes all Message files in the given directory. + """ + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in file_iter: + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + file_path.unlink(missing_ok=True) + + @dataclass class Chat: """ @@ -289,3 +300,12 @@ class ChatDB(Chat): msgs if msgs else self.messages, self.file_suffix, self.get_next_fid) + + def clear_cache(self) -> None: + """ + Deletes all Message files from the cache dir and removes those messages from + the internal list. + """ + clear_dir(self.cache_path, self.glob) + # only keep messages from DB dir (or those that have not yet been written) + self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e1ad0d..9e74061 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -300,3 +300,48 @@ class TestChatDB(CmmTestCase): # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + + def test_chat_db_clear(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + + # now rewrite them to the DB dir and check for modified paths + chat_db.write_db() + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + + # add a new message with empty file_path + message_empty = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You don't belong here!")) + # and one for the cache dir + message_cache = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You're a creep!"), + file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + chat_db.add_msgs([message_empty, message_cache]) + + # clear the cache and check the cache dir + chat_db.clear_cache() + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 0) + # make sure that the DB messages (and the new message) are still there + self.assertEqual(len(chat_db.messages), 5) + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + # but not the message with the cache dir path + self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) -- 2.36.6 From fa292fb73a97e167ba79d894af62d3cee40202d0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 16:00:24 +0200 Subject: [PATCH 025/121] message: improved robustness of Question and Answer content checks and tests --- chatmastermind/message.py | 48 +++++++++++++++++++++------------------ tests/test_message.py | 29 ++++++++++++++++++----- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 675ab3a..384fb96 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -128,29 +128,29 @@ class ModelLine(str): return cls(' '.join([cls.prefix, model])) -class Question(str): +class Answer(str): """ - A single question with a defined header. + A single answer with a defined header. """ - tokens: int = 0 # tokens used by this question - txt_header: ClassVar[str] = '=== QUESTION ===' - yaml_key: ClassVar[str] = 'question' + tokens: int = 0 # tokens used by this answer + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ - Make sure the question string does not contain the header. + Make sure the answer string does not contain the header as a whole line. """ - if cls.txt_header in string: - raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if cls.txt_header in string.split('\n'): + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance @@ -162,29 +162,33 @@ class Question(str): return source_code(self, include_delims) -class Answer(str): +class Question(str): """ - A single answer with a defined header. + A single question with a defined header. """ - tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' - yaml_key: ClassVar[str] = 'answer' + tokens: int = 0 # tokens used by this question + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ - Make sure the answer string does not contain the header. + Make sure the question string does not contain the header as a whole line + (also not that from 'Answer', so it's always clear where the answer starts). """ - if cls.txt_header in string: - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + string_lines = string.split('\n') + if cls.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if Answer.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance diff --git a/tests/test_message.py b/tests/test_message.py index 0d7953e..e01de66 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase): class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: + def test_question_with_header(self) -> None: with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") + Question(f"{Question.txt_header}\nWhat is your name?") - def test_question_without_prefix(self) -> None: + def test_question_with_answer_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Answer.txt_header}\nBob") + + def test_question_with_legal_header(self) -> None: + """ + If the header is just a part of a line, it's fine. + """ + question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + self.assertIsInstance(question, Question) + self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + + def test_question_without_header(self) -> None: question = Question("What is your favorite color?") self.assertIsInstance(question, Question) self.assertEqual(question, "What is your favorite color?") class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: + def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") + Answer(f"{Answer.txt_header}\nno") - def test_answer_without_prefix(self) -> None: + def test_answer_with_legal_header(self) -> None: + answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + + def test_answer_without_header(self) -> None: answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") -- 2.36.6 From 4b0f40bccdf5a1f10caf037cd41b726830ecef90 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:00:08 +0200 Subject: [PATCH 026/121] message: fixed Answer header for TXT format --- chatmastermind/message.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 384fb96..87de8e2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -96,7 +96,7 @@ class AILine(str): def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): - raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -116,7 +116,7 @@ class ModelLine(str): def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): - raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -133,7 +133,7 @@ class Answer(str): A single answer with a defined header. """ tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' + txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: @@ -355,17 +355,20 @@ class Message(): try: pos = fd.tell() ai = AILine(fd.readline()).ai() - except TagError: + except MessageError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() - except TagError: + except MessageError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') - question_idx = text.index(Question.txt_header) + 1 + try: + question_idx = text.index(Question.txt_header) + 1 + except ValueError: + raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) -- 2.36.6 From 44cd1fab4587ce9dc2b1b0f7f5a2a66d023a1ef0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:19:14 +0200 Subject: [PATCH 027/121] message: added rename_tags() function and test --- chatmastermind/message.py | 10 +++++++++- tests/test_message.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 87de8e2..0fb949c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,7 @@ import pathlib import yaml from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field -from .tags import Tag, TagLine, TagError, match_tags +from .tags import Tag, TagLine, TagError, match_tags, rename_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') @@ -499,6 +499,14 @@ class Message(): return False return True + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None: + """ + Renames the given tags. The first tuple element is the old name, + the second one is the new name. + """ + if self.tags: + self.tags = rename_tags(self.tags, tags_rename) + def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. diff --git a/tests/test_message.py b/tests/test_message.py index e01de66..e860538 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -792,3 +792,15 @@ class MessageInTestCase(CmmTestCase): def test_message_in(self) -> None: self.assertTrue(message_in(self.message1, [self.message1])) self.assertFalse(message_in(self.message1, [self.message2])) + + +class MessageRenameTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_rename_tags(self) -> None: + self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) + self.assertIsNotNone(self.message.tags) + self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -- 2.36.6 From 6e2d5009c15768e1c66f396e00b0b3d68391432d Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 3 Sep 2023 10:18:16 +0200 Subject: [PATCH 028/121] chat: new possibilites for adding messages and better tests --- chatmastermind/chat.py | 75 ++++++++++++++++++++++++---- tests/test_chat.py | 109 ++++++++++++++++++++++++++++++++--------- 2 files changed, 153 insertions(+), 31 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 9fc0a27..7e6df8f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path, return messages +def make_file_path(dir_path: pathlib.Path, + file_suffix: str, + next_fid: Callable[[], int]) -> pathlib.Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + return dir_path / f"{next_fid():04d}{file_suffix}" + + def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, @@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path, file_path = message.file_path # message has no file_path: create one if not file_path: - fid = next_fid() - fname = f"{fid:04d}{file_suffix}" - file_path = dir_path / fname + file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name @@ -124,11 +132,11 @@ class Chat: """ self.messages = [] - def add_msgs(self, msgs: list[Message]) -> None: + def add_messages(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ - self.messages += msgs + self.messages += messages self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -279,25 +287,25 @@ class ChatDB(Chat): self.messages += new_messages self.sort() - def write_db(self, msgs: Optional[list[Message]] = None) -> None: + def write_db(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) - def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + def write_cache(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) @@ -309,3 +317,52 @@ class ChatDB(Chat): clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the DB directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the cache directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def write_messages(self, messages: Optional[list[Message]] = None) -> None: + """ + Write either the given messages or the internal ones to their current file_path. + If messages are given, they all must have a valid file_path. When writing the + internal messages, the ones with a valid file_path are written, the others + are ignored. + """ + if messages and any(m.file_path is None for m in messages): + raise ChatError("Can't write files without a valid file_path") + msgs = iter(messages if messages else self.messages) + while (m := next(msgs, None)): + m.to_file() diff --git a/tests/test_chat.py b/tests/test_chat.py index 9e74061..a1c020e 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -5,7 +5,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, terminal_width +from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from .test_main import CmmTestCase @@ -22,14 +22,14 @@ class TestChat(CmmTestCase): file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: - self.chat.add_msgs([self.message2, self.message1]) + self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') @@ -38,18 +38,18 @@ class TestChat(CmmTestCase): self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: - self.chat.add_msgs([self.message1]) + self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) - def test_add_msgs(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') @@ -58,13 +58,13 @@ class TestChat(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} @@ -81,7 +81,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 @@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase): self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: + """ + List all Message files in the given TemporaryDirectory. + """ + # exclude '.next' + return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) def tearDown(self) -> None: self.db_path.cleanup() @@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase): def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.get_next_fid(), 1) - self.assertEqual(chat_db.get_next_fid(), 2) - self.assertEqual(chat_db.get_next_fid(), 3) + self.assertEqual(chat_db.get_next_fid(), 5) + self.assertEqual(chat_db.get_next_fid(), 6) + self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_fname, 'r') as f: - self.assertEqual(f.read(), '3') + self.assertEqual(f.read(), '7') def test_chat_db_write(self) -> None: # create a new ChatDB instance @@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) @@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -333,15 +344,69 @@ class TestChatDB(CmmTestCase): message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) - chat_db.add_msgs([message_empty, message_cache]) + chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + # check if the file_path has been correctly set + self.assertIsNotNone(message1.file_path) + self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + # check if the file_path has been correctly set + self.assertIsNotNone(message2.file_path) + self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 5) + + with self.assertRaises(ChatError): + chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) + + def test_chat_db_write_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + # try to write a message without a valid file_path + message = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.write_messages([message]) + + # write a message with a valid file_path + message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' + chat_db.write_messages([message]) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) -- 2.36.6 From 63040b368895fb8065c0c03a15b0f40beb561339 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 08:49:43 +0200 Subject: [PATCH 029/121] message / chat: output improvements --- chatmastermind/chat.py | 16 ++++------------ chatmastermind/message.py | 24 ++++++++++++++++++++++++ tests/test_chat.py | 16 ++++++++++++---- tests/test_message.py | 24 ++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 7e6df8f..c631dab 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -7,7 +7,7 @@ from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, ClassVar, Any, Callable -from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -170,18 +170,10 @@ class Chat: output: list[str] = [] for message in self.messages: if source_code_only: - output.extend(source_code(message.question, include_delims=True)) + output.append(message.to_str(source_code_only=True)) continue - output.append('-' * terminal_width()) - if with_tags: - output.append(message.tags_str()) - if with_files: - output.append('FILE: ' + str(message.file_path)) - output.append(Question.txt_header) - output.append(message.question) - if message.answer: - output.append(Answer.txt_header) - output.append(message.answer) + output.append(message.to_str(with_tags, with_files)) + output.append('\n' + ('-' * terminal_width()) + '\n') if paged: print_paged('\n'.join(output)) else: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 0fb949c..35de3b9 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -392,6 +392,30 @@ class Message(): data[cls.file_yaml_key] = file_path return cls.from_dict(data) + def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: + """ + Return the current Message as a string. + """ + output: list[str] = [] + if source_code_only: + # use the source code from answer only + if self.answer: + output.extend(self.answer.source_code(include_delims=True)) + return '\n'.join(output) if len(output) > 0 else '' + if with_tags: + output.append(self.tags_str()) + if with_file: + output.append('FILE: ' + str(self.file_path)) + output.append(Question.txt_header) + output.append(self.question) + if self.answer: + output.append(Answer.txt_header) + output.append(self.answer) + return '\n'.join(output) + + def __str__(self) -> str: + return self.to_str(False, False, False) + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ Write a Message to the given file. Type is determined based on the suffix. diff --git a/tests/test_chat.py b/tests/test_chat.py index a1c020e..f8302eb 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -66,16 +66,20 @@ class TestChat(CmmTestCase): def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) - expected_output = f"""{'-'*terminal_width()} -{Question.txt_header} + expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) @@ -83,20 +87,24 @@ Answer 2 def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) - expected_output = f"""{'-'*terminal_width()} -{TagLine.prefix} atag1 btag2 + expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) diff --git a/tests/test_message.py b/tests/test_message.py index e860538..a49c893 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -804,3 +804,27 @@ class MessageRenameTagsTestCase(CmmTestCase): self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) self.assertIsNotNone(self.message.tags) self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] + + +class MessageToStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_to_str(self) -> None: + expected_output = f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(), expected_output) + + def test_to_str_with_tags_and_file(self) -> None: + expected_output = f"""{TagLine.prefix} atag1 btag2 +FILE: /tmp/foo/bla +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) -- 2.36.6 From 7e25a08d6e8d1fd7e3a3ba782f4b8d20e67c8ef0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 08:16:55 +0200 Subject: [PATCH 030/121] chat: added functions for finding and deleting messages --- chatmastermind/chat.py | 52 ++++++++++++++++++++++++++++++++---------- tests/test_chat.py | 22 ++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c631dab..4e8fb20 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -2,7 +2,7 @@ Module implementing various chat classes and functions for managing a chat history. """ import shutil -import pathlib +from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass @@ -30,7 +30,7 @@ def print_paged(text: str) -> None: pager(text) -def read_dir(dir_path: pathlib.Path, +def read_dir(dir_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ @@ -55,9 +55,9 @@ def read_dir(dir_path: pathlib.Path, return messages -def make_file_path(dir_path: pathlib.Path, +def make_file_path(dir_path: Path, file_suffix: str, - next_fid: Callable[[], int]) -> pathlib.Path: + next_fid: Callable[[], int]) -> Path: """ Create a file_path for the given directory using the given file_suffix and ID generator function. @@ -65,7 +65,7 @@ def make_file_path(dir_path: pathlib.Path, return dir_path / f"{next_fid():04d}{file_suffix}" -def write_dir(dir_path: pathlib.Path, +def write_dir(dir_path: Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: @@ -90,7 +90,7 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) -def clear_dir(dir_path: pathlib.Path, +def clear_dir(dir_path: Path, glob: Optional[str] = None) -> None: """ Deletes all Message files in the given directory. @@ -139,6 +139,34 @@ class Chat: self.messages += messages self.sort() + def latest_message(self) -> Optional[Message]: + """ + Returns the last added message (according to the file ID). + """ + if len(self.messages) > 0: + self.sort() + return self.messages[-1] + else: + return None + + def find_messages(self, msg_names: list[str]) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the + caller should check the result if he requires all messages). + """ + return [m for m in self.messages + if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + + def remove_messages(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (incl. the suffix) or full paths. + """ + self.messages = [m for m in self.messages + if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + self.sort() + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. @@ -192,8 +220,8 @@ class ChatDB(Chat): default_file_suffix: ClassVar[str] = '.txt' - cache_path: pathlib.Path - db_path: pathlib.Path + cache_path: Path + db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix @@ -209,8 +237,8 @@ class ChatDB(Chat): @classmethod def from_dir(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ @@ -230,8 +258,8 @@ class ChatDB(Chat): @classmethod def from_messages(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index f8302eb..d81a97a 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -62,6 +62,28 @@ class TestChat(CmmTestCase): tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + def test_find_remove_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + msgs = self.chat.find_messages(['0001.txt']) + self.assertListEqual(msgs, [self.message1]) + msgs = self.chat.find_messages(['0001.txt', '0002.txt']) + self.assertListEqual(msgs, [self.message1, self.message2]) + # add new Message with full path + message3 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('/foo/bla/0003.txt')) + self.chat.add_messages([message3]) + # find new Message by full path + msgs = self.chat.find_messages(['/foo/bla/0003.txt']) + self.assertListEqual(msgs, [message3]) + # find Message with full path only by filename + msgs = self.chat.find_messages(['0003.txt']) + self.assertListEqual(msgs, [message3]) + # remove last message + self.chat.remove_messages(['0003.txt']) + self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) -- 2.36.6 From eb0d97ddc8cad58626d85d6a32eb10085e850128 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:46:23 +0200 Subject: [PATCH 031/121] cmm: the 'tags' command now uses the new 'ChatDB' --- chatmastermind/main.py | 34 +++++++++++++++++++++------------- chatmastermind/utils.py | 5 ----- tests/test_main.py | 2 +- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 7866179..3f31aee 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,10 +7,11 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType +from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config +from .chat import ChatDB from itertools import zip_longest from typing import Any @@ -57,12 +58,17 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: Config) -> None: +def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ - Handler for the 'tag' command. + Handler for the 'tags' command. """ + chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), + db_path=pathlib.Path(config.db)) if args.list: - print_tags_frequency(get_tags(config, None)) + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming def config_cmd(args: argparse.Namespace, config: Config) -> None: @@ -187,14 +193,16 @@ def create_parser() -> argparse.ArgumentParser: hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') - # 'tag' command parser - tag_cmd_parser = cmdparser.add_parser('tag', - help="Manage tags.", - aliases=['t']) - tag_cmd_parser.set_defaults(func=tag_cmd) - tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) - tag_group.add_argument('-l', '--list', help="List all tags and their frequency", - action='store_true') + # 'tags' command parser + tags_cmd_parser = cmdparser.add_parser('tags', + help="Manage tags.", + aliases=['t']) + tags_cmd_parser.set_defaults(func=tags_cmd) + tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) + tags_group.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') + tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") + tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") # 'config' command parser config_cmd_parser = cmdparser.add_parser('config', diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index bd80e4f..e6eeb97 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals print(message['content']) else: print(f"{message['role'].upper()}: {message['content']}") - - -def print_tags_frequency(tags: list[str]) -> None: - for tag in sorted(set(tags)): - print(f"- {tag}: {tags.count(tag)}") diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..23c3d00 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -227,7 +227,7 @@ class TestCreateParser(CmmTestCase): mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From b0504aedbef6e167469c703174d57164fc637595 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:21:49 +0200 Subject: [PATCH 032/121] cmm: the 'hist' command now uses the new 'ChatDB' --- chatmastermind/main.py | 60 +++++++++++++++++++++++------------------- tests/test_main.py | 15 ++++++----- 2 files changed, 42 insertions(+), 33 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 3f31aee..08c5e3e 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -12,6 +12,7 @@ from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB +from .message import MessageFilter from itertools import zip_longest from typing import Any @@ -32,11 +33,11 @@ def create_question_with_hist(args: argparse.Namespace, by the specified tags. """ tags = args.tags or [] - extags = args.extags or [] + etags = args.etags or [] otags = args.output_tags or [] - if not args.only_source_code: - print_tag_args(tags, extags, otags) + if not args.source_code_only: + print_tag_args(tags, etags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -53,8 +54,10 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, extags, config, - args.match_all_tags, False, False) + chat = create_chat_hist(full_question, tags, etags, config, + match_all_tags=True if args.atags else False, # FIXME + with_tags=False, + with_file=False) return chat, full_question, tags @@ -95,7 +98,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: if args.model: config.openai.model = args.model chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.only_source_code) + print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) save_answers(question, answers, tags, otags, config) @@ -107,14 +110,18 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ - tags = args.tags or [] - extags = args.extags or [] - chat = create_chat_hist(None, tags, extags, config, - args.match_all_tags, - args.with_tags, - args.with_files) - print_chat_hist(chat, args.dump, args.only_source_code) + mfilter = MessageFilter(tags_or=args.tags, + tags_and=args.atags, + tags_not=args.etags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) def print_cmd(args: argparse.Namespace, config: Config) -> None: @@ -130,7 +137,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: else: print(f"Unknown file type: {args.file}") sys.exit(1) - if args.only_source_code: + if args.source_code_only: display_source_code(data['answer']) else: print(dump_data(data).strip()) @@ -150,18 +157,17 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names', metavar='TAGS') + help='List of tag names (one must match)', metavar='TAGS') tag_arg.completer = tags_completer # type: ignore - extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', - help='List of tag names to exclude', metavar='EXTAGS') - extag_arg.completer = tags_completer # type: ignore + atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', + help='List of tag names (all must match)', metavar='TAGS') + atag_arg.completer = tags_completer # type: ignore + etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', + help='List of tag names to exclude', metavar='ETAGS') + etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', help='List of output tag names, default is input', metavar='OTAGS') otag_arg.completer = tags_completer # type: ignore - tag_parser.add_argument('-a', '--match-all-tags', - help="All given tags must match when selecting chat history entries", - action='store_true') - # enable autocompletion for tags # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], @@ -176,7 +182,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', + ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') # 'hist' command parser @@ -184,14 +190,14 @@ def create_parser() -> argparse.ArgumentParser: help="Print chat history.", aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) - hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", - action='store_true') hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", action='store_true') - hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') + hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') + hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', @@ -222,7 +228,7 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') argcomplete.autocomplete(parser) diff --git a/tests/test_main.py b/tests/test_main.py index 23c3d00..bb9aa2a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase): self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], - extags=['extag1'], + atags=None, + etags=['etag1'], output_tags=None, question=[self.question], source=None, - only_source_code=False, + source_code_only=False, number=3, max_tokens=None, temperature=None, @@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase): with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.extags, + self.args.etags, []) mock_create_chat_hist.assert_called_once_with(self.question, self.args.tags, - self.args.extags, + self.args.etags, self.config, - False, False, False) + match_all_tags=False, + with_tags=False, + with_file=False) mock_print_chat_hist.assert_called_once_with('test_chat', False, - self.args.only_source_code) + self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, self.args.number) -- 2.36.6 From f93a57c00da39ff658a261970501d0c8c5140ec2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:42:59 +0200 Subject: [PATCH 033/121] cmm: tags completion now uses 'Message.tags_from_dir' (fixes tag completion for me) --- chatmastermind/main.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 08c5e3e..b3bd1b8 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,13 +6,13 @@ import yaml import sys import argcomplete import argparse -import pathlib +from pathlib import Path from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data +from .storage import save_answers, create_chat_hist, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import MessageFilter +from .message import Message, MessageFilter from itertools import zip_longest from typing import Any @@ -20,9 +20,8 @@ default_config = '.config.yaml' def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: - with open(parsed_args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - return get_tags_unique(config, prefix) + config = Config.from_file(parsed_args.config) + return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) def create_question_with_hist(args: argparse.Namespace, @@ -65,8 +64,8 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. """ - chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), - db_path=pathlib.Path(config.db)) + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) if args.list: tags_freq = chat.tags_frequency(args.prefix, args.contain) for tag, freq in tags_freq.items(): @@ -128,7 +127,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ - fname = pathlib.Path(args.file) + fname = Path(args.file) if fname.suffix == '.yaml': with open(args.file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) -- 2.36.6 From bf1cbff6a2c11411097cca53229ac6b1c6ecae06 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:07:02 +0200 Subject: [PATCH 034/121] cmm: the 'print' command now uses 'Message.from_file()' --- chatmastermind/main.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index b3bd1b8..951d3cf 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -2,17 +2,16 @@ # -*- coding: utf-8 -*- # vim: set fileencoding=utf-8 : -import yaml import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType +from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter +from .message import Message, MessageFilter, MessageError from itertools import zip_longest from typing import Any @@ -128,18 +127,13 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'print' command. """ fname = Path(args.file) - if fname.suffix == '.yaml': - with open(args.file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif fname.suffix == '.txt': - data = read_file(fname) - else: - print(f"Unknown file type: {args.file}") + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") sys.exit(1) - if args.source_code_only: - display_source_code(data['answer']) - else: - print(dump_data(data).strip()) def create_parser() -> argparse.ArgumentParser: @@ -223,11 +217,11 @@ def create_parser() -> argparse.ArgumentParser: # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', - help="Print files.", + help="Print message files.", aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', action='store_true') argcomplete.autocomplete(parser) -- 2.36.6 From aa322de71866ed513f56b10ba5564b5a482c888b Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:00:15 +0200 Subject: [PATCH 035/121] added new module 'ai.py' --- chatmastermind/ai.py | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 chatmastermind/ai.py diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py new file mode 100644 index 0000000..4a8b914 --- /dev/null +++ b/chatmastermind/ai.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import Protocol, Optional, Union +from .configuration import AIConfig +from .tags import Tag +from .message import Message +from .chat import Chat + + +class AIError(Exception): + pass + + +@dataclass +class Tokens: + prompt: int = 0 + completion: int = 0 + total: int = 0 + + +@dataclass +class AIResponse: + """ + The response to an AI request. Consists of one or more messages + (each containing the question and a single answer) and the nr. + of used tokens. + """ + messages: list[Message] + tokens: Optional[Tokens] = None + + +class AI(Protocol): + """ + The base class for AI clients. + """ + + name: str + config: AIConfig + + def request(self, + question: Message, + context: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + context (i. e. chat history). The nr. of requested answers + corresponds to the nr. of messages in the 'AIResponse'. + """ + raise NotImplementedError + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def tokens(self, data: Union[Message, Chat]) -> int: + """ + Computes the nr. of AI language tokens for the given message + or chat. Note that the computation may not be 100% accurate + and is not implemented for all AIs. + """ + raise NotImplementedError -- 2.36.6 From b7e3ca7ca77d65c069f05b4ca005385793026c68 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 10:18:09 +0200 Subject: [PATCH 036/121] added new module 'openai.py' --- chatmastermind/ais/openai.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 chatmastermind/ais/openai.py diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py new file mode 100644 index 0000000..74438b8 --- /dev/null +++ b/chatmastermind/ais/openai.py @@ -0,0 +1,96 @@ +""" +Implements the OpenAI client classes and functions. +""" +import openai +from typing import Optional, Union +from ..tags import Tag +from ..message import Message, Answer +from ..chat import Chat +from ..ai import AI, AIResponse, Tokens +from ..configuration import OpenAIConfig + +ChatType = list[dict[str, str]] + + +class OpenAI(AI): + """ + The OpenAI AI client. + """ + + def __init__(self, name: str, config: OpenAIConfig) -> None: + self.name = name + self.config = config + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + chat history. The nr. of requested answers corresponds to the + nr. of messages in the 'AIResponse'. + """ + # FIXME: use real 'system' message (store in OpenAIConfig) + oai_chat = self.openai_chat(chat, "system", question) + response = openai.ChatCompletion.create( + model=self.config.model, + messages=oai_chat, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + n=num_answers, + frequency_penalty=self.config.frequency_penalty, + presence_penalty=self.config.presence_penalty) + answers: list[Message] = [] + for choice in response['choices']: # type: ignore + answers.append(Message(question=question.question, + answer=Answer(choice['message']['content']), + tags=otags, + ai=self.name, + model=self.config.model)) + return AIResponse(answers, Tokens(response['usage']['prompt'], + response['usage']['completion'], + response['usage']['total'])) + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def print_models(self) -> None: + """ + Print all models supported by the current AI. + """ + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def openai_chat(self, chat: Chat, system: str, + question: Optional[Message] = None) -> ChatType: + """ + Create a chat history with system message in OpenAI format. + Optionally append a new question. + """ + oai_chat: ChatType = [] + + def append(role: str, content: str) -> None: + oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + + append('system', system) + for message in chat.messages: + if message.answer: + append('user', message.question) + append('assistant', message.answer) + if question: + append('user', question.question) + return oai_chat + + def tokens(self, data: Union[Message, Chat]) -> int: + raise NotImplementedError -- 2.36.6 From eb2fcba99d6918edf89b42ff8f2c171e49532c4a Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 5 Sep 2023 23:24:20 +0200 Subject: [PATCH 037/121] added new module 'ai_factory' --- chatmastermind/ai_factory.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 chatmastermind/ai_factory.py diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py new file mode 100644 index 0000000..c90366b --- /dev/null +++ b/chatmastermind/ai_factory.py @@ -0,0 +1,20 @@ +""" +Creates different AI instances, based on the given configuration. +""" + +import argparse +from .configuration import Config +from .ai import AI, AIError +from .ais.openai import OpenAI + + +def create_ai(args: argparse.Namespace, config: Config) -> AI: + """ + Creates an AI subclass instance from the given args and configuration. + """ + if args.ai == 'openai': + # FIXME: create actual 'OpenAIConfig' and set values from 'args' + # FIXME: use actual name from config + return OpenAI("openai", config.openai) + else: + raise AIError(f"AI '{args.ai}' is not supported") -- 2.36.6 From ba5aa1fbc73013cee81c7bb27b0a970866b6bf25 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:35:53 +0200 Subject: [PATCH 038/121] cmm: added 'question' command --- chatmastermind/main.py | 103 +++++++++++++++++++++++++++++++++-------- tests/test_main.py | 18 +++---- 2 files changed, 93 insertions(+), 28 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 951d3cf..b10b97b 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -11,7 +11,9 @@ from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter, MessageError +from .message import Message, MessageFilter, MessageError, Question +from .ai_factory import create_ai +from .ai import AI, AIResponse from itertools import zip_longest from typing import Any @@ -30,12 +32,12 @@ def create_question_with_hist(args: argparse.Namespace, Creates the "AI request", including the question and chat history as determined by the specified tags. """ - tags = args.tags or [] - etags = args.etags or [] + tags = args.or_tags or [] + xtags = args.exclude_tags or [] otags = args.output_tags or [] if not args.source_code_only: - print_tag_args(tags, etags, otags) + print_tag_args(tags, xtags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -52,8 +54,8 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, etags, config, - match_all_tags=True if args.atags else False, # FIXME + chat = create_chat_hist(full_question, tags, xtags, config, + match_all_tags=True if args.and_tags else False, # FIXME with_tags=False, with_file=False) return chat, full_question, tags @@ -85,6 +87,47 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None: config.to_file(args.config) +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = Message(question=Question(args.question), + tags=args.ouput_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass + + def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. @@ -98,7 +141,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] - answers, usage = ai(chat, config, args.number) + answers, usage = ai(chat, config, args.num_answers) save_answers(question, answers, tags, otags, config) print("-" * terminal_width()) print(f"Usage: {usage}") @@ -109,9 +152,9 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'hist' command. """ - mfilter = MessageFilter(tags_or=args.tags, - tags_and=args.atags, - tags_not=args.etags, + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, question_contains=args.question, answer_contains=args.answer) chat = ChatDB.from_dir(Path('.'), @@ -139,7 +182,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-c', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -149,19 +192,41 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) - tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names (one must match)', metavar='TAGS') + tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', + help='List of tag names (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore - atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', - help='List of tag names (all must match)', metavar='TAGS') + atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', + help='List of tag names (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore - etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', - help='List of tag names to exclude', metavar='ETAGS') + etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', + help='List of tag names to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OTAGS') + help='List of output tag names, default is input', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # 'question' command parser + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + help="ask, create and process questions.", + aliases=['q']) + question_cmd_parser.set_defaults(func=question_cmd) + question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) + question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') + question_group.add_argument('-c', '--create', nargs='+', help='Create a question') + question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') + question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') + question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', + action='store_true') + question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) + question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) + question_cmd_parser.add_argument('-A', '--AI', help='AI to use') + question_cmd_parser.add_argument('-M', '--model', help='Model to use') + question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, + default=1) + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', + action='store_true') + # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], help="Ask a question.", @@ -172,7 +237,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, + ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', diff --git a/tests/test_main.py b/tests/test_main.py index bb9aa2a..ce9121a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -114,14 +114,14 @@ class TestHandleQuestion(CmmTestCase): def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( - tags=['tag1'], - atags=None, - etags=['etag1'], + or_tags=['tag1'], + and_tags=None, + exclude_tags=['xtag1'], output_tags=None, question=[self.question], source=None, source_code_only=False, - number=3, + num_answers=3, max_tokens=None, temperature=None, model=None, @@ -143,12 +143,12 @@ class TestHandleQuestion(CmmTestCase): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.etags, + mock_print_tag_args.assert_called_once_with(self.args.or_tags, + self.args.exclude_tags, []) mock_create_chat_hist.assert_called_once_with(self.question, - self.args.tags, - self.args.etags, + self.args.or_tags, + self.args.exclude_tags, self.config, match_all_tags=False, with_tags=False, @@ -158,7 +158,7 @@ class TestHandleQuestion(CmmTestCase): self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, - self.args.number) + self.args.num_answers) expected_calls = [] for num, answer in enumerate(mock_ai.return_value[0], start=1): title = f'-- ANSWER {num} ' -- 2.36.6 From 893917e455f87b7059657a5d2f01854f096ac5bc Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:12:05 +0200 Subject: [PATCH 039/121] test_main: temporarily disabled all testcases --- tests/test_chat.py | 6 +- tests/test_main.py | 468 +++++++++++++++++++++--------------------- tests/test_message.py | 34 +-- tests/test_tags.py | 6 +- 4 files changed, 257 insertions(+), 257 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index d81a97a..8e4aa8c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,3 +1,4 @@ +import unittest import pathlib import tempfile import time @@ -6,10 +7,9 @@ from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError -from .test_main import CmmTestCase -class TestChat(CmmTestCase): +class TestChat(unittest.TestCase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), @@ -131,7 +131,7 @@ Answer 2 self.assertEqual(mock_stdout.getvalue(), expected_output) -class TestChatDB(CmmTestCase): +class TestChatDB(unittest.TestCase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() diff --git a/tests/test_main.py b/tests/test_main.py index ce9121a..91e6462 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,236 +1,236 @@ -import unittest -import io -import pathlib -import argparse -from chatmastermind.utils import terminal_width -from chatmastermind.main import create_parser, ask_cmd -from chatmastermind.api_client import ai -from chatmastermind.configuration import Config -from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from unittest import mock -from unittest.mock import patch, MagicMock, Mock, ANY +# import unittest +# import io +# import pathlib +# import argparse +# from chatmastermind.utils import terminal_width +# from chatmastermind.main import create_parser, ask_cmd +# from chatmastermind.api_client import ai +# from chatmastermind.configuration import Config +# from chatmastermind.storage import create_chat_hist, save_answers, dump_data +# from unittest import mock +# from unittest.mock import patch, MagicMock, Mock, ANY -class CmmTestCase(unittest.TestCase): - """ - Base class for all cmm testcases. - """ - def dummy_config(self, db: str) -> Config: - """ - Creates a dummy configuration. - """ - return Config.from_dict( - {'system': 'dummy_system', - 'db': db, - 'openai': {'api_key': 'dummy_key', - 'model': 'dummy_model', - 'max_tokens': 4000, - 'temperature': 1.0, - 'top_p': 1, - 'frequency_penalty': 0, - 'presence_penalty': 0}} - ) - - -class TestCreateChat(CmmTestCase): - - def setUp(self) -> None: - self.config = self.dummy_config(db='test_files') - self.question = "test question" - self.tags = ['test_tag'] - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['test_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 4) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['other_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 2) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.side_effect = ( - io.StringIO(dump_data({'question': 'test_content', - 'answer': 'some answer', - 'tags': ['test_tag']})), - io.StringIO(dump_data({'question': 'test_content2', - 'answer': 'some answer2', - 'tags': ['test_tag2']})), - ) - - test_chat = create_chat_hist(self.question, [], None, self.config) - - self.assertEqual(len(test_chat), 6) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': 'test_content2'}) - self.assertEqual(test_chat[4], - {'role': 'assistant', 'content': 'some answer2'}) - - -class TestHandleQuestion(CmmTestCase): - - def setUp(self) -> None: - self.question = "test question" - self.args = argparse.Namespace( - or_tags=['tag1'], - and_tags=None, - exclude_tags=['xtag1'], - output_tags=None, - question=[self.question], - source=None, - source_code_only=False, - num_answers=3, - max_tokens=None, - temperature=None, - model=None, - match_all_tags=False, - with_tags=False, - with_file=False, - ) - self.config = self.dummy_config(db='test_files') - - @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") - @patch("chatmastermind.main.print_tag_args") - @patch("chatmastermind.main.print_chat_hist") - @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) - @patch("chatmastermind.utils.pp") - @patch("builtins.print") - def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, - mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, - mock_create_chat_hist: MagicMock) -> None: - open_mock = MagicMock() - with patch("chatmastermind.storage.open", open_mock): - ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.or_tags, - self.args.exclude_tags, - []) - mock_create_chat_hist.assert_called_once_with(self.question, - self.args.or_tags, - self.args.exclude_tags, - self.config, - match_all_tags=False, - with_tags=False, - with_file=False) - mock_print_chat_hist.assert_called_once_with('test_chat', - False, - self.args.source_code_only) - mock_ai.assert_called_with("test_chat", - self.config, - self.args.num_answers) - expected_calls = [] - for num, answer in enumerate(mock_ai.return_value[0], start=1): - title = f'-- ANSWER {num} ' - title_end = '-' * (terminal_width() - len(title)) - expected_calls.append(((f'{title}{title_end}',),)) - expected_calls.append(((answer,),)) - expected_calls.append((("-" * terminal_width(),),)) - expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - self.assertEqual(mock_print.call_args_list, expected_calls) - open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) - open_mock.assert_has_calls(open_expected_calls, any_order=True) - - -class TestSaveAnswers(CmmTestCase): - @mock.patch('builtins.open') - @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: - question = "Test question?" - answers = ["Answer 1", "Answer 2"] - tags = ["tag1", "tag2"] - otags = ["otag1", "otag2"] - config = self.dummy_config(db='test_db') - - with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ - mock.patch('chatmastermind.storage.yaml.dump'), \ - mock.patch('io.StringIO') as stringio_mock: - stringio_instance = stringio_mock.return_value - stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] - save_answers(question, answers, tags, otags, config) - - open_calls = [ - mock.call(pathlib.Path('test_db/.next'), 'r'), - mock.call(pathlib.Path('test_db/.next'), 'w'), - ] - open_mock.assert_has_calls(open_calls, any_order=True) - - -class TestAI(CmmTestCase): - - @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock) -> None: - mock_create.return_value = { - 'choices': [ - {'message': {'content': 'response_text_1'}}, - {'message': {'content': 'response_text_2'}} - ], - 'usage': {'tokens': 10} - } - - chat = [{"role": "system", "content": "hello ai"}] - config = self.dummy_config(db='dummy') - config.openai.model = "text-davinci-002" - config.openai.max_tokens = 150 - config.openai.temperature = 0.5 - - result = ai(chat, config, 2) - expected_result = (['response_text_1', 'response_text_2'], - {'tokens': 10}) - self.assertEqual(result, expected_result) - - -class TestCreateParser(CmmTestCase): - def test_create_parser(self) -> None: - with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: - mock_cmdparser = Mock() - mock_add_subparsers.return_value = mock_cmdparser - parser = create_parser() - self.assertIsInstance(parser, argparse.ArgumentParser) - mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) - mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) - self.assertTrue('.config.yaml' in parser.get_default('config')) +# class CmmTestCase(unittest.TestCase): +# """ +# Base class for all cmm testcases. +# """ +# def dummy_config(self, db: str) -> Config: +# """ +# Creates a dummy configuration. +# """ +# return Config.from_dict( +# {'system': 'dummy_system', +# 'db': db, +# 'openai': {'api_key': 'dummy_key', +# 'model': 'dummy_model', +# 'max_tokens': 4000, +# 'temperature': 1.0, +# 'top_p': 1, +# 'frequency_penalty': 0, +# 'presence_penalty': 0}} +# ) +# +# +# class TestCreateChat(CmmTestCase): +# +# def setUp(self) -> None: +# self.config = self.dummy_config(db='test_files') +# self.question = "test question" +# self.tags = ['test_tag'] +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['test_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 4) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['other_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 2) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.side_effect = ( +# io.StringIO(dump_data({'question': 'test_content', +# 'answer': 'some answer', +# 'tags': ['test_tag']})), +# io.StringIO(dump_data({'question': 'test_content2', +# 'answer': 'some answer2', +# 'tags': ['test_tag2']})), +# ) +# +# test_chat = create_chat_hist(self.question, [], None, self.config) +# +# self.assertEqual(len(test_chat), 6) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': 'test_content2'}) +# self.assertEqual(test_chat[4], +# {'role': 'assistant', 'content': 'some answer2'}) +# +# +# class TestHandleQuestion(CmmTestCase): +# +# def setUp(self) -> None: +# self.question = "test question" +# self.args = argparse.Namespace( +# or_tags=['tag1'], +# and_tags=None, +# exclude_tags=['xtag1'], +# output_tags=None, +# question=[self.question], +# source=None, +# source_code_only=False, +# num_answers=3, +# max_tokens=None, +# temperature=None, +# model=None, +# match_all_tags=False, +# with_tags=False, +# with_file=False, +# ) +# self.config = self.dummy_config(db='test_files') +# +# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") +# @patch("chatmastermind.main.print_tag_args") +# @patch("chatmastermind.main.print_chat_hist") +# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) +# @patch("chatmastermind.utils.pp") +# @patch("builtins.print") +# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, +# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, +# mock_create_chat_hist: MagicMock) -> None: +# open_mock = MagicMock() +# with patch("chatmastermind.storage.open", open_mock): +# ask_cmd(self.args, self.config) +# mock_print_tag_args.assert_called_once_with(self.args.or_tags, +# self.args.exclude_tags, +# []) +# mock_create_chat_hist.assert_called_once_with(self.question, +# self.args.or_tags, +# self.args.exclude_tags, +# self.config, +# match_all_tags=False, +# with_tags=False, +# with_file=False) +# mock_print_chat_hist.assert_called_once_with('test_chat', +# False, +# self.args.source_code_only) +# mock_ai.assert_called_with("test_chat", +# self.config, +# self.args.num_answers) +# expected_calls = [] +# for num, answer in enumerate(mock_ai.return_value[0], start=1): +# title = f'-- ANSWER {num} ' +# title_end = '-' * (terminal_width() - len(title)) +# expected_calls.append(((f'{title}{title_end}',),)) +# expected_calls.append(((answer,),)) +# expected_calls.append((("-" * terminal_width(),),)) +# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) +# self.assertEqual(mock_print.call_args_list, expected_calls) +# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) +# open_mock.assert_has_calls(open_expected_calls, any_order=True) +# +# +# class TestSaveAnswers(CmmTestCase): +# @mock.patch('builtins.open') +# @mock.patch('chatmastermind.storage.print') +# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: +# question = "Test question?" +# answers = ["Answer 1", "Answer 2"] +# tags = ["tag1", "tag2"] +# otags = ["otag1", "otag2"] +# config = self.dummy_config(db='test_db') +# +# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ +# mock.patch('chatmastermind.storage.yaml.dump'), \ +# mock.patch('io.StringIO') as stringio_mock: +# stringio_instance = stringio_mock.return_value +# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] +# save_answers(question, answers, tags, otags, config) +# +# open_calls = [ +# mock.call(pathlib.Path('test_db/.next'), 'r'), +# mock.call(pathlib.Path('test_db/.next'), 'w'), +# ] +# open_mock.assert_has_calls(open_calls, any_order=True) +# +# +# class TestAI(CmmTestCase): +# +# @patch("openai.ChatCompletion.create") +# def test_ai(self, mock_create: MagicMock) -> None: +# mock_create.return_value = { +# 'choices': [ +# {'message': {'content': 'response_text_1'}}, +# {'message': {'content': 'response_text_2'}} +# ], +# 'usage': {'tokens': 10} +# } +# +# chat = [{"role": "system", "content": "hello ai"}] +# config = self.dummy_config(db='dummy') +# config.openai.model = "text-davinci-002" +# config.openai.max_tokens = 150 +# config.openai.temperature = 0.5 +# +# result = ai(chat, config, 2) +# expected_result = (['response_text_1', 'response_text_2'], +# {'tokens': 10}) +# self.assertEqual(result, expected_result) +# +# +# class TestCreateParser(CmmTestCase): +# def test_create_parser(self) -> None: +# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: +# mock_cmdparser = Mock() +# mock_add_subparsers.return_value = mock_cmdparser +# parser = create_parser() +# self.assertIsInstance(parser, argparse.ArgumentParser) +# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) +# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) +# self.assertTrue('.config.yaml' in parser.get_default('config')) diff --git a/tests/test_message.py b/tests/test_message.py index a49c893..57d5982 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,12 +1,12 @@ +import unittest import pathlib import tempfile from typing import cast -from .test_main import CmmTestCase from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine -class SourceCodeTestCase(CmmTestCase): +class SourceCodeTestCase(unittest.TestCase): def test_source_code_with_include_delims(self) -> None: text = """ Some text before the code block @@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase): self.assertEqual(result, expected_result) -class QuestionTestCase(CmmTestCase): +class QuestionTestCase(unittest.TestCase): def test_question_with_header(self) -> None: with self.assertRaises(MessageError): Question(f"{Question.txt_header}\nWhat is your name?") @@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase): self.assertEqual(question, "What is your favorite color?") -class AnswerTestCase(CmmTestCase): +class AnswerTestCase(unittest.TestCase): def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): Answer(f"{Answer.txt_header}\nno") @@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase): self.assertEqual(answer, "No") -class MessageToFileTxtTestCase(CmmTestCase): +class MessageToFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -160,7 +160,7 @@ This is a question. self.message_complete.file_path = self.file_path -class MessageToFileYamlTestCase(CmmTestCase): +class MessageToFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase): self.assertEqual(content, expected_content) -class MessageFromFileTxtTestCase(CmmTestCase): +class MessageFromFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -388,7 +388,7 @@ This is a question. self.assertIsNone(message) -class MessageFromFileYamlTestCase(CmmTestCase): +class MessageFromFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase): self.assertIsNone(message) -class TagsFromFileTestCase(CmmTestCase): +class TagsFromFileTestCase(unittest.TestCase): def setUp(self) -> None: self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) @@ -663,7 +663,7 @@ This is an answer. self.assertSetEqual(tags, set()) -class TagsFromDirTestCase(CmmTestCase): +class TagsFromDirTestCase(unittest.TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory() @@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase): self.assertSetEqual(all_tags, set()) -class MessageIDTestCase(CmmTestCase): +class MessageIDTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase): self.message_no_file_path.msg_id() -class MessageHashTestCase(CmmTestCase): +class MessageHashTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase): self.assertIn(msg, msgs) -class MessageTagsStrTestCase(CmmTestCase): +class MessageTagsStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase): self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') -class MessageFilterTagsTestCase(CmmTestCase): +class MessageFilterTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) -class MessageInTestCase(CmmTestCase): +class MessageInTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase): self.assertFalse(message_in(self.message1, [self.message2])) -class MessageRenameTagsTestCase(CmmTestCase): +class MessageRenameTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase): self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -class MessageToStrTestCase(CmmTestCase): +class MessageToStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), Answer('This is an answer.'), diff --git a/tests/test_tags.py b/tests/test_tags.py index aa89a06..edd3c05 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -1,8 +1,8 @@ -from .test_main import CmmTestCase +import unittest from chatmastermind.tags import Tag, TagLine, TagError -class TestTag(CmmTestCase): +class TestTag(unittest.TestCase): def test_valid_tag(self) -> None: tag = Tag('mytag') self.assertEqual(tag, 'mytag') @@ -18,7 +18,7 @@ class TestTag(CmmTestCase): self.assertEqual(Tag.alternative_separators, [',']) -class TestTagLine(CmmTestCase): +class TestTagLine(unittest.TestCase): def test_valid_tagline(self) -> None: tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') -- 2.36.6 From 74a26b8c2f42ec59916e3e744c9db23d40ee6fa4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:23:29 +0200 Subject: [PATCH 040/121] setup: added 'ais' subfolder --- chatmastermind/ais/__init__.py | 0 setup.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 chatmastermind/ais/__init__.py diff --git a/chatmastermind/ais/__init__.py b/chatmastermind/ais/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 02d9ab1..8484629 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages(), + packages=find_packages() + ["chatmastermind.ais"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -32,7 +32,7 @@ setup( "openai", "PyYAML", "argcomplete", - "pytest" + "pytest", ], python_requires=">=3.9", test_suite="tests", -- 2.36.6 From 2df9dd64274a9f0c3214281d7e032ea8e131432a Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:43:23 +0200 Subject: [PATCH 041/121] cmm: removed all the old code and modules --- chatmastermind/api_client.py | 45 ------- chatmastermind/main.py | 104 ++------------- chatmastermind/storage.py | 121 ------------------ chatmastermind/utils.py | 80 ------------ tests/test_main.py | 236 ----------------------------------- 5 files changed, 12 insertions(+), 574 deletions(-) delete mode 100644 chatmastermind/api_client.py delete mode 100644 chatmastermind/storage.py delete mode 100644 chatmastermind/utils.py delete mode 100644 tests/test_main.py diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py deleted file mode 100644 index 2c4a094..0000000 --- a/chatmastermind/api_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import openai - -from .utils import ChatType -from .configuration import Config - - -def openai_api_key(api_key: str) -> None: - openai.api_key = api_key - - -def print_models() -> None: - """ - Print all models supported by the current AI. - """ - not_ready = [] - for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): - if engine['ready']: - print(engine['id']) - else: - not_ready.append(engine['id']) - if len(not_ready) > 0: - print('\nNot ready: ' + ', '.join(not_ready)) - - -def ai(chat: ChatType, - config: Config, - number: int - ) -> tuple[list[str], dict[str, int]]: - """ - Make AI request with the given chat history and configuration. - Return AI response and tokens used. - """ - response = openai.ChatCompletion.create( - model=config.openai.model, - messages=chat, - temperature=config.openai.temperature, - max_tokens=config.openai.max_tokens, - top_p=config.openai.top_p, - n=number, - frequency_penalty=config.openai.frequency_penalty, - presence_penalty=config.openai.presence_penalty) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/main.py b/chatmastermind/main.py index b10b97b..857bb5a 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,61 +6,19 @@ import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType -from .storage import save_answers, create_chat_hist -from .api_client import ai, openai_api_key, print_models -from .configuration import Config +from .configuration import Config, default_config_path from .chat import ChatDB from .message import Message, MessageFilter, MessageError, Question from .ai_factory import create_ai from .ai import AI, AIResponse -from itertools import zip_longest from typing import Any -default_config = '.config.yaml' - def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) -def create_question_with_hist(args: argparse.Namespace, - config: Config, - ) -> tuple[ChatType, str, list[str]]: - """ - Creates the "AI request", including the question and chat history as determined - by the specified tags. - """ - tags = args.or_tags or [] - xtags = args.exclude_tags or [] - otags = args.output_tags or [] - - if not args.source_code_only: - print_tag_args(tags, xtags, otags) - - question_parts = [] - question_list = args.question if args.question is not None else [] - source_list = args.source if args.source is not None else [] - - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: - question_parts.append(question) - elif source is not None: - with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") - - full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, xtags, config, - match_all_tags=True if args.and_tags else False, # FIXME - with_tags=False, - with_file=False) - return chat, full_question, tags - - def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. @@ -74,17 +32,12 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: # TODO: add renaming -def config_cmd(args: argparse.Namespace, config: Config) -> None: +def config_cmd(args: argparse.Namespace) -> None: """ Handler for the 'config' command. """ - if args.list_models: - print_models() - elif args.print_model: - print(config.openai.model) - elif args.model: - config.openai.model = args.model - config.to_file(args.config) + if args.create: + Config.create_default(Path(args.create)) def question_cmd(args: argparse.Namespace, config: Config) -> None: @@ -95,6 +48,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: db_path=Path(config.db)) # if it's a new question, create and store it immediately if args.ask or args.create: + # FIXME: add sources to the question message = Message(question=Question(args.question), tags=args.ouput_tags, # FIXME ai=args.ai, @@ -128,25 +82,6 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: pass -def ask_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'ask' command. - """ - if args.max_tokens: - config.openai.max_tokens = args.max_tokens - if args.temperature: - config.openai.temperature = args.temperature - if args.model: - config.openai.model = args.model - chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.source_code_only) - otags = args.output_tags or [] - answers, usage = ai(chat, config, args.num_answers) - save_answers(question, answers, tags, otags, config) - print("-" * terminal_width()) - print(f"Usage: {usage}") - - def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. @@ -182,7 +117,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-C', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -227,22 +162,6 @@ def create_parser() -> argparse.ArgumentParser: question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') - # 'ask' command parser - ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.", - aliases=['a']) - ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', - required=True) - ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') - # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], help="Print chat history.", @@ -278,7 +197,7 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') - config_group.add_argument('-M', '--model', help="Set model in the config file") + config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', @@ -297,11 +216,12 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = Config.from_file(args.config) - openai_api_key(config.openai.api_key) - - command.func(command, config) + if command.func == config_cmd: + command.func(command) + else: + config = Config.from_file(args.config) + command.func(command, config) return 0 diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py deleted file mode 100644 index 8b9ed97..0000000 --- a/chatmastermind/storage.py +++ /dev/null @@ -1,121 +0,0 @@ -import yaml -import io -import pathlib -from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config -from typing import Any, Optional - - -def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: - with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() - # also support tags separated by ',' (old format) - separator = ',' if ',' in tagline else ' ' - tags = [t.strip() for t in tagline.split(separator)] - if tags_only: - return {"tags": tags} - text = fd.read().strip().split('\n') - question_idx = text.index("=== QUESTION ===") + 1 - answer_idx = text.index("==== ANSWER ====") - question = "\n".join(text[question_idx:answer_idx]).strip() - answer = "\n".join(text[answer_idx + 1:]).strip() - return {"question": question, "answer": answer, "tags": tags, - "file": fname.name} - - -def dump_data(data: dict[str, Any]) -> str: - with io.StringIO() as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - return fd.getvalue() - - -def write_file(fname: str, data: dict[str, Any]) -> None: - with open(fname, "w") as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - - -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]], - config: Config - ) -> None: - wtags = otags or tags - num, inum = 0, 0 - next_fname = pathlib.Path(str(config.db)) / '.next' - try: - with open(next_fname, 'r') as f: - num = int(f.read()) - except Exception: - pass - for answer in answers: - num += 1 - inum += 1 - title = f'-- ANSWER {inum} ' - title_end = '-' * (terminal_width() - len(title)) - print(f'{title}{title_end}') - print(answer) - write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) - with open(next_fname, 'w') as f: - f.write(f'{num}') - - -def create_chat_hist(question: Optional[str], - tags: Optional[list[str]], - extags: Optional[list[str]], - config: Config, - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> ChatType: - chat: ChatType = [] - append_message(chat, 'system', str(config.system).strip()) - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data['file'] = file.name - elif file.suffix == '.txt': - data = read_file(file) - else: - continue - data_tags = set(data.get('tags', [])) - tags_match: bool - if match_all_tags: - tags_match = not tags or set(tags).issubset(data_tags) - else: - tags_match = not tags or bool(data_tags.intersection(tags)) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat, with_tags, with_file) - if question: - append_message(chat, 'user', question) - return chat - - -def get_tags(config: Config, prefix: Optional[str]) -> list[str]: - result = [] - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif file.suffix == '.txt': - data = read_file(file, tags_only=True) - else: - continue - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return result - - -def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: - return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py deleted file mode 100644 index e6eeb97..0000000 --- a/chatmastermind/utils.py +++ /dev/null @@ -1,80 +0,0 @@ -import shutil -from pprint import PrettyPrinter -from typing import Any - -ChatType = list[dict[str, str]] - - -def terminal_width() -> int: - return shutil.get_terminal_size().columns - - -def pp(*args: Any, **kwargs: Any) -> None: - return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) - - -def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: - """ - Prints the tags specified in the given args. - """ - printed_messages = [] - - if tags: - printed_messages.append(f"Tags: {' '.join(tags)}") - if extags: - printed_messages.append(f"Excluding tags: {' '.join(extags)}") - if otags: - printed_messages.append(f"Output tags: {' '.join(otags)}") - - if printed_messages: - print("\n".join(printed_messages)) - print() - - -def append_message(chat: ChatType, - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: dict[str, str], - chat: ChatType, - with_tags: bool = False, - with_file: bool = False - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - if with_tags: - tags = " ".join(message['tags']) - append_message(chat, 'tags', tags) - if with_file: - append_message(chat, 'file', message['file']) - - -def display_source_code(content: str) -> None: - try: - content_start = content.index('```') - content_end = content.rindex('```') - if content_start + 3 < content_end: - print(content[content_start + 3:content_end].strip()) - except ValueError: - pass - - -def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: - if dump: - pp(chat) - return - for message in chat: - text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 - if source_code: - display_source_code(message['content']) - continue - if message['role'] == 'user': - print('-' * terminal_width()) - if text_too_long: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 91e6462..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,236 +0,0 @@ -# import unittest -# import io -# import pathlib -# import argparse -# from chatmastermind.utils import terminal_width -# from chatmastermind.main import create_parser, ask_cmd -# from chatmastermind.api_client import ai -# from chatmastermind.configuration import Config -# from chatmastermind.storage import create_chat_hist, save_answers, dump_data -# from unittest import mock -# from unittest.mock import patch, MagicMock, Mock, ANY - - -# class CmmTestCase(unittest.TestCase): -# """ -# Base class for all cmm testcases. -# """ -# def dummy_config(self, db: str) -> Config: -# """ -# Creates a dummy configuration. -# """ -# return Config.from_dict( -# {'system': 'dummy_system', -# 'db': db, -# 'openai': {'api_key': 'dummy_key', -# 'model': 'dummy_model', -# 'max_tokens': 4000, -# 'temperature': 1.0, -# 'top_p': 1, -# 'frequency_penalty': 0, -# 'presence_penalty': 0}} -# ) -# -# -# class TestCreateChat(CmmTestCase): -# -# def setUp(self) -> None: -# self.config = self.dummy_config(db='test_files') -# self.question = "test question" -# self.tags = ['test_tag'] -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['test_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 4) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['other_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 2) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.side_effect = ( -# io.StringIO(dump_data({'question': 'test_content', -# 'answer': 'some answer', -# 'tags': ['test_tag']})), -# io.StringIO(dump_data({'question': 'test_content2', -# 'answer': 'some answer2', -# 'tags': ['test_tag2']})), -# ) -# -# test_chat = create_chat_hist(self.question, [], None, self.config) -# -# self.assertEqual(len(test_chat), 6) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': 'test_content2'}) -# self.assertEqual(test_chat[4], -# {'role': 'assistant', 'content': 'some answer2'}) -# -# -# class TestHandleQuestion(CmmTestCase): -# -# def setUp(self) -> None: -# self.question = "test question" -# self.args = argparse.Namespace( -# or_tags=['tag1'], -# and_tags=None, -# exclude_tags=['xtag1'], -# output_tags=None, -# question=[self.question], -# source=None, -# source_code_only=False, -# num_answers=3, -# max_tokens=None, -# temperature=None, -# model=None, -# match_all_tags=False, -# with_tags=False, -# with_file=False, -# ) -# self.config = self.dummy_config(db='test_files') -# -# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") -# @patch("chatmastermind.main.print_tag_args") -# @patch("chatmastermind.main.print_chat_hist") -# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) -# @patch("chatmastermind.utils.pp") -# @patch("builtins.print") -# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, -# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, -# mock_create_chat_hist: MagicMock) -> None: -# open_mock = MagicMock() -# with patch("chatmastermind.storage.open", open_mock): -# ask_cmd(self.args, self.config) -# mock_print_tag_args.assert_called_once_with(self.args.or_tags, -# self.args.exclude_tags, -# []) -# mock_create_chat_hist.assert_called_once_with(self.question, -# self.args.or_tags, -# self.args.exclude_tags, -# self.config, -# match_all_tags=False, -# with_tags=False, -# with_file=False) -# mock_print_chat_hist.assert_called_once_with('test_chat', -# False, -# self.args.source_code_only) -# mock_ai.assert_called_with("test_chat", -# self.config, -# self.args.num_answers) -# expected_calls = [] -# for num, answer in enumerate(mock_ai.return_value[0], start=1): -# title = f'-- ANSWER {num} ' -# title_end = '-' * (terminal_width() - len(title)) -# expected_calls.append(((f'{title}{title_end}',),)) -# expected_calls.append(((answer,),)) -# expected_calls.append((("-" * terminal_width(),),)) -# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) -# self.assertEqual(mock_print.call_args_list, expected_calls) -# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) -# open_mock.assert_has_calls(open_expected_calls, any_order=True) -# -# -# class TestSaveAnswers(CmmTestCase): -# @mock.patch('builtins.open') -# @mock.patch('chatmastermind.storage.print') -# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: -# question = "Test question?" -# answers = ["Answer 1", "Answer 2"] -# tags = ["tag1", "tag2"] -# otags = ["otag1", "otag2"] -# config = self.dummy_config(db='test_db') -# -# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ -# mock.patch('chatmastermind.storage.yaml.dump'), \ -# mock.patch('io.StringIO') as stringio_mock: -# stringio_instance = stringio_mock.return_value -# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] -# save_answers(question, answers, tags, otags, config) -# -# open_calls = [ -# mock.call(pathlib.Path('test_db/.next'), 'r'), -# mock.call(pathlib.Path('test_db/.next'), 'w'), -# ] -# open_mock.assert_has_calls(open_calls, any_order=True) -# -# -# class TestAI(CmmTestCase): -# -# @patch("openai.ChatCompletion.create") -# def test_ai(self, mock_create: MagicMock) -> None: -# mock_create.return_value = { -# 'choices': [ -# {'message': {'content': 'response_text_1'}}, -# {'message': {'content': 'response_text_2'}} -# ], -# 'usage': {'tokens': 10} -# } -# -# chat = [{"role": "system", "content": "hello ai"}] -# config = self.dummy_config(db='dummy') -# config.openai.model = "text-davinci-002" -# config.openai.max_tokens = 150 -# config.openai.temperature = 0.5 -# -# result = ai(chat, config, 2) -# expected_result = (['response_text_1', 'response_text_2'], -# {'tokens': 10}) -# self.assertEqual(result, expected_result) -# -# -# class TestCreateParser(CmmTestCase): -# def test_create_parser(self) -> None: -# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: -# mock_cmdparser = Mock() -# mock_add_subparsers.return_value = mock_cmdparser -# parser = create_parser() -# self.assertIsInstance(parser, argparse.ArgumentParser) -# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) -# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) -# self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From b1a23394fc741f5038a4ac0b9d6772448d077f9d Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 13:31:01 +0200 Subject: [PATCH 042/121] cmm: splitted commands into separate modules (and more cleanup) --- chatmastermind/commands/config.py | 11 +++ chatmastermind/commands/hist.py | 23 +++++ chatmastermind/commands/print.py | 19 ++++ chatmastermind/commands/question.py | 57 ++++++++++++ chatmastermind/commands/tags.py | 17 ++++ chatmastermind/main.py | 131 +++++----------------------- setup.py | 2 +- tests/test_ai_factory.py | 48 ++++++++++ 8 files changed, 196 insertions(+), 112 deletions(-) create mode 100644 chatmastermind/commands/config.py create mode 100644 chatmastermind/commands/hist.py create mode 100644 chatmastermind/commands/print.py create mode 100644 chatmastermind/commands/question.py create mode 100644 chatmastermind/commands/tags.py create mode 100644 tests/test_ai_factory.py diff --git a/chatmastermind/commands/config.py b/chatmastermind/commands/config.py new file mode 100644 index 0000000..262164c --- /dev/null +++ b/chatmastermind/commands/config.py @@ -0,0 +1,11 @@ +import argparse +from pathlib import Path +from ..configuration import Config + + +def config_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'config' command. + """ + if args.create: + Config.create_default(Path(args.create)) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py new file mode 100644 index 0000000..88ed3be --- /dev/null +++ b/chatmastermind/commands/hist.py @@ -0,0 +1,23 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import MessageFilter + + +def hist_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'hist' command. + """ + + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py new file mode 100644 index 0000000..51e76f8 --- /dev/null +++ b/chatmastermind/commands/print.py @@ -0,0 +1,19 @@ +import sys +import argparse +from pathlib import Path +from ..configuration import Config +from ..message import Message, MessageError + + +def print_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'print' command. + """ + fname = Path(args.file) + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") + sys.exit(1) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py new file mode 100644 index 0000000..9c56ced --- /dev/null +++ b/chatmastermind/commands/question.py @@ -0,0 +1,57 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import Message, Question +from ..ai_factory import create_ai +from ..ai import AI, AIResponse + + +def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: + """ + Creates (and writes) a new message from the given arguments. + """ + # FIXME: add sources to the question + message = Message(question=Question(args.question), + tags=args.output_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + return message + + +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = create_message(chat, args) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass diff --git a/chatmastermind/commands/tags.py b/chatmastermind/commands/tags.py new file mode 100644 index 0000000..2906a5b --- /dev/null +++ b/chatmastermind/commands/tags.py @@ -0,0 +1,17 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB + + +def tags_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'tags' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + if args.list: + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 857bb5a..88121b4 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,12 +6,14 @@ import sys import argcomplete import argparse from pathlib import Path -from .configuration import Config, default_config_path -from .chat import ChatDB -from .message import Message, MessageFilter, MessageError, Question -from .ai_factory import create_ai -from .ai import AI, AIResponse from typing import Any +from .configuration import Config, default_config_path +from .message import Message +from .commands.question import question_cmd +from .commands.tags import tags_cmd +from .commands.config import config_cmd +from .commands.hist import hist_cmd +from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: @@ -19,101 +21,6 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) -def tags_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'tags' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - if args.list: - tags_freq = chat.tags_frequency(args.prefix, args.contain) - for tag, freq in tags_freq.items(): - print(f"- {tag}: {freq}") - # TODO: add renaming - - -def config_cmd(args: argparse.Namespace) -> None: - """ - Handler for the 'config' command. - """ - if args.create: - Config.create_default(Path(args.create)) - - -def question_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'question' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - # if it's a new question, create and store it immediately - if args.ask or args.create: - # FIXME: add sources to the question - message = Message(question=Question(args.question), - tags=args.ouput_tags, # FIXME - ai=args.ai, - model=args.model) - chat.add_to_cache([message]) - if args.create: - return - - # create the correct AI instance - ai: AI = create_ai(args, config) - if args.ask: - response: AIResponse = ai.request(message, - chat, - args.num_answers, # FIXME - args.otags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass - elif args.repeat: - lmessage = chat.latest_message() - assert lmessage - # TODO: repeat either the last question or the - # one(s) given in 'args.repeat' (overwrite - # existing ones if 'args.overwrite' is True) - pass - elif args.process: - # TODO: process either all questions without an - # answer or the one(s) given in 'args.process' - pass - - -def hist_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'hist' command. - """ - - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags, - question_contains=args.question, - answer_contains=args.answer) - chat = ChatDB.from_dir(Path('.'), - Path(config.db), - mfilter=mfilter) - chat.print(args.source_code_only, - args.with_tags, - args.with_files) - - -def print_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'print' command. - """ - fname = Path(args.file) - try: - message = Message.from_file(fname) - if message: - print(message.to_str(source_code_only=args.source_code_only)) - except MessageError: - print(f"File is not a valid message: {args.file}") - sys.exit(1) - - def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") @@ -128,20 +35,28 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', - help='List of tag names (one must match)', metavar='OTAGS') + help='List of tags (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', - help='List of tag names (all must match)', metavar='ATAGS') + help='List of tags (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', - help='List of tag names to exclude', metavar='XTAGS') + help='List of tags to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OUTTAGS') + help='List of output tags (default: use input tags)', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # a parent parser for all commands that support AI configuration + ai_parser = argparse.ArgumentParser(add_help=False) + ai_parser.add_argument('-A', '--AI', help='AI ID to use') + ai_parser.add_argument('-M', '--model', help='Model to use') + ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) + ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) + ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) + # 'question' command parser - question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], help="ask, create and process questions.", aliases=['q']) question_cmd_parser.set_defaults(func=question_cmd) @@ -152,12 +67,6 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - question_cmd_parser.add_argument('-A', '--AI', help='AI to use') - question_cmd_parser.add_argument('-M', '--model', help='Model to use') - question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') diff --git a/setup.py b/setup.py index 8484629..a311605 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages() + ["chatmastermind.ais"], + packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py new file mode 100644 index 0000000..d63970e --- /dev/null +++ b/tests/test_ai_factory.py @@ -0,0 +1,48 @@ +import argparse +import unittest +from unittest.mock import MagicMock +from chatmastermind.ai_factory import create_ai +from chatmastermind.configuration import Config +from chatmastermind.ai import AIError +from chatmastermind.ais.openai import OpenAI + + +class TestCreateAI(unittest.TestCase): + def setUp(self) -> None: + self.args = MagicMock(spec=argparse.Namespace) + self.args.ai = 'default' + self.args.model = None + self.args.max_tokens = None + self.args.temperature = None + + def test_create_ai_from_args(self) -> None: + # Create an AI with the default configuration + config = Config() + self.args.ai = 'default' + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_ai_from_default(self) -> None: + self.args.ai = None + # Create an AI with the default configuration + config = Config() + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_empty_ai_error(self) -> None: + self.args.ai = None + # Create Config with empty AIs + config = Config() + config.ais = {} + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) + + def test_create_unsupported_ai_error(self) -> None: + # Mock argparse.Namespace with ai='invalid_ai' + self.args.ai = 'invalid_ai' + # Create default Config + config = Config() + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) -- 2.36.6 From eaa399bcb90fa90ec5c2e6bc6e1fb100f8ade338 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:52:03 +0200 Subject: [PATCH 043/121] configuration et al: implemented new Config format --- chatmastermind/ai.py | 13 ++-- chatmastermind/ai_factory.py | 29 ++++++-- chatmastermind/ais/openai.py | 9 +-- chatmastermind/configuration.py | 119 ++++++++++++++++++++++++++------ 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index 4a8b914..e94de8e 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -33,18 +33,23 @@ class AI(Protocol): The base class for AI clients. """ + ID: str name: str config: AIConfig def request(self, question: Message, - context: Chat, + chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ - Make an AI request, asking the given question with the given - context (i. e. chat history). The nr. of requested answers - corresponds to the nr. of messages in the 'AIResponse'. + Make an AI request. Parameters: + * question: the question to ask + * chat: the chat history to be added as context + * num_answers: nr. of requested answers (corresponds + to the nr. of messages in the 'AIResponse') + * otags: the output tags, i. e. the tags that all + returned messages should contain """ raise NotImplementedError diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c90366b..c4a063a 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -3,18 +3,35 @@ Creates different AI instances, based on the given configuration. """ import argparse -from .configuration import Config +from typing import cast +from .configuration import Config, OpenAIConfig, default_ai_ID from .ai import AI, AIError from .ais.openai import OpenAI def create_ai(args: argparse.Namespace, config: Config) -> AI: """ - Creates an AI subclass instance from the given args and configuration. + Creates an AI subclass instance from the given arguments + and configuration file. """ - if args.ai == 'openai': - # FIXME: create actual 'OpenAIConfig' and set values from 'args' - # FIXME: use actual name from config - return OpenAI("openai", config.openai) + if args.ai: + try: + ai_conf = config.ais[args.ai] + except KeyError: + raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + elif default_ai_ID in config.ais: + ai_conf = config.ais[default_ai_ID] + else: + raise AIError("No AI name given and no default exists") + + if ai_conf.name == 'openai': + ai = OpenAI(cast(OpenAIConfig, ai_conf)) + if args.model: + ai.config.model = args.model + if args.max_tokens: + ai.config.max_tokens = args.max_tokens + if args.temperature: + ai.config.temperature = args.temperature + return ai else: raise AIError(f"AI '{args.ai}' is not supported") diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 74438b8..14ce33f 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -17,9 +17,11 @@ class OpenAI(AI): The OpenAI AI client. """ - def __init__(self, name: str, config: OpenAIConfig) -> None: - self.name = name + def __init__(self, config: OpenAIConfig) -> None: + self.ID = config.ID + self.name = config.name self.config = config + openai.api_key = config.api_key def request(self, question: Message, @@ -31,8 +33,7 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - # FIXME: use real 'system' message (store in OpenAIConfig) - oai_chat = self.openai_chat(chat, "system", question) + oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0780604..d82f913 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,17 +1,40 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional, ClassVar +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] +default_ai_ID: str = 'default' +default_config_path = '.config.yaml' + + +class ConfigError(Exception): + pass + + @dataclass class AIConfig: """ The base class of all AI configurations. """ - name: str + # the name of the AI the config class represents + # -> it's a class variable and thus not part of the + # dataclass constructor + name: ClassVar[str] + # a user-defined ID for an AI configuration entry + ID: str + + # the name must not be changed + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name': + raise AttributeError("'{name}' is not allowed to be changed") + else: + super().__setattr__(name, value) @dataclass @@ -19,21 +42,27 @@ class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + name: ClassVar[str] = 'openai' + + # all members have default values, so we can easily create + # a default configuration + ID: str = 'default' + api_key: str = '0123456789' + system: str = 'You are an assistant' + model: str = 'gpt-3.5-turbo-16k' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: """ Create OpenAIConfig from a dict. """ - return cls( - name='OpenAI', + res = cls( + system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -42,6 +71,30 @@ class OpenAIConfig(AIConfig): frequency_penalty=float(source['frequency_penalty']), presence_penalty=float(source['presence_penalty']) ) + # overwrite default ID if provided + if 'ID' in source: + res.ID = source['ID'] + return res + + +def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given name. + """ + if name.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"AI '{name}' is not supported") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais} @dataclass @@ -49,30 +102,52 @@ class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create Config from a dict. + Create Config from a dict (with the same format as the config file). """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for ID, conf in source['ais'].items(): + # add the AI ID to the config (for easy internal access) + conf['ID'] = ID + ai_conf = ai_config_instance(conf['name'], conf) + ais[ID] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f, sort_keys=False) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for conf in data['ais'].values(): + del (conf['ID']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) def as_dict(self) -> dict[str, Any]: - return asdict(self) + res = asdict(self) + for ID, conf in res['ais'].items(): + conf.update({'name': self.ais[ID].name}) + return res -- 2.36.6 From 76f23733972f6add31353ca987303e590c3e5b76 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 10:40:22 +0200 Subject: [PATCH 044/121] configuration: added tests --- chatmastermind/configuration.py | 2 +- tests/test_configuration.py | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 tests/test_configuration.py diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index d82f913..398fa03 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -87,7 +87,7 @@ def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> else: return OpenAIConfig.from_dict(conf_dict) else: - raise ConfigError(f"AI '{name}' is not supported") + raise ConfigError(f"Unknown AI '{name}'") def create_default_ai_configs() -> dict[str, AIConfig]: diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 0000000..f3f9a98 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,160 @@ +import os +import unittest +import yaml +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import cast +from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config + + +class TestAIConfigInstance(unittest.TestCase): + def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: + ai_config = cast(OpenAIConfig, ai_config_instance('openai')) + ai_reference = OpenAIConfig() + self.assertEqual(ai_config.ID, ai_reference.ID) + self.assertEqual(ai_config.name, ai_reference.name) + self.assertEqual(ai_config.api_key, ai_reference.api_key) + self.assertEqual(ai_config.system, ai_reference.system) + self.assertEqual(ai_config.model, ai_reference.model) + self.assertEqual(ai_config.temperature, ai_reference.temperature) + self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) + self.assertEqual(ai_config.top_p, ai_reference.top_p) + self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) + self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) + + def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: + conf_dict = { + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) + self.assertEqual(ai_config.system, 'Custom system') + self.assertEqual(ai_config.api_key, '9876543210') + self.assertEqual(ai_config.model, 'custom_model') + self.assertEqual(ai_config.max_tokens, 5000) + self.assertAlmostEqual(ai_config.temperature, 0.5) + self.assertAlmostEqual(ai_config.top_p, 0.8) + self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) + self.assertAlmostEqual(ai_config.presence_penalty, 0.2) + + def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: + with self.assertRaises(ConfigError): + ai_config_instance('invalid_name') + + +class TestConfig(unittest.TestCase): + def setUp(self) -> None: + self.test_file = NamedTemporaryFile(delete=False) + + def tearDown(self) -> None: + os.remove(self.test_file.name) + + def test_from_dict_should_create_config_from_dict(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + config = Config.from_dict(source_dict) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertEqual(config.ais['default'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + # check that 'ID' has been added + self.assertEqual(config.ais['default'].ID, 'default') + + def test_create_default_should_create_default_config(self) -> None: + Config.create_default(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + default_config = yaml.load(f, Loader=yaml.FullLoader) + config_reference = Config() + self.assertEqual(default_config['db'], config_reference.db) + + def test_from_file_should_load_config_from_file(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + config = Config.from_file(self.test_file.name) + self.assertIsInstance(config, Config) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertIsInstance(config.ais['default'], AIConfig) + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + + def test_to_file_should_save_config_to_file(self) -> None: + config = Config( + db='./test_db/', + ais={ + 'default': OpenAIConfig( + ID='default', + system='Custom system', + api_key='9876543210', + model='custom_model', + max_tokens=5000, + temperature=0.5, + top_p=0.8, + frequency_penalty=0.7, + presence_penalty=0.2 + ) + } + ) + config.to_file(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['db'], './test_db/') + self.assertEqual(len(saved_config['ais']), 1) + self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + + def test_from_file_error_unknown_ai(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'foobla', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + with self.assertRaises(ConfigError): + Config.from_file(self.test_file.name) -- 2.36.6 From c0b7d17587f45a4b28a8083cd067b7b427816627 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:51:17 +0200 Subject: [PATCH 045/121] question_cmd: fixes --- chatmastermind/commands/question.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 9c56ced..1709a3c 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path +from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB from ..message import Message, Question @@ -11,8 +12,26 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Creates (and writes) a new message from the given arguments. """ - # FIXME: add sources to the question - message = Message(question=Question(args.question), + question_parts = [] + question_list = args.question if args.question is not None else [] + source_list = args.source if args.source is not None else [] + + # FIXME: don't surround all sourced files with ``` + # -> do it only if '--source-code-only' is True and no source code + # could be extracted from that file + for question, source in zip_longest(question_list, source_list, fillvalue=None): + if question is not None and source is not None: + with open(source) as r: + question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") + elif question is not None: + question_parts.append(question) + elif source is not None: + with open(source) as r: + question_parts.append(f"```\n{r.read().strip()}\n```") + + full_question = '\n\n'.join(question_parts) + + message = Message(question=Question(full_question), tags=args.output_tags, # FIXME ai=args.ai, model=args.model) -- 2.36.6 From 5fb5dde550539d0612d9ce3b9ae223bcebdef6a2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:31:30 +0200 Subject: [PATCH 046/121] question cmd: added tests --- tests/test_question_cmd.py | 111 +++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_question_cmd.py diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py new file mode 100644 index 0000000..96b2fdf --- /dev/null +++ b/tests/test_question_cmd.py @@ -0,0 +1,111 @@ +import os +import unittest +import argparse +import tempfile +from pathlib import Path +from unittest.mock import MagicMock +from chatmastermind.commands.question import create_message +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB + + +class TestMessageCreate(unittest.TestCase): + """ + Test if messages created by the 'question' command have + the correct format. + """ + def setUp(self) -> None: + # create ChatDB structure + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), + db_path=Path(self.db_path.name)) + # create arguments mock + self.args = MagicMock(spec=argparse.Namespace) + self.args.source = None + self.args.source_code_only = False + self.args.ai = None + self.args.model = None + self.args.output_tags = None + # create some files for sourcing + self.source_file1 = tempfile.NamedTemporaryFile(delete=False) + self.source_file1_content = """This is just text. +No source code. +Nope. Go look elsewhere!""" + with open(self.source_file1.name, 'w') as f: + f.write(self.source_file1_content) + self.source_file2 = tempfile.NamedTemporaryFile(delete=False) + self.source_file2_content = """This is just text. +``` +This is embedded source code. +``` +And some text again.""" + with open(self.source_file2.name, 'w') as f: + f.write(self.source_file2_content) + self.source_file3 = tempfile.NamedTemporaryFile(delete=False) + self.source_file3_content = """This is all source code. +Yes, really. +Language is called 'brainfart'.""" + with open(self.source_file3.name, 'w') as f: + f.write(self.source_file3_content) + + def tearDown(self) -> None: + os.remove(self.source_file1.name) + os.remove(self.source_file2.name) + os.remove(self.source_file3.name) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return list(Path(tmp_dir.name).glob('*.[ty]*')) + + def test_message_file_created(self) -> None: + self.args.question = ["What is this?"] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + create_message(self.chat, self.args) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + message = Message.from_file(cache_dir_files[0]) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] + + def test_single_question(self) -> None: + self.args.question = ["What is this?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) + self.assertEqual(len(message.question.source_code()), 0) + + def test_multipart_question(self) -> None: + self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("""What is this + +'bard' thing? + +Is it good?""")) + + def test_single_question_with_text_only_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file1.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains no source code + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file1_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file2_content}""")) -- 2.36.6 From 7cf62c54efd0a6ca6a23eacbd7a7bd716962f97f Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:16:17 +0200 Subject: [PATCH 047/121] Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. --- chatmastermind/commands/question.py | 22 ++++++++++++---------- chatmastermind/main.py | 5 ++--- tests/test_question_cmd.py | 22 ++++++++++++++++++---- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 1709a3c..818b1de 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -15,19 +15,21 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: question_parts = [] question_list = args.question if args.question is not None else [] source_list = args.source if args.source is not None else [] + code_list = args.source_code if args.source_code is not None else [] - # FIXME: don't surround all sourced files with ``` - # -> do it only if '--source-code-only' is True and no source code - # could be extracted from that file - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: + for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + if question is not None and len(question.strip()) > 0: question_parts.append(question) - elif source is not None: + if source is not None and len(source) > 0: with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + if code is not None and len(code) > 0: + with open(code) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 88121b4..f7163ab 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,9 +67,8 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 96b2fdf..06cc527 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -23,7 +23,7 @@ class TestMessageCreate(unittest.TestCase): # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source = None - self.args.source_code_only = False + self.args.source_code = None self.args.ai = None self.args.model = None self.args.output_tags = None @@ -94,11 +94,11 @@ Is it good?""")) # source file contains no source code # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_embedded_source_source(self) -> None: self.args.question = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) @@ -106,6 +106,20 @@ Is it good?""")) # source file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file2_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source_code = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question(f"""What is this? + +``` +{self.source_file2_content} +```""")) -- 2.36.6 From d22877a0f1a206ed7697fe4d773ef576bbf30aa3 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:38:40 +0200 Subject: [PATCH 048/121] Port print arguments -q/-a/-S from main to restructuring. --- chatmastermind/commands/print.py | 10 +++++++++- chatmastermind/main.py | 6 ++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py index 51e76f8..3d2b990 100644 --- a/chatmastermind/commands/print.py +++ b/chatmastermind/commands/print.py @@ -13,7 +13,15 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: try: message = Message.from_file(fname) if message: - print(message.to_str(source_code_only=args.source_code_only)) + if args.question: + print(message.question) + elif args.answer: + print(message.answer) + elif message.answer and args.only_source_code: + for code in message.answer.source_code(): + print(code) + else: + print(message.to_str()) except MessageError: print(f"File is not a valid message: {args.file}") sys.exit(1) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index f7163ab..eadb095 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -113,8 +113,10 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', - action='store_true') + print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() + print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') + print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') + print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') argcomplete.autocomplete(parser) return parser -- 2.36.6 From 39b518a8a60f335fb952995ee151440f899c7f85 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 16:05:27 +0200 Subject: [PATCH 049/121] Small fixes. --- chatmastermind/ai_factory.py | 8 ++++---- chatmastermind/commands/question.py | 6 +++--- tests/test_ai_factory.py | 10 +++++----- tests/test_question_cmd.py | 14 +++++++------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c4a063a..bc4583c 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -14,11 +14,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: Creates an AI subclass instance from the given arguments and configuration file. """ - if args.ai: + if args.AI: try: - ai_conf = config.ais[args.ai] + ai_conf = config.ais[args.AI] except KeyError: - raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") elif default_ai_ID in config.ais: ai_conf = config.ais[default_ai_ID] else: @@ -34,4 +34,4 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: ai.config.temperature = args.temperature return ai else: - raise AIError(f"AI '{args.ai}' is not supported") + raise AIError(f"AI '{args.AI}' is not supported") diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 818b1de..90b782b 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -13,7 +13,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: Creates (and writes) a new message from the given arguments. """ question_parts = [] - question_list = args.question if args.question is not None else [] + question_list = args.ask if args.ask is not None else [] source_list = args.source if args.source is not None else [] code_list = args.source_code if args.source_code is not None else [] @@ -35,7 +35,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: message = Message(question=Question(full_question), tags=args.output_tags, # FIXME - ai=args.ai, + ai=args.AI, model=args.model) chat.add_to_cache([message]) return message @@ -59,7 +59,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME - args.otags) # FIXME + args.output_tags) # FIXME assert response # TODO: # * add answer to the message above (and create diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d63970e..d00b319 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.ai = 'default' + self.args.AI = 'default' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,19 +18,19 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.ai = 'default' + self.args.AI = 'default' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_ai_from_default(self) -> None: - self.args.ai = None + self.args.AI = None # Create an AI with the default configuration config = Config() ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_empty_ai_error(self) -> None: - self.args.ai = None + self.args.AI = None # Create Config with empty AIs config = Config() config.ais = {} @@ -40,7 +40,7 @@ class TestCreateAI(unittest.TestCase): def test_create_unsupported_ai_error(self) -> None: # Mock argparse.Namespace with ai='invalid_ai' - self.args.ai = 'invalid_ai' + self.args.AI = 'invalid_ai' # Create default Config config = Config() # Call create_ai function and assert that it raises AIError diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 06cc527..aa0dc25 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -24,7 +24,7 @@ class TestMessageCreate(unittest.TestCase): self.args = MagicMock(spec=argparse.Namespace) self.args.source = None self.args.source_code = None - self.args.ai = None + self.args.AI = None self.args.model = None self.args.output_tags = None # create some files for sourcing @@ -59,7 +59,7 @@ Language is called 'brainfart'.""" return list(Path(tmp_dir.name).glob('*.[ty]*')) def test_message_file_created(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) @@ -70,14 +70,14 @@ Language is called 'brainfart'.""" self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: - self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this @@ -87,7 +87,7 @@ Language is called 'brainfart'.""" Is it good?""")) def test_single_question_with_text_only_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -99,7 +99,7 @@ Is it good?""")) {self.source_file1_content}""")) def test_single_question_with_embedded_source_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -111,7 +111,7 @@ Is it good?""")) {self.source_file2_content}""")) def test_single_question_with_embedded_source_code_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) -- 2.36.6 From 53582a71239e01b0c2a6cef6a0529e2a082d3118 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 18:28:10 +0200 Subject: [PATCH 050/121] question_cmd: fixed source code extraction and added a testcase --- chatmastermind/commands/question.py | 17 +++++-- chatmastermind/main.py | 2 +- chatmastermind/message.py | 2 +- tests/test_question_cmd.py | 79 +++++++++++++++++++++-------- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 90b782b..756a051 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question +from ..message import Message, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ question_parts = [] question_list = args.ask if args.ask is not None else [] - source_list = args.source if args.source is not None else [] - code_list = args.source_code if args.source_code is not None else [] + text_files = args.source_text if args.source_text is not None else [] + code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) if source is not None and len(source) > 0: @@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code is not None and len(code) > 0: with open(code) as r: content = r.read().strip() - if len(content) > 0: + if len(content) == 0: + continue + # try to extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + # if there's none, add the whole file + else: question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index eadb095..99aca09 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,7 +67,7 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 35de3b9..7107c13 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -414,7 +414,7 @@ class Message(): return '\n'.join(output) def __str__(self) -> str: - return self.to_str(False, False, False) + return self.to_str(True, True, False) def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index aa0dc25..40ea4d8 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -22,18 +22,19 @@ class TestMessageCreate(unittest.TestCase): db_path=Path(self.db_path.name)) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) - self.args.source = None + self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None - # create some files for sourcing + # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) + # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` @@ -42,12 +43,26 @@ This is embedded source code. And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) + # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) + # File 4 : two source code blocks + self.source_file4 = tempfile.NamedTemporaryFile(delete=False) + self.source_file4_content = """This is just text. +``` +This is embedded source code. +``` +And some text again. +``` +This is embedded source code. +``` +Aaaand again some text.""" + with open(self.source_file4.name, 'w') as f: + f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) @@ -86,40 +101,62 @@ Language is called 'brainfart'.""" Is it good?""")) - def test_single_question_with_text_only_source(self) -> None: + def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file1.name}"] + self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains no source code + # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_source(self) -> None: - self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file2.name}"] - message = create_message(self.chat, self.args) - self.assertIsInstance(message, Message) - # source file contains 1 source code block - # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question(f"""What is this? - -{self.source_file2_content}""")) - - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains 1 source code block + # file contains 1 source code block # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` +""")) + + def test_single_question_with_code_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file3.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file is complete source code + self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` -{self.source_file2_content} +{self.source_file3_content} ```""")) + + def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file4.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 2 source code blocks + # -> expect them in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` + + +``` +This is embedded source code. +``` +""")) -- 2.36.6 From 1e3bfdd67fc13437f6f7da72468572e24f9e9818 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:39:00 +0200 Subject: [PATCH 051/121] chat: added 'update_messages()' function and test --- chatmastermind/chat.py | 16 ++++++++++++++++ tests/test_chat.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4e8fb20..ddabb56 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -386,3 +386,19 @@ class ChatDB(Chat): msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): m.to_file() + + def update_messages(self, messages: list[Message], write: bool = True) -> None: + """ + Update existing messages. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. Only accepts + existing messages. + """ + if any(not message_in(m, self.messages) for m in messages): + raise ChatError("Can't update messages that are not in the internal list") + # remove old versions and add new ones + self.messages = [m for m in self.messages if not message_in(m, messages)] + self.messages += messages + self.sort() + # write the UPDATED messages if requested + if write: + self.write_messages(messages) diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e4aa8c..ed630a4 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -440,3 +440,31 @@ class TestChatDB(unittest.TestCase): cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + + def test_chat_db_update_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + message = chat_db.messages[0] + message.answer = Answer("New answer") + # update message without writing + chat_db.update_messages([message], write=False) + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # re-read the message and check for old content + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) + # now check with writing (message should be overwritten) + chat_db.update_messages([message], write=True) + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # test without file_path -> expect error + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.update_messages([message1]) -- 2.36.6 From dd3d3ffc82abd110b66a4e0af6ce2d990d702b7c Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:18:14 +0200 Subject: [PATCH 052/121] chat: added check for existing files when creating new filenames --- chatmastermind/chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ddabb56..7c4dd35 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -62,7 +62,10 @@ def make_file_path(dir_path: Path, Create a file_path for the given directory using the given file_suffix and ID generator function. """ - return dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + while file_path.exists(): + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + return file_path def write_dir(dir_path: Path, -- 2.36.6 From cf50818f28f389e2a66ad919f78c83aa41933dfe Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:52:07 +0200 Subject: [PATCH 053/121] question_cmd: fixed '--ask' command --- chatmastermind/ai.py | 6 ++++++ chatmastermind/ais/openai.py | 19 ++++++++++++++----- chatmastermind/commands/question.py | 15 ++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index e94de8e..b97b5f1 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -66,3 +66,9 @@ class AI(Protocol): and is not implemented for all AIs. """ raise NotImplementedError + + def print(self) -> None: + """ + Print some info about the current AI, like system message. + """ + pass diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 14ce33f..1db4d20 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -43,16 +43,20 @@ class OpenAI(AI): n=num_answers, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - answers: list[Message] = [] - for choice in response['choices']: # type: ignore + question.answer = Answer(response['choices'][0]['message']['content']) + question.tags = otags + question.ai = self.name + question.model = self.config.model + answers: list[Message] = [question] + for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, ai=self.name, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt'], - response['usage']['completion'], - response['usage']['total'])) + return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], + response['usage']['completion_tokens'], + response['usage']['total_tokens'])) def models(self) -> list[str]: """ @@ -95,3 +99,8 @@ class OpenAI(AI): def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError + + def print(self) -> None: + print(f"MODEL: {self.config.model}") + print("=== SYSTEM ===") + print(self.config.system) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 756a051..fdabd62 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: # create the correct AI instance ai: AI = create_ai(args, config) if args.ask: + ai.print() + chat.print(paged=False) response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME args.output_tags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass + chat.update_messages([response.messages[0]]) + chat.add_to_cache(response.messages[1:]) + for idx, msg in enumerate(response.messages): + print(f"=== ANSWER {idx+1} ===") + print(msg.answer) + if response.tokens: + print("===============") + print(response.tokens) elif args.repeat: lmessage = chat.latest_message() assert lmessage -- 2.36.6 From 533ee1c1a94d40f0ed9fa4ed70f09947010ba65b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:54:17 +0200 Subject: [PATCH 054/121] question_cmd: added message filtering by tags --- chatmastermind/commands/question.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index fdabd62..f439447 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question, source_code +from ..message import Message, MessageFilter, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -52,8 +52,12 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags) chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) + db_path=Path(config.db), + mfilter=mfilter) # if it's a new question, create and store it immediately if args.ask or args.create: message = create_message(chat, args) @@ -77,14 +81,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: if response.tokens: print("===============") print(response.tokens) - elif args.repeat: + elif args.repeat is not None: lmessage = chat.latest_message() assert lmessage # TODO: repeat either the last question or the # one(s) given in 'args.repeat' (overwrite # existing ones if 'args.overwrite' is True) pass - elif args.process: + elif args.process is not None: # TODO: process either all questions without an # answer or the one(s) given in 'args.process' pass -- 2.36.6 From b48667bfa0347e1237bb555fd5b2fb2e7514c621 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:55:47 +0200 Subject: [PATCH 055/121] openai: stores AI.ID instead of AI.name in message --- chatmastermind/ais/openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 1db4d20..a388a7a 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -45,14 +45,14 @@ class OpenAI(AI): presence_penalty=self.config.presence_penalty) question.answer = Answer(response['choices'][0]['message']['content']) question.tags = otags - question.ai = self.name + question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, - ai=self.name, + ai=self.ID, model=self.config.model)) return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], response['usage']['completion_tokens'], -- 2.36.6 From eca44b14cb9b810c62dc896c30b6994a3eb0f757 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:24:20 +0200 Subject: [PATCH 056/121] message: fixed matching with empty tag sets --- chatmastermind/message.py | 4 ++-- tests/test_chat.py | 22 ++++++++++++++++++++-- tests/test_message.py | 6 ++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 7107c13..df59ed6 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -312,7 +312,7 @@ class Message(): mfilter.tags_not if mfilter else None) else: message = cls.__from_file_yaml(file_path) - if message and (not mfilter or (mfilter and message.match(mfilter))): + if message and (mfilter is None or message.match(mfilter)): return message else: return None @@ -508,7 +508,7 @@ class Message(): Return True if all attributes match, else False. """ mytags = self.tags or set() - if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 diff --git a/tests/test_chat.py b/tests/test_chat.py index ed630a4..1916a2b 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_filter(self) -> None: + def test_chat_db_from_dir_filter_tags(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or={Tag('tag1')})) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + + def test_chat_db_from_dir_filter_tags_empty(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or=set(), + tags_and=set(), + tags_not=set())) + self.assertEqual(len(chat_db.messages), 0) + + def test_chat_db_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messges(self) -> None: + def test_chat_db_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, diff --git a/tests/test_message.py b/tests/test_message.py index 57d5982..1f440df 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -300,6 +300,12 @@ This is a question. MessageFilter(tags_or={Tag('tag1')})) self.assertIsNone(message) + def test_from_file_txt_empty_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or=set(), + tags_and=set())) + self.assertIsNone(message) + def test_from_file_txt_no_tags_match_tags_not(self) -> None: message = Message.from_file(self.file_path_min, MessageFilter(tags_not={Tag('tag1')})) -- 2.36.6 From 6f71a2ff691105b25593ae00d5053443a1ab768b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:56:50 +0200 Subject: [PATCH 057/121] message: to_file() now uses intermediate temporary file --- chatmastermind/message.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index df59ed6..64929a3 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,6 +3,8 @@ Module implementing message related functions and classes. """ import pathlib import yaml +import tempfile +import shutil from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -445,16 +447,18 @@ class Message(): * Answer.txt_header * Answer """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) if self.tags: - fd.write(f'{TagLine.from_set(self.tags)}\n') + temp_fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: - fd.write(f'{AILine.from_ai(self.ai)}\n') + temp_fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: - fd.write(f'{ModelLine.from_model(self.model)}\n') - fd.write(f'{Question.txt_header}\n{self.question}\n') + temp_fd.write(f'{ModelLine.from_model(self.model)}\n') + temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ @@ -466,7 +470,8 @@ class Message(): * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) @@ -476,7 +481,8 @@ class Message(): data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) - yaml.dump(data, fd, sort_keys=False) + yaml.dump(data, temp_fd, sort_keys=False) + shutil.move(temp_file_path, file_path) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ -- 2.36.6 From 59b851650ad59ea61df4774c33ed7e624e98e13b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:25:33 +0200 Subject: [PATCH 058/121] question_cmd: when no tags are specified, no tags are selected --- chatmastermind/commands/question.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index f439447..4936d8f 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -52,9 +52,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags) + mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), + tags_and=args.and_tags if args.and_tags is not None else set(), + tags_not=args.exclude_tags if args.exclude_tags is not None else set()) chat = ChatDB.from_dir(cache_path=Path('.'), db_path=Path(config.db), mfilter=mfilter) -- 2.36.6 From c143c001f905dda3154a153ad8ccaed1bc24a5f4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:37:06 +0200 Subject: [PATCH 059/121] configuration: improved config file format --- chatmastermind/configuration.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 398fa03..08f6cbe 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -17,6 +17,18 @@ class ConfigError(Exception): pass +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + @dataclass class AIConfig: """ @@ -48,13 +60,13 @@ class OpenAIConfig(AIConfig): # a default configuration ID: str = 'default' api_key: str = '0123456789' - system: str = 'You are an assistant' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 max_tokens: int = 4000 top_p: float = 1.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + system: str = 'You are an assistant' @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: @@ -62,14 +74,14 @@ class OpenAIConfig(AIConfig): Create OpenAIConfig from a dict. """ res = cls( - system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), temperature=float(source['temperature']), top_p=float(source['top_p']), frequency_penalty=float(source['frequency_penalty']), - presence_penalty=float(source['presence_penalty']) + presence_penalty=float(source['presence_penalty']), + system=str(source['system']) ) # overwrite default ID if provided if 'ID' in source: @@ -148,6 +160,8 @@ class Config: def as_dict(self) -> dict[str, Any]: res = asdict(self) + # add the AI name manually (as first element) + # (not done by 'asdict' because it's a class variable) for ID, conf in res['ais'].items(): - conf.update({'name': self.ais[ID].name}) + res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf} return res -- 2.36.6 From d4021eeb110c4d7e9ac0ee41f68e92ad1e12cf22 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 11 Sep 2023 07:38:49 +0200 Subject: [PATCH 060/121] configuration: made 'default' AI ID optional --- chatmastermind/ai_factory.py | 18 ++++++++++++------ chatmastermind/configuration.py | 3 +-- tests/test_ai_factory.py | 4 ++-- tests/test_configuration.py | 14 +++++++------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index bc4583c..420b287 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration. import argparse from typing import cast -from .configuration import Config, OpenAIConfig, default_ai_ID +from .configuration import Config, AIConfig, OpenAIConfig from .ai import AI, AIError from .ais.openai import OpenAI -def create_ai(args: argparse.Namespace, config: Config) -> AI: +def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 """ Creates an AI subclass instance from the given arguments - and configuration file. + and configuration file. If AI has not been set in the + arguments, it searches for the ID 'default'. If that + is not found, it uses the first AI in the list. """ + ai_conf: AIConfig if args.AI: try: ai_conf = config.ais[args.AI] except KeyError: raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") - elif default_ai_ID in config.ais: - ai_conf = config.ais[default_ai_ID] + elif 'default' in config.ais: + ai_conf = config.ais['default'] else: - raise AIError("No AI name given and no default exists") + try: + ai_conf = next(iter(config.ais.values())) + except StopIteration: + raise AIError("No AI found in this configuration") if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 08f6cbe..5397f4a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') supported_ais: list[str] = ['openai'] -default_ai_ID: str = 'default' default_config_path = '.config.yaml' @@ -58,7 +57,7 @@ class OpenAIConfig(AIConfig): # all members have default values, so we can easily create # a default configuration - ID: str = 'default' + ID: str = 'myopenai' api_key: str = '0123456789' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d00b319..9cb94d3 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.AI = 'default' + self.args.AI = 'myopenai' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.AI = 'default' + self.args.AI = 'myopenai' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index f3f9a98..ba8a5aa 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase): source_dict = { 'db': './test_db/', 'ais': { - 'default': { + 'myopenai': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', @@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase): config = Config.from_dict(source_dict) self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) - self.assertEqual(config.ais['default'].name, 'openai') - self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + self.assertEqual(config.ais['myopenai'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') # check that 'ID' has been added - self.assertEqual(config.ais['default'].ID, 'default') + self.assertEqual(config.ais['myopenai'].ID, 'myopenai') def test_create_default_should_create_default_config(self) -> None: Config.create_default(Path(self.test_file.name)) @@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase): config = Config( db='./test_db/', ais={ - 'default': OpenAIConfig( - ID='default', + 'myopenai': OpenAIConfig( + ID='myopenai', system='Custom system', api_key='9876543210', model='custom_model', @@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase): saved_config = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) - self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = { -- 2.36.6 From 8bd659e888b37a45faf133b2ac2f4eaaca825a39 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 17:07:01 +0200 Subject: [PATCH 061/121] added new module 'tags.py' with classes 'Tag' and 'TagLine' --- chatmastermind/tags.py | 130 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 chatmastermind/tags.py diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py new file mode 100644 index 0000000..28583a2 --- /dev/null +++ b/chatmastermind/tags.py @@ -0,0 +1,130 @@ +""" +Module implementing tag related functions and classes. +""" +from typing import Type, TypeVar, Optional + +TagInst = TypeVar('TagInst', bound='Tag') +TagLineInst = TypeVar('TagLineInst', bound='TagLine') + + +class TagError(Exception): + pass + + +class Tag(str): + """ + A single tag. A string that can contain anything but the default separator (' '). + """ + # default separator + default_separator = ' ' + # alternative separators (e. g. for backwards compatibility) + alternative_separators = [','] + + def __new__(cls: Type[TagInst], string: str) -> TagInst: + """ + Make sure the tag string does not contain the default separator. + """ + if cls.default_separator in string: + raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'") + instance = super().__new__(cls, string) + return instance + + +class TagLine(str): + """ + A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, + separated by the defaut separator (' '). Any operations on a TagLine will sort + the tags. + """ + # the prefix + prefix = 'TAGS:' + + def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: + """ + Make sure the tagline string starts with the prefix. + """ + if not string.startswith(cls.prefix): + raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst: + """ + Create a new TagLine from a set of tags. + """ + return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + + def tags(self) -> set[Tag]: + """ + Returns all tags contained in this line as a set. + """ + tagstr = self[len(self.prefix):].strip() + separator = Tag.default_separator + # look for alternative separators and use the first one found + # -> we don't support different separators in the same TagLine + for s in Tag.alternative_separators: + if s in tagstr: + separator = s + break + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + + def merge(self, taglines: set['TagLine']) -> 'TagLine': + """ + Merges the tags of all given taglines into the current one + and returns a new TagLine. + """ + merged_tags = self.tags() + for tl in taglines: + merged_tags |= tl.tags() + return self.from_set(set(sorted(merged_tags))) + + def delete_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Deletes the given tags and returns a new TagLine. + """ + return self.from_set(self.tags().difference(tags)) + + def add_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Adds the given tags and returns a new TagLine. + """ + return self.from_set(set(sorted(self.tags() | tags))) + + def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + """ + Renames the given tags and returns a new TagLine. The first + tuple element is the old name, the second one is the new name. + """ + new_tags = self.tags() + for t in tags: + if t[0] in new_tags: + new_tags.remove(t[0]) + new_tags.add(t[1]) + return self.from_set(set(sorted(new_tags))) + + def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the current TagLine matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + tag_set = self.tags() + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing -- 2.36.6 From 2d456e68f187cbd89cf048865bb5256abf2630c0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 17 Aug 2023 08:28:15 +0200 Subject: [PATCH 062/121] added testcases for Tag and TagLine classes --- tests/test_main.py | 114 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..eb13dc5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data +from chatmastermind.tags import Tag, TagLine, TagError from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -231,3 +232,116 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From 061e5f8682be086d641db1dc6c0e02a10910ee83 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 12:11:56 +0200 Subject: [PATCH 063/121] tags.py: converted most TagLine functions to module functions --- chatmastermind/tags.py | 99 ++++++++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 28 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 28583a2..bfe5fd5 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -30,6 +30,67 @@ class Tag(str): return instance +def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]: + """ + Deletes the given tags and returns a new set. + """ + return tags.difference(tags_delete) + + +def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]: + """ + Adds the given tags and returns a new set. + """ + return set(sorted(tags | tags_add)) + + +def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]: + """ + Merges the tags in 'tags_merge' into the current one and returns a new set. + """ + for ts in tags_merge: + tags |= ts + return tags + + +def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]: + """ + Renames the given tags and returns a new set. The first tuple element + is the old name, the second one is the new name. + """ + for t in tags_rename: + if t[0] in tags: + tags.remove(t[0]) + tags.add(t[1]) + return set(sorted(tags)) + + +def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the given set 'tags' matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tags for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing + + class TagLine(str): """ A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, @@ -71,37 +132,29 @@ class TagLine(str): def merge(self, taglines: set['TagLine']) -> 'TagLine': """ - Merges the tags of all given taglines into the current one - and returns a new TagLine. + Merges the tags of all given taglines into the current one and returns a new TagLine. """ - merged_tags = self.tags() - for tl in taglines: - merged_tags |= tl.tags() - return self.from_set(set(sorted(merged_tags))) + tags_merge = [tl.tags() for tl in taglines] + return self.from_set(merge_tags(self.tags(), tags_merge)) - def delete_tags(self, tags: set[Tag]) -> 'TagLine': + def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine': """ Deletes the given tags and returns a new TagLine. """ - return self.from_set(self.tags().difference(tags)) + return self.from_set(delete_tags(self.tags(), tags_delete)) - def add_tags(self, tags: set[Tag]) -> 'TagLine': + def add_tags(self, tags_add: set[Tag]) -> 'TagLine': """ Adds the given tags and returns a new TagLine. """ - return self.from_set(set(sorted(self.tags() | tags))) + return self.from_set(add_tags(self.tags(), tags_add)) - def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine': """ Renames the given tags and returns a new TagLine. The first tuple element is the old name, the second one is the new name. """ - new_tags = self.tags() - for t in tags: - if t[0] in new_tags: - new_tags.remove(t[0]) - new_tags.add(t[1]) - return self.from_set(set(sorted(new_tags))) + return self.from_set(rename_tags(self.tags(), tags_rename)) def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], tags_not: Optional[set[Tag]]) -> bool: @@ -117,14 +170,4 @@ class TagLine(str): 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag exclusion is still done if 'tags_not' is not 'None'). """ - tag_set = self.tags() - required_tags_present = False - excluded_tags_missing = False - if ((tags_or is None and tags_and is None) - or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 - or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 - required_tags_present = True - if ((tags_not is None) - or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 - excluded_tags_missing = True - return required_tags_present and excluded_tags_missing + return match_tags(self.tags(), tags_or, tags_and, tags_not) -- 2.36.6 From 264979a60dff629abf36ea464c921c422c64c0f0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:07:50 +0200 Subject: [PATCH 064/121] added new module 'message.py' --- chatmastermind/message.py | 430 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 chatmastermind/message.py diff --git a/chatmastermind/message.py b/chatmastermind/message.py new file mode 100644 index 0000000..157cd46 --- /dev/null +++ b/chatmastermind/message.py @@ -0,0 +1,430 @@ +""" +Module implementing message related functions and classes. +""" +import pathlib +import yaml +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from dataclasses import dataclass, asdict, field +from .tags import Tag, TagLine, TagError, match_tags + +QuestionInst = TypeVar('QuestionInst', bound='Question') +AnswerInst = TypeVar('AnswerInst', bound='Answer') +MessageInst = TypeVar('MessageInst', bound='Message') +AILineInst = TypeVar('AILineInst', bound='AILine') +ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') +YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] + + +class MessageError(Exception): + pass + + +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + +def source_code(text: str, include_delims: bool = False) -> list[str]: + """ + Extract all source code sections from the given text, i. e. all lines + surrounded by lines tarting with '```'. If 'include_delims' is True, + the surrounding lines are included, otherwise they are omitted. The + result list contains every source code section as a single string. + The order in the list represents the order of the sections in the text. + """ + code_sections: list[str] = [] + code_lines: list[str] = [] + in_code_block = False + + for line in text.split('\n'): + if line.strip().startswith('```'): + if include_delims: + code_lines.append(line) + if in_code_block: + code_sections.append('\n'.join(code_lines) + '\n') + code_lines.clear() + in_code_block = not in_code_block + elif in_code_block: + code_lines.append(line) + + return code_sections + + +@dataclass(kw_only=True) +class MessageFilter: + """ + Various filters for a Message. + """ + tags_or: Optional[set[Tag]] = None + tags_and: Optional[set[Tag]] = None + tags_not: Optional[set[Tag]] = None + ai: Optional[str] = None + model: Optional[str] = None + question_contains: Optional[str] = None + answer_contains: Optional[str] = None + answer_state: Optional[Literal['available', 'missing']] = None + ai_state: Optional[Literal['available', 'missing']] = None + model_state: Optional[Literal['available', 'missing']] = None + + +class AILine(str): + """ + A line that represents the AI name in a '.txt' file.. + """ + prefix: Final[str] = 'AI:' + + def __new__(cls: Type[AILineInst], string: str) -> AILineInst: + if not string.startswith(cls.prefix): + raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def ai(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: + return cls(' '.join([cls.prefix, ai])) + + +class ModelLine(str): + """ + A line that represents the model name in a '.txt' file.. + """ + prefix: Final[str] = 'MODEL:' + + def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: + if not string.startswith(cls.prefix): + raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def model(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: + return cls(' '.join([cls.prefix, model])) + + +class Question(str): + """ + A single question with a defined header. + """ + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' + + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + """ + Make sure the question string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +class Answer(str): + """ + A single answer with a defined header. + """ + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' + + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + """ + Make sure the answer string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +@dataclass +class Message(): + """ + Single message. Consists of a question and optionally an answer, a set of tags + and a file path. + """ + question: Question + answer: Optional[Answer] = None + # metadata, ignored when comparing messages + tags: Optional[set[Tag]] = field(default=None, compare=False) + ai: Optional[str] = field(default=None, compare=False) + model: Optional[str] = field(default=None, compare=False) + file_path: Optional[pathlib.Path] = field(default=None, compare=False) + # class variables + file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] + tags_yaml_key: ClassVar[str] = 'tags' + file_yaml_key: ClassVar[str] = 'file_path' + ai_yaml_key: ClassVar[str] = 'ai' + model_yaml_key: ClassVar[str] = 'model' + + def __hash__(self) -> int: + """ + The hash value is computed based on immutable members. + """ + return hash((self.question, self.answer)) + + @classmethod + def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: + """ + Create a Message from the given dict. + """ + return cls(question=data[Question.yaml_key], + answer=data.get(Answer.yaml_key, None), + tags=set(data.get(cls.tags_yaml_key, [])), + ai=data.get(cls.ai_yaml_key, None), + model=data.get(cls.model_yaml_key, None), + file_path=data.get(cls.file_yaml_key, None)) + + @classmethod + def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + """ + Return only the tags from the given Message file. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + if file_path.suffix == '.txt': + with open(file_path, "r") as fd: + tags = TagLine(fd.readline()).tags() + else: # '.yaml' + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + tags = set(sorted(data[cls.tags_yaml_key])) + return tags + + @classmethod + def from_file(cls: Type[MessageInst], file_path: pathlib.Path, + mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: + """ + Create a Message from the given file. Returns 'None' if the message does + not fulfill the filter requirements. For TXT files, the tags are matched + before building the whole message. The other filters are applied afterwards. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + + if file_path.suffix == '.txt': + message = cls.__from_file_txt(file_path, + mfilter.tags_or if mfilter else None, + mfilter.tags_and if mfilter else None, + mfilter.tags_not if mfilter else None) + else: + message = cls.__from_file_yaml(file_path) + if message and (not mfilter or (mfilter and message.match(mfilter))): + return message + else: + return None + + @classmethod + def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 + tags_or: Optional[set[Tag]] = None, + tags_and: Optional[set[Tag]] = None, + tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: + """ + Create a Message from the given TXT file. Expects the following file structures: + For '.txt': + * TagLine [Optional] + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header [Optional] + * Answer [Optional] + + Returns 'None' if the message does not fulfill the tag requirements. + """ + tags: set[Tag] = set() + question: Question + answer: Optional[Answer] = None + ai: Optional[str] = None + model: Optional[str] = None + with open(file_path, "r") as fd: + # TagLine (Optional) + try: + pos = fd.tell() + tags = TagLine(fd.readline()).tags() + except TagError: + fd.seek(pos) + if tags_or or tags_and or tags_not: + # match with an empty set if the file has no tags + if not match_tags(tags, tags_or, tags_and, tags_not): + return None + # AILine (Optional) + try: + pos = fd.tell() + ai = AILine(fd.readline()).ai() + except TagError: + fd.seek(pos) + # ModelLine (Optional) + try: + pos = fd.tell() + model = ModelLine(fd.readline()).model() + except TagError: + fd.seek(pos) + # Question and Answer + text = fd.read().strip().split('\n') + question_idx = text.index(Question.txt_header) + 1 + try: + answer_idx = text.index(Answer.txt_header) + question = Question.from_list(text[question_idx:answer_idx]) + answer = Answer.from_list(text[answer_idx + 1:]) + except ValueError: + question = Question.from_list(text[question_idx:]) + return cls(question, answer, tags, ai, model, file_path) + + @classmethod + def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: + """ + Create a Message from the given YAML file. Expects the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string [Optional] + * Message.tags_yaml_key: list of strings [Optional] + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + data[cls.file_yaml_key] = file_path + return cls.from_dict(data) + + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 + """ + Write a Message to the given file. Type is determined based on the suffix. + Currently supported suffixes: ['.txt', '.yaml'] + """ + if file_path: + self.file_path = file_path + if not self.file_path: + raise MessageError("Got no valid path to write message") + if self.file_path.suffix not in self.file_suffixes: + raise MessageError(f"File type '{self.file_path.suffix}' is not supported") + # TXT + if self.file_path.suffix == '.txt': + return self.__to_file_txt(self.file_path) + elif self.file_path.suffix == '.yaml': + return self.__to_file_yaml(self.file_path) + + def __to_file_txt(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in TXT format. + Creates the following file structures: + * TagLine + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header + * Answer + """ + with open(file_path, "w") as fd: + if self.tags: + fd.write(f'{TagLine.from_set(self.tags)}\n') + if self.ai: + fd.write(f'{AILine.from_ai(self.ai)}\n') + if self.model: + fd.write(f'{ModelLine.from_model(self.model)}\n') + fd.write(f'{Question.txt_header}\n{self.question}\n') + if self.answer: + fd.write(f'{Answer.txt_header}\n{self.answer}\n') + + def __to_file_yaml(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in YAML format. + Creates the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string + * Message.tags_yaml_key: list of strings + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "w") as fd: + data: YamlDict = {Question.yaml_key: str(self.question)} + if self.answer: + data[Answer.yaml_key] = str(self.answer) + if self.ai: + data[self.ai_yaml_key] = self.ai + if self.model: + data[self.model_yaml_key] = self.model + if self.tags: + data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) + yaml.dump(data, fd, sort_keys=False) + + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 + """ + Matches the current Message to the given filter atttributes. + Return True if all attributes match, else False. + """ + mytags = self.tags or set() + if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 + or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 + or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 + or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 + or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 + or (mfilter.model_state == 'available' and not self.model) # noqa: W503 + or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 + or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 + or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 + return False + return True + + def msg_id(self) -> str: + """ + Returns an ID that is unique throughout all messages in the same (DB) directory. + Currently this is the file name. The ID is also used for sorting messages. + """ + if self.file_path: + return self.file_path.name + else: + raise MessageError("Can't create file ID without a file path") + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From 33567df15fcfbcf84118c59c15ecb055bf9b05da Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:08:22 +0200 Subject: [PATCH 065/121] added testcases for messages.py --- tests/test_main.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index eb13dc5..8ce06cb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,6 +8,7 @@ from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.tags import Tag, TagLine, TagError +from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -345,3 +346,79 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") -- 2.36.6 From 09da312657537d4eb802e7a79e4e2a9ef1f72e90 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:04:41 +0200 Subject: [PATCH 066/121] configuration: added 'as_dict()' as an instance function --- chatmastermind/configuration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0037916..5ae32d6 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -63,4 +63,7 @@ class Config(): def to_file(self, path: str) -> None: with open(path, 'w') as f: - yaml.dump(asdict(self), f) + yaml.dump(asdict(self), f, sort_keys=False) + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From 30ccec2462a7610cf707cc13584df9bc3497b342 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:30:24 +0200 Subject: [PATCH 067/121] tags: TagLine constructor now supports multiline taglines and multiple spaces --- chatmastermind/tags.py | 20 +++++++++++--------- tests/test_main.py | 9 +++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bfe5fd5..544270c 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -1,7 +1,7 @@ """ Module implementing tag related functions and classes. """ -from typing import Type, TypeVar, Optional +from typing import Type, TypeVar, Optional, Final TagInst = TypeVar('TagInst', bound='Tag') TagLineInst = TypeVar('TagLineInst', bound='TagLine') @@ -16,9 +16,9 @@ class Tag(str): A single tag. A string that can contain anything but the default separator (' '). """ # default separator - default_separator = ' ' + default_separator: Final[str] = ' ' # alternative separators (e. g. for backwards compatibility) - alternative_separators = [','] + alternative_separators: Final[list[str]] = [','] def __new__(cls: Type[TagInst], string: str) -> TagInst: """ @@ -93,19 +93,21 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s class TagLine(str): """ - A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, - separated by the defaut separator (' '). Any operations on a TagLine will sort - the tags. + A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by + a list of tags, separated by the defaut separator (' '). Any operations on a + TagLine will sort the tags. """ # the prefix - prefix = 'TAGS:' + prefix: Final[str] = 'TAGS:' def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: """ - Make sure the tagline string starts with the prefix. + Make sure the tagline string starts with the prefix. Also replace newlines + and multiple spaces with ' ', in order to support multiline TagLines. """ if not string.startswith(cls.prefix): raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + string = ' '.join(string.split()) instance = super().__new__(cls, string) return instance @@ -114,7 +116,7 @@ class TagLine(str): """ Create a new TagLine from a set of tags. """ - return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) def tags(self) -> set[Tag]: """ diff --git a/tests/test_main.py b/tests/test_main.py index 8ce06cb..25cdc37 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -256,6 +256,10 @@ class TestTagLine(CmmTestCase): tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_invalid_tagline(self) -> None: with self.assertRaises(TagError): TagLine('tag1 tag2') @@ -273,6 +277,11 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From c0f50bace5d8f0e55ca1ed46cd7ff9c589c15a83 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 21 Aug 2023 08:29:48 +0200 Subject: [PATCH 068/121] gitignore: added vim session file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4ade1df..89bf5fb 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json .config.yaml db -noweb \ No newline at end of file +noweb +Session.vim -- 2.36.6 From acec5f1d552d5537120cc0b496f97cf3c9fadefe Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 08:46:03 +0200 Subject: [PATCH 069/121] tests: splitted 'test_main.py' into 3 modules --- tests/test_main.py | 200 ------------------------------------------ tests/test_message.py | 78 ++++++++++++++++ tests/test_tags.py | 124 ++++++++++++++++++++++++++ 3 files changed, 202 insertions(+), 200 deletions(-) create mode 100644 tests/test_message.py create mode 100644 tests/test_tags.py diff --git a/tests/test_main.py b/tests/test_main.py index 25cdc37..db5fcdb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,8 +7,6 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from chatmastermind.tags import Tag, TagLine, TagError -from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -233,201 +231,3 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) - - -class TestTag(CmmTestCase): - def test_valid_tag(self) -> None: - tag = Tag('mytag') - self.assertEqual(tag, 'mytag') - - def test_invalid_tag(self) -> None: - with self.assertRaises(TagError): - Tag('tag with space') - - def test_default_separator(self) -> None: - self.assertEqual(Tag.default_separator, ' ') - - def test_alternative_separators(self) -> None: - self.assertEqual(Tag.alternative_separators, [',']) - - -class TestTagLine(CmmTestCase): - def test_valid_tagline(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_valid_tagline_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_invalid_tagline(self) -> None: - with self.assertRaises(TagError): - TagLine('tag1 tag2') - - def test_prefix(self) -> None: - self.assertEqual(TagLine.prefix, 'TAGS:') - - def test_from_set(self) -> None: - tags = {Tag('tag1'), Tag('tag2')} - tagline = TagLine.from_set(tags) - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_tags_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_merge(self) -> None: - tagline1 = TagLine('TAGS: tag1 tag2') - tagline2 = TagLine('TAGS: tag2 tag3') - merged_tagline = tagline1.merge({tagline2}) - self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') - - def test_delete_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag2') - - def test_add_tags(self) -> None: - tagline = TagLine('TAGS: tag1') - new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') - - def test_rename_tags(self) -> None: - tagline = TagLine('TAGS: old1 old2') - new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) - self.assertEqual(new_tagline, 'TAGS: new1 new2') - - def test_match_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - - # Test case 1: Match any tag in 'tags_or' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and: set[Tag] = set() - tags_not: set[Tag] = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 2: Match all tags in 'tags_and' - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = {Tag('tag5')} - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 5: No matching tags in 'tags_or' - tags_or = {Tag('tag4'), Tag('tag5')} - tags_and = set() - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 6: Not all tags in 'tags_and' are present - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 7: Some tags in 'tags_not' are present - tags_or = {Tag('tag1')} - tags_and = set() - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 8: 'tags_or' and 'tags_and' are None, match all tags - tags_not = set() - self.assertTrue(tagline.match_tags(None, None, tags_not)) - - # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(None, None, tags_not)) - - -class SourceCodeTestCase(CmmTestCase): - def test_source_code_with_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " ```python\n print(\"Hello, World!\")\n ```\n", - " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" - ] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_without_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " print(\"Hello, World!\")\n", - " x = 10\n y = 20\n print(x + y)\n" - ] - result = source_code(text, include_delims=False) - self.assertEqual(result, expected_result) - - def test_source_code_with_single_code_block(self) -> None: - text = "```python\nprint(\"Hello, World!\")\n```" - expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_with_no_code_blocks(self) -> None: - text = "Some text without any code blocks" - expected_result: list[str] = [] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - -class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") - - def test_question_without_prefix(self) -> None: - question = Question("What is your favorite color?") - self.assertIsInstance(question, Question) - self.assertEqual(question, "What is your favorite color?") - - -class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") - - def test_answer_without_prefix(self) -> None: - answer = Answer("No") - self.assertIsInstance(answer, Answer) - self.assertEqual(answer, "No") diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..220fef2 --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,78 @@ +from .test_main import CmmTestCase +from chatmastermind.message import source_code, MessageError, Question, Answer + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") diff --git a/tests/test_tags.py b/tests/test_tags.py new file mode 100644 index 0000000..9ac9746 --- /dev/null +++ b/tests/test_tags.py @@ -0,0 +1,124 @@ +from .test_main import CmmTestCase +from chatmastermind.tags import Tag, TagLine, TagError + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From 9c2598a4b82db3b304caee342859ae3cc15bf0ed Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 19:59:38 +0200 Subject: [PATCH 070/121] tests: added testcases for Message.from/to_file() and others --- tests/test_message.py | 545 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 544 insertions(+), 1 deletion(-) diff --git a/tests/test_message.py b/tests/test_message.py index 220fef2..0e326b4 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,5 +1,9 @@ +import pathlib +import tempfile +from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, MessageError, Question, Answer +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.tags import Tag, TagLine class SourceCodeTestCase(CmmTestCase): @@ -76,3 +80,542 @@ class AnswerTestCase(CmmTestCase): answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") + + +class MessageToFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_txt_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""" + self.assertEqual(content, expected_content) + + def test_to_file_txt_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.txt_header} +This is a question. +""" + self.assertEqual(content, expected_content) + + def test_to_file_unsupported_file_type(self) -> None: + unsupported_file_path = pathlib.Path("example.doc") + with self.assertRaises(MessageError) as cm: + self.message_complete.to_file(unsupported_file_path) + self.assertEqual(str(cm.exception), "File type '.doc' is not supported") + + def test_to_file_no_file_path(self) -> None: + """ + Provoke an exception using an empty path. + """ + with self.assertRaises(MessageError) as cm: + # clear the internal file_path + self.message_complete.file_path = None + self.message_complete.to_file(None) + self.assertEqual(str(cm.exception), "Got no valid path to write message") + # reset the internal file_path + self.message_complete.file_path = self.file_path + + +class MessageToFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_multiline = Message(Question('This is a\nmultiline question.'), + Answer('This is a\nmultiline answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_yaml_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: This is a question. +{Answer.yaml_key}: This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_multiline(self) -> None: + self.message_multiline.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: |- + This is a + multiline question. +{Answer.yaml_key}: |- + This is a + multiline answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"{Question.yaml_key}: This is a question.\n" + self.assertEqual(content, expected_content) + + +class MessageFromFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_min.close() + self.file_path.unlink() + self.file_path_min.unlink() + + def test_from_file_txt_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_min(self) -> None: + """ + Read a message with only required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_txt_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.txt") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_txt_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_txt_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_txt_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class MessageFromFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + self.file_min.close() + self.file_path_min.unlink() + + def test_from_file_yaml_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_min(self) -> None: + """ + Read a message with only the required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.yaml") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_yaml_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_yaml_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_yaml_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_yaml_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class TagsFromFileTestCase(CmmTestCase): + def setUp(self) -> None: + self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt = pathlib.Path(self.file_txt.name) + with open(self.file_path_txt, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml = pathlib.Path(self.file_yaml.name) + with open(self.file_path_yaml, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + + def tearDown(self) -> None: + self.file_txt.close() + self.file_path_txt.unlink() + self.file_yaml.close() + self.file_path_yaml.unlink() + + def test_tags_from_file_txt(self) -> None: + tags = Message.tags_from_file(self.file_path_txt) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_from_file_yaml(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + +class MessageIDTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message = Message(Question('This is a question.'), + file_path=self.file_path) + self.message_no_file_path = Message(Question('This is a question.')) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_msg_id_txt(self) -> None: + self.assertEqual(self.message.msg_id(), self.file_path.name) + + def test_msg_id_txt_exception(self) -> None: + with self.assertRaises(MessageError): + self.message_no_file_path.msg_id() + + +class MessageHashTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a new question.'), + file_path=pathlib.Path('/tmp/foo/bla')) + self.message3 = Message(Question('This is a question.'), + Answer('This is an answer.'), + file_path=pathlib.Path('/tmp/foo/bla')) + # message4 is a copy of message1, because only question and + # answer are used for hashing and comparison + self.message4 = Message(Question('This is a question.'), + tags={Tag('tag1'), Tag('tag2')}, + ai='Blabla', + file_path=pathlib.Path('foobla')) + + def test_set_hashing(self) -> None: + msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4} + self.assertEqual(len(msgs), 3) + for msg in [self.message1, self.message2, self.message3]: + self.assertIn(msg, msgs) -- 2.36.6 From 17f7b2fb452ccb946d6c9344d46c70d50ce86a06 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 26 Aug 2023 12:50:47 +0200 Subject: [PATCH 071/121] Added tags filtering (prefix and contained string) to TagLine and Message --- chatmastermind/message.py | 71 ++++++++++++++++++++++-- chatmastermind/tags.py | 12 +++- tests/test_message.py | 113 +++++++++++++++++++++++++++++++++++++- tests/test_tags.py | 22 +++++++- 4 files changed, 204 insertions(+), 14 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..902aaa2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -219,21 +219,57 @@ class Message(): file_path=data.get(cls.file_yaml_key, None)) @classmethod - def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: """ - Return only the tags from the given Message file. + Return only the tags from the given Message file, + optionally filtered based on prefix or contained string. """ + tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") + # for TXT, it's enough to read the TagLine if file_path.suffix == '.txt': with open(file_path, "r") as fd: - tags = TagLine(fd.readline()).tags() + try: + tags = TagLine(fd.readline()).tags(prefix, contain) + except TagError: + pass # message without tags else: # '.yaml' - with open(file_path, "r") as fd: - data = yaml.load(fd, Loader=yaml.FullLoader) - tags = set(sorted(data[cls.tags_yaml_key])) + try: + message = cls.from_file(file_path) + if message: + msg_tags = message.filter_tags(prefix=prefix, contain=contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + if msg_tags: + tags = msg_tags + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix' + and 'contain'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix, contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod @@ -395,6 +431,29 @@ class Message(): data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, fd, sort_keys=False) + def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Filter tags based on their prefix (i. e. the tag starts with a given string) + or some contained string. + """ + res_tags = self.tags + if res_tags: + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() + + def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: + """ + Returns all tags as a string with the TagLine prefix. Optionally filtered + using 'Message.filter_tags()'. + """ + if self.tags: + return str(TagLine.from_set(self.filter_tags(prefix, contain))) + else: + return str(TagLine.from_set(set())) + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 """ Matches the current Message to the given filter atttributes. diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 544270c..c438db9 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -118,9 +118,10 @@ class TagLine(str): """ return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) - def tags(self) -> set[Tag]: + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ - Returns all tags contained in this line as a set. + Returns all tags contained in this line as a set, optionally + filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() separator = Tag.default_separator @@ -130,7 +131,12 @@ class TagLine(str): if s in tagstr: separator = s break - return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() def merge(self, taglines: set['TagLine']) -> 'TagLine': """ diff --git a/tests/test_message.py b/tests/test_message.py index 0e326b4..7b8aee9 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -543,11 +543,19 @@ class TagsFromFileTestCase(CmmTestCase): self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: - fd.write(f"""{TagLine.prefix} tag1 tag2 + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 {Question.txt_header} This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) + with open(self.file_path_txt_no_tags, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -560,6 +568,16 @@ This is an answer. {Message.tags_yaml_key}: - tag1 - tag2 + - ptag3 +""") + self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) + with open(self.file_path_yaml_no_tags, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. """) def tearDown(self) -> None: @@ -570,11 +588,90 @@ This is an answer. def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_no_tags) + self.assertSetEqual(tags, set()) def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_yaml_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml_no_tags) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, contain='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, contain='R') + self.assertSetEqual(tags, set()) + + +class TagsFromDirTestCase(CmmTestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_dir_no_tags = tempfile.TemporaryDirectory() + self.tag_sets = [ + {Tag('atag1'), Tag('atag2')}, + {Tag('btag3'), Tag('btag4')}, + {Tag('ctag5'), Tag('ctag6')} + ] + self.files = [ + pathlib.Path(self.temp_dir.name, 'file1.txt'), + pathlib.Path(self.temp_dir.name, 'file2.yaml'), + pathlib.Path(self.temp_dir.name, 'file3.txt') + ] + self.files_no_tags = [ + pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), + pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), + pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + ] + for file, tags in zip(self.files, self.tag_sets): + message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags) + message.to_file(file) + for file in self.files_no_tags: + message = Message(Question('This is a question.'), + Answer('This is an answer.')) + message.to_file(file) + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def test_tags_from_dir(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) + expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2] + self.assertEqual(all_tags, expected_tags) + + def test_tags_from_dir_prefix(self) -> None: + atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a') + expected_tags = self.tag_sets[0] + self.assertEqual(atags, expected_tags) + + def test_tags_from_dir_no_tags(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name)) + self.assertSetEqual(all_tags, set()) class MessageIDTestCase(CmmTestCase): @@ -619,3 +716,13 @@ class MessageHashTestCase(CmmTestCase): self.assertEqual(len(msgs), 3) for msg in [self.message1, self.message2, self.message3]: self.assertIn(msg, msgs) + + +class MessageTagsStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_tags_str(self) -> None: + self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') diff --git a/tests/test_tags.py b/tests/test_tags.py index 9ac9746..bd2b685 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -40,15 +40,33 @@ class TestTagLine(CmmTestCase): self.assertEqual(tagline, 'TAGS: tag1 tag2') def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') + tagline = TagLine('TAGS: atag1 btag2') tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_contain(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(contain='t') + self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(contain='1') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(contain='R') + self.assertSetEqual(tags, set()) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From 238dbbee6061c5604a0bcf58d751bc43d517054c Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 27 Aug 2023 18:07:38 +0200 Subject: [PATCH 072/121] fixed handling empty tags in TXT file --- chatmastermind/tags.py | 2 ++ tests/test_message.py | 13 +++++++++++++ tests/test_tags.py | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index c438db9..bb45a08 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -124,6 +124,8 @@ class TagLine(str): filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() + if tagstr == '': + return set() # no tags, only prefix separator = Tag.default_separator # look for alternative separators and use the first one found # -> we don't support different separators in the same TagLine diff --git a/tests/test_message.py b/tests/test_message.py index 7b8aee9..9cfb30a 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -556,6 +556,15 @@ This is an answer. This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) + with open(self.file_path_txt_tags_empty, "w") as fd: + fd.write(f"""TAGS: +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -594,6 +603,10 @@ This is an answer. tags = Message.tags_from_file(self.file_path_txt_no_tags) self.assertSetEqual(tags, set()) + def test_tags_from_file_txt_tags_empty(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_tags_empty) + self.assertSetEqual(tags, set()) + def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) diff --git a/tests/test_tags.py b/tests/test_tags.py index bd2b685..eeab199 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -44,6 +44,10 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) + def test_tags_empty(self) -> None: + tagline = TagLine('TAGS:') + self.assertSetEqual(tagline.tags(), set()) + def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() -- 2.36.6 From fde0ae4652c604d756e5df66e4aa363cc7c427fd Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 29 Aug 2023 11:35:18 +0200 Subject: [PATCH 073/121] fixed test case file cleanup --- tests/test_message.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_message.py b/tests/test_message.py index 9cfb30a..83a73ea 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -594,6 +594,12 @@ This is an answer. self.file_path_txt.unlink() self.file_yaml.close() self.file_path_yaml.unlink() + self.file_txt_no_tags.close + self.file_path_txt_no_tags.unlink() + self.file_txt_tags_empty.close + self.file_path_txt_tags_empty.unlink() + self.file_yaml_no_tags.close() + self.file_path_yaml_no_tags.unlink() def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) @@ -671,6 +677,7 @@ class TagsFromDirTestCase(CmmTestCase): def tearDown(self) -> None: self.temp_dir.cleanup() + self.temp_dir_no_tags.cleanup() def test_tags_from_dir(self) -> None: all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) -- 2.36.6 From 74c39070d620f79c458497ab9cab6fe356d9b79c Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 30 Aug 2023 08:20:25 +0200 Subject: [PATCH 074/121] fixed Message.filter_tags --- chatmastermind/message.py | 15 ++++++++------- tests/test_message.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 902aaa2..820d104 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -436,13 +436,14 @@ class Message(): Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ - res_tags = self.tags - if res_tags: - if prefix and len(prefix) > 0: - res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} - if contain and len(contain) > 0: - res_tags -= {tag for tag in res_tags if contain not in tag} - return res_tags or set() + if not self.tags: + return set() + res_tags = self.tags.copy() + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ diff --git a/tests/test_message.py b/tests/test_message.py index 83a73ea..2a9d0ff 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -746,3 +746,18 @@ class MessageTagsStrTestCase(CmmTestCase): def test_tags_str(self) -> None: self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') + + +class MessageFilterTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_filter_tags(self) -> None: + tags_all = self.message.filter_tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.message.filter_tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.message.filter_tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) -- 2.36.6 From dc3f3dc168b8b5fb19bb5b1a88c42638414a19ec Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 09:19:38 +0200 Subject: [PATCH 075/121] added 'message_in()' function and test --- chatmastermind/message.py | 16 +++++++++++++++- tests/test_message.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 820d104..3eca26e 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,7 +3,7 @@ Module implementing message related functions and classes. """ import pathlib import yaml -from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags @@ -57,6 +57,20 @@ def source_code(text: str, include_delims: bool = False) -> list[str]: return code_sections +def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: + """ + Searches the given message list for a message with the same file + name as the given one (i. e. it compares Message.file_path.name). + If the given message has no file_path, False is returned. + """ + if not message.file_path: + return False + for m in messages: + if m.file_path and m.file_path.name == message.file_path.name: + return True + return False + + @dataclass(kw_only=True) class MessageFilter: """ diff --git a/tests/test_message.py b/tests/test_message.py index 2a9d0ff..0d7953e 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -2,7 +2,7 @@ import pathlib import tempfile from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine @@ -761,3 +761,17 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.message.filter_tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + + +class MessageInTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/bla/foo')) + + def test_message_in(self) -> None: + self.assertTrue(message_in(self.message1, [self.message1])) + self.assertFalse(message_in(self.message1, [self.message2])) -- 2.36.6 From a093f9b86777067439c04de7c3dfeaa5d3a2ec68 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 15:47:29 +0200 Subject: [PATCH 076/121] tags: some clarification and new tests --- chatmastermind/tags.py | 3 ++- tests/test_tags.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bb45a08..5ea1a3a 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -77,7 +77,8 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' or all of the tags in 'tags_and' but it must never contain any of the tags in 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag - exclusion is still done if 'tags_not' is not 'None'). + exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()), + they match no tags. """ required_tags_present = False excluded_tags_missing = False diff --git a/tests/test_tags.py b/tests/test_tags.py index eeab199..aa89a06 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -144,3 +144,20 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + # Test case 10: 'tags_or' and 'tags_and' are empty, match no tags + self.assertFalse(tagline.match_tags(set(), set(), None)) + + # Test case 11: 'tags_or' is empty, match no tags + self.assertFalse(tagline.match_tags(set(), None, None)) + + # Test case 12: 'tags_and' is empty, match no tags + self.assertFalse(tagline.match_tags(None, set(), None)) + + # Test case 13: 'tags_or' is empty, match 'tags_and' + tags_and = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(None, tags_and, None)) + + # Test case 14: 'tags_and' is empty, match 'tags_or' + tags_or = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(tags_or, None, None)) -- 2.36.6 From 64893949a4193cdfcd03c0d268325b1347d71c0a Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 24 Aug 2023 16:49:54 +0200 Subject: [PATCH 077/121] added new module 'chat.py' --- chatmastermind/chat.py | 278 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 chatmastermind/chat.py diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py new file mode 100644 index 0000000..c5d8bf3 --- /dev/null +++ b/chatmastermind/chat.py @@ -0,0 +1,278 @@ +""" +Module implementing various chat classes and functions for managing a chat history. +""" +import shutil +import pathlib +from pprint import PrettyPrinter +from pydoc import pager +from dataclasses import dataclass +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable +from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .tags import Tag + +ChatInst = TypeVar('ChatInst', bound='Chat') +ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') + + +class ChatError(Exception): + pass + + +def terminal_width() -> int: + return shutil.get_terminal_size().columns + + +def pp(*args: Any, **kwargs: Any) -> None: + return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) + + +def print_paged(text: str) -> None: + pager(text) + + +def read_dir(dir_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> list[Message]: + """ + Reads the messages from the given folder. + Parameters: + * 'dir_path': source directory + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages: list[Message] = [] + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + message = Message.from_file(file_path, mfilter) + if message: + messages.append(message) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return messages + + +def write_dir(dir_path: pathlib.Path, + messages: list[Message], + file_suffix: str, + next_fid: Callable[[], int]) -> None: + """ + Write all messages to the given directory. If a message has no file_path, + a new one will be created. If message.file_path exists, it will be modified + to point to the given directory. + Parameters: + * 'dir_path': destination directory + * 'messages': list of messages to write + * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] + * 'next_fid': callable that returns the next file ID + """ + for message in messages: + file_path = message.file_path + # message has no file_path: create one + if not file_path: + fid = next_fid() + fname = f"{fid:04d}{file_suffix}" + file_path = dir_path / fname + # file_path does not point to given directory: modify it + elif not file_path.parent.samefile(dir_path): + file_path = dir_path / file_path.name + message.to_file(file_path) + + +@dataclass +class Chat: + """ + A class containing a complete chat history. + """ + + messages: list[Message] + + def filter(self, mfilter: MessageFilter) -> None: + """ + Use 'Message.match(mfilter) to remove all messages that + don't fulfill the filter requirements. + """ + self.messages = [m for m in self.messages if m.match(mfilter)] + + def sort(self, reverse: bool = False) -> None: + """ + Sort the messages according to 'Message.msg_id()'. + """ + try: + # the message may not have an ID if it doesn't have a file_path + self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) + except MessageError: + pass + + def clear(self) -> None: + """ + Delete all messages. + """ + self.messages = [] + + def add_msgs(self, msgs: list[Message]) -> None: + """ + Add new messages and sort them if possible. + """ + self.messages += msgs + self.sort() + + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Get the tags of all messages, optionally filtered by prefix or substring. + """ + tags: set[Tag] = set() + for m in self.messages: + tags |= m.filter_tags(prefix, contain) + return tags + + def print(self, dump: bool = False, source_code_only: bool = False, + with_tags: bool = False, with_file: bool = False, + paged: bool = True) -> None: + if dump: + pp(self) + return + output: list[str] = [] + for message in self.messages: + if source_code_only: + output.extend(source_code(message.question, include_delims=True)) + continue + output.append('-' * terminal_width()) + output.append(Question.txt_header) + output.append(message.question) + if message.answer: + output.append(Answer.txt_header) + output.append(message.answer) + if with_tags: + output.append(message.tags_str()) + if with_file: + output.append('FILE: ' + str(message.file_path)) + if paged: + print_paged('\n'.join(output)) + else: + print(*output, sep='\n') + + +@dataclass +class ChatDB(Chat): + """ + A 'Chat' class that is bound to a given directory structure. Supports reading + and writing messages from / to that structure. Such a structure consists of + two directories: a 'cache directory', where all messages are temporarily + stored, and a 'DB' directory, where selected messages can be stored + persistently. + """ + + default_file_suffix: ClassVar[str] = '.txt' + + cache_path: pathlib.Path + db_path: pathlib.Path + # a MessageFilter that all messages must match (if given) + mfilter: Optional[MessageFilter] = None + file_suffix: str = default_file_suffix + # the glob pattern for all messages + glob: Optional[str] = None + + def __post_init__(self) -> None: + # contains the latest message ID + self.next_fname = self.db_path / '.next' + # make all paths absolute + self.cache_path = self.cache_path.absolute() + self.db_path = self.db_path.absolute() + + @classmethod + def from_dir(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a 'ChatDB' instance from the given directory structure. + Reads all messages from 'db_path' into the local message list. + Parameters: + * 'cache_path': path to the directory for temporary messages + * 'db_path': path to the directory for persistent messages + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages = read_dir(db_path, glob, mfilter) + return cls(messages, cache_path, db_path, mfilter, + cls.default_file_suffix, glob) + + @classmethod + def from_messages(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + messages: list[Message], + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a ChatDB instance from the given message list. + """ + return cls(messages, cache_path, db_path, mfilter) + + def get_next_fid(self) -> int: + try: + with open(self.next_fname, 'r') as f: + next_fid = int(f.read()) + 1 + self.set_next_fid(next_fid) + return next_fid + except Exception: + self.set_next_fid(1) + return 1 + + def set_next_fid(self, fid: int) -> None: + with open(self.next_fname, 'w') as f: + f.write(f'{fid}') + + def read_db(self) -> None: + """ + Reads new messages from the DB directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.db_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def read_cache(self) -> None: + """ + Reads new messages from the cache directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.cache_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def write_db(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the DB directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point + to the DB directory. + """ + write_dir(self.db_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) + + def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the cache directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point to + the cache directory. + """ + write_dir(self.cache_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) -- 2.36.6 From 815a21893c70e4bf1186dc063b74229891915746 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 28 Aug 2023 14:24:24 +0200 Subject: [PATCH 078/121] added tests for 'chat.py' --- tests/test_chat.py | 297 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 tests/test_chat.py diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..2d0ffa0 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,297 @@ +import pathlib +import tempfile +import time +from io import StringIO +from unittest.mock import patch +from chatmastermind.tags import TagLine +from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter +from chatmastermind.chat import Chat, ChatDB, terminal_width +from .test_main import CmmTestCase + + +class TestChat(CmmTestCase): + def setUp(self) -> None: + self.chat = Chat([]) + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('atag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('0002.txt')) + + def test_filter(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.filter(MessageFilter(answer_contains='Answer 1')) + + self.assertEqual(len(self.chat.messages), 1) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + + def test_sort(self) -> None: + self.chat.add_msgs([self.message2, self.message1]) + self.chat.sort() + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + self.chat.sort(reverse=True) + self.assertEqual(self.chat.messages[0].question, 'Question 2') + self.assertEqual(self.chat.messages[1].question, 'Question 1') + + def test_clear(self) -> None: + self.chat.add_msgs([self.message1]) + self.chat.clear() + self.assertEqual(len(self.chat.messages), 0) + + def test_add_msgs(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.assertEqual(len(self.chat.messages), 2) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + + def test_tags(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_all = self.chat.tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.chat.tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.chat.tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + @patch('sys.stdout', new_callable=StringIO) + def test_print(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False, with_tags=True, with_file=True) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{TagLine.prefix} atag1 +FILE: 0001.txt +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +{TagLine.prefix} btag2 +FILE: 0002.txt +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + +class TestChatDB(CmmTestCase): + def setUp(self) -> None: + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('tag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('tag2')}, + file_path=pathlib.Path('0002.yaml')) + self.message3 = Message(Question('Question 3'), + Answer('Answer 3'), + {Tag('tag3')}, + file_path=pathlib.Path('0003.txt')) + self.message4 = Message(Question('Question 4'), + Answer('Answer 4'), + {Tag('tag4')}, + file_path=pathlib.Path('0004.yaml')) + + self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) + self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) + self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) + self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + + def tearDown(self) -> None: + self.db_path.cleanup() + self.cache_path.cleanup() + pass + + def test_chat_db_from_dir(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + # check that the files are sorted + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, + pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_from_dir_glob(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + glob='*.txt') + self.assertEqual(len(chat_db.messages), 2) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + + def test_chat_db_filter(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(answer_contains='Answer 2')) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[0].answer, 'Answer 2') + + def test_chat_db_from_messges(self) -> None: + chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + messages=[self.message1, self.message2, + self.message3, self.message4]) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + + def test_chat_db_fids(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.get_next_fid(), 1) + self.assertEqual(chat_db.get_next_fid(), 2) + self.assertEqual(chat_db.get_next_fid(), 3) + with open(chat_db.next_fname, 'r') as f: + self.assertEqual(f.read(), '3') + + def test_chat_db_write(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + # check that Message.file_path has been correctly updated + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + + # check the timestamp of the files in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} + # overwrite the messages in the db directory + time.sleep(0.05) + chat_db.write_db() + # check if the written files are in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + # check if all files in the DB dir have actually been overwritten + for file in db_dir_files: + self.assertGreater(file.stat().st_mtime, old_timestamps[file]) + # check that Message.file_path has been correctly updated (again) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_read(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + + # create 2 new files in the DB directory + new_message1 = Message(Question('Question 5'), + Answer('Answer 5'), + {Tag('tag5')}) + new_message2 = Message(Question('Question 6'), + Answer('Answer 6'), + {Tag('tag6')}) + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 6) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # create 2 new files in the cache directory + new_message3 = Message(Question('Question 7'), + Answer('Answer 5'), + {Tag('tag7')}) + new_message4 = Message(Question('Question 8'), + Answer('Answer 6'), + {Tag('tag8')}) + new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) + # read and check them + chat_db.read_cache() + self.assertEqual(len(chat_db.messages), 8) + # check that the new message have the cache dir path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) + # an the old ones keep their path (since they have not been replaced) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now overwrite two messages in the DB directory + new_message1.question = Question('New Question 1') + new_message2.question = Question('New Question 2') + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read from the DB dir and check if the modified messages have been updated + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + self.assertEqual(chat_db.messages[4].question, 'New Question 1') + self.assertEqual(chat_db.messages[5].question, 'New Question 2') + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now write the messages from the cache to the DB directory + new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + # check that they now have the DB path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) -- 2.36.6 From 6737fa98c73a1db51f9ee9bf25b0765e2c193c96 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 08:57:54 +0200 Subject: [PATCH 079/121] added tokens() function to Message and Chat --- chatmastermind/chat.py | 7 +++++++ chatmastermind/message.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c5d8bf3..4a458df 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -129,6 +129,13 @@ class Chat: tags |= m.filter_tags(prefix, contain) return tags + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by all messages in this chat. + If unknown, 0 is returned. + """ + return sum(m.tokens() for m in self.messages) + def print(self, dump: bool = False, source_code_only: bool = False, with_tags: bool = False, with_file: bool = False, paged: bool = True) -> None: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 3eca26e..675ab3a 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -132,6 +132,7 @@ class Question(str): """ A single question with a defined header. """ + tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' @@ -165,6 +166,7 @@ class Answer(str): """ A single answer with a defined header. """ + tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '=== ANSWER ===' yaml_key: ClassVar[str] = 'answer' @@ -502,3 +504,13 @@ class Message(): def as_dict(self) -> dict[str, Any]: return asdict(self) + + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by this message. + If unknown, 0 is returned. + """ + if self.answer: + return self.question.tokens + self.answer.tokens + else: + return self.question.tokens -- 2.36.6 From 33565d351dc575660955b32a63f5c427998ec80c Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:07:58 +0200 Subject: [PATCH 080/121] configuration: added AIConfig class --- chatmastermind/configuration.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 5ae32d6..0780604 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -7,7 +7,15 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') @dataclass -class OpenAIConfig(): +class AIConfig: + """ + The base class of all AI configurations. + """ + name: str + + +@dataclass +class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ @@ -25,6 +33,7 @@ class OpenAIConfig(): Create OpenAIConfig from a dict. """ return cls( + name='OpenAI', api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -36,7 +45,7 @@ class OpenAIConfig(): @dataclass -class Config(): +class Config: """ The configuration file structure. """ @@ -47,7 +56,7 @@ class Config(): @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create OpenAIConfig from a dict. + Create Config from a dict. """ return cls( system=str(source['system']), -- 2.36.6 From b22a4b07ed99ef9e3c13159479bec9cd07b4b9f9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:35:32 +0200 Subject: [PATCH 081/121] chat: added tags_frequency() function and test --- chatmastermind/chat.py | 11 ++++++++++- tests/test_chat.py | 9 +++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4a458df..759467d 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -127,7 +127,16 @@ class Chat: tags: set[Tag] = set() for m in self.messages: tags |= m.filter_tags(prefix, contain) - return tags + return set(sorted(tags)) + + def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: + """ + Get the frequency of all tags of all messages, optionally filtered by prefix or substring. + """ + tags: list[Tag] = [] + for m in self.messages: + tags += [tag for tag in m.filter_tags(prefix, contain)] + return {tag: tags.count(tag) for tag in sorted(tags)} def tokens(self) -> int: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index 2d0ffa0..5f1fcb6 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -14,7 +14,7 @@ class TestChat(CmmTestCase): self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), - {Tag('atag1')}, + {Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), @@ -57,6 +57,11 @@ class TestChat(CmmTestCase): tags_cont = self.chat.tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + def test_tags_frequency(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_freq = self.chat.tags_frequency() + self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) @@ -83,7 +88,7 @@ Answer 2 Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 +{TagLine.prefix} atag1 btag2 FILE: 0001.txt {'-'*terminal_width()} {Question.txt_header} -- 2.36.6 From 48c8e951e1d439426e8a22a89b7dc2a24fdd0898 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:44:27 +0200 Subject: [PATCH 082/121] chat: fixed handling of unsupported files in DB and chache dir --- chatmastermind/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 759467d..11f1d74 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path, messages: list[Message] = [] file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in sorted(file_iter): - if file_path.is_file(): + if file_path.is_file() and file_path.suffix in Message.file_suffixes: try: message = Message.from_file(file_path, mfilter) if message: -- 2.36.6 From c318b99671be511d7c79226d65974befd4241932 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:18:41 +0200 Subject: [PATCH 083/121] chat: improved history printing --- chatmastermind/chat.py | 15 ++++++--------- tests/test_chat.py | 10 +++++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 11f1d74..e4e8ab6 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -145,27 +145,24 @@ class Chat: """ return sum(m.tokens() for m in self.messages) - def print(self, dump: bool = False, source_code_only: bool = False, - with_tags: bool = False, with_file: bool = False, + def print(self, source_code_only: bool = False, + with_tags: bool = False, with_files: bool = False, paged: bool = True) -> None: - if dump: - pp(self) - return output: list[str] = [] for message in self.messages: if source_code_only: output.extend(source_code(message.question, include_delims=True)) continue output.append('-' * terminal_width()) + if with_tags: + output.append(message.tags_str()) + if with_files: + output.append('FILE: ' + str(message.file_path)) output.append(Question.txt_header) output.append(message.question) if message.answer: output.append(Answer.txt_header) output.append(message.answer) - if with_tags: - output.append(message.tags_str()) - if with_file: - output.append('FILE: ' + str(message.file_path)) if paged: print_paged('\n'.join(output)) else: diff --git a/tests/test_chat.py b/tests/test_chat.py index 5f1fcb6..8e1ad0d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -82,21 +82,21 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) - self.chat.print(paged=False, with_tags=True, with_file=True) + self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} +{TagLine.prefix} atag1 btag2 +FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 btag2 -FILE: 0001.txt {'-'*terminal_width()} +{TagLine.prefix} btag2 +FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 -{TagLine.prefix} btag2 -FILE: 0002.txt """ self.assertEqual(mock_stdout.getvalue(), expected_output) -- 2.36.6 From 8e63831701741705799fd7baee065c4fe6b420b0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 09:19:47 +0200 Subject: [PATCH 084/121] chat: added clear_cache() function and test --- chatmastermind/chat.py | 20 +++++++++++++++++++ tests/test_chat.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index e4e8ab6..9fc0a27 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -82,6 +82,17 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) +def clear_dir(dir_path: pathlib.Path, + glob: Optional[str] = None) -> None: + """ + Deletes all Message files in the given directory. + """ + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in file_iter: + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + file_path.unlink(missing_ok=True) + + @dataclass class Chat: """ @@ -289,3 +300,12 @@ class ChatDB(Chat): msgs if msgs else self.messages, self.file_suffix, self.get_next_fid) + + def clear_cache(self) -> None: + """ + Deletes all Message files from the cache dir and removes those messages from + the internal list. + """ + clear_dir(self.cache_path, self.glob) + # only keep messages from DB dir (or those that have not yet been written) + self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e1ad0d..9e74061 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -300,3 +300,48 @@ class TestChatDB(CmmTestCase): # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + + def test_chat_db_clear(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + + # now rewrite them to the DB dir and check for modified paths + chat_db.write_db() + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + + # add a new message with empty file_path + message_empty = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You don't belong here!")) + # and one for the cache dir + message_cache = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You're a creep!"), + file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + chat_db.add_msgs([message_empty, message_cache]) + + # clear the cache and check the cache dir + chat_db.clear_cache() + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 0) + # make sure that the DB messages (and the new message) are still there + self.assertEqual(len(chat_db.messages), 5) + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + # but not the message with the cache dir path + self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) -- 2.36.6 From aba3eb783d3ac9b1a644225ee509393673cf21ab Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 16:00:24 +0200 Subject: [PATCH 085/121] message: improved robustness of Question and Answer content checks and tests --- chatmastermind/message.py | 48 +++++++++++++++++++++------------------ tests/test_message.py | 29 ++++++++++++++++++----- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 675ab3a..384fb96 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -128,29 +128,29 @@ class ModelLine(str): return cls(' '.join([cls.prefix, model])) -class Question(str): +class Answer(str): """ - A single question with a defined header. + A single answer with a defined header. """ - tokens: int = 0 # tokens used by this question - txt_header: ClassVar[str] = '=== QUESTION ===' - yaml_key: ClassVar[str] = 'question' + tokens: int = 0 # tokens used by this answer + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ - Make sure the question string does not contain the header. + Make sure the answer string does not contain the header as a whole line. """ - if cls.txt_header in string: - raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if cls.txt_header in string.split('\n'): + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance @@ -162,29 +162,33 @@ class Question(str): return source_code(self, include_delims) -class Answer(str): +class Question(str): """ - A single answer with a defined header. + A single question with a defined header. """ - tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' - yaml_key: ClassVar[str] = 'answer' + tokens: int = 0 # tokens used by this question + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ - Make sure the answer string does not contain the header. + Make sure the question string does not contain the header as a whole line + (also not that from 'Answer', so it's always clear where the answer starts). """ - if cls.txt_header in string: - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + string_lines = string.split('\n') + if cls.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if Answer.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance diff --git a/tests/test_message.py b/tests/test_message.py index 0d7953e..e01de66 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase): class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: + def test_question_with_header(self) -> None: with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") + Question(f"{Question.txt_header}\nWhat is your name?") - def test_question_without_prefix(self) -> None: + def test_question_with_answer_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Answer.txt_header}\nBob") + + def test_question_with_legal_header(self) -> None: + """ + If the header is just a part of a line, it's fine. + """ + question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + self.assertIsInstance(question, Question) + self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + + def test_question_without_header(self) -> None: question = Question("What is your favorite color?") self.assertIsInstance(question, Question) self.assertEqual(question, "What is your favorite color?") class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: + def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") + Answer(f"{Answer.txt_header}\nno") - def test_answer_without_prefix(self) -> None: + def test_answer_with_legal_header(self) -> None: + answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + + def test_answer_without_header(self) -> None: answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") -- 2.36.6 From d35de86c67fc963a3d3b97c48757ce05bd31cc04 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:00:08 +0200 Subject: [PATCH 086/121] message: fixed Answer header for TXT format --- chatmastermind/message.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 384fb96..87de8e2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -96,7 +96,7 @@ class AILine(str): def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): - raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -116,7 +116,7 @@ class ModelLine(str): def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): - raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -133,7 +133,7 @@ class Answer(str): A single answer with a defined header. """ tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' + txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: @@ -355,17 +355,20 @@ class Message(): try: pos = fd.tell() ai = AILine(fd.readline()).ai() - except TagError: + except MessageError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() - except TagError: + except MessageError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') - question_idx = text.index(Question.txt_header) + 1 + try: + question_idx = text.index(Question.txt_header) + 1 + except ValueError: + raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) -- 2.36.6 From 713b55482a61195ce7163bd839b6eb257619fb03 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:19:14 +0200 Subject: [PATCH 087/121] message: added rename_tags() function and test --- chatmastermind/message.py | 10 +++++++++- tests/test_message.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 87de8e2..0fb949c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,7 @@ import pathlib import yaml from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field -from .tags import Tag, TagLine, TagError, match_tags +from .tags import Tag, TagLine, TagError, match_tags, rename_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') @@ -499,6 +499,14 @@ class Message(): return False return True + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None: + """ + Renames the given tags. The first tuple element is the old name, + the second one is the new name. + """ + if self.tags: + self.tags = rename_tags(self.tags, tags_rename) + def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. diff --git a/tests/test_message.py b/tests/test_message.py index e01de66..e860538 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -792,3 +792,15 @@ class MessageInTestCase(CmmTestCase): def test_message_in(self) -> None: self.assertTrue(message_in(self.message1, [self.message1])) self.assertFalse(message_in(self.message1, [self.message2])) + + +class MessageRenameTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_rename_tags(self) -> None: + self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) + self.assertIsNotNone(self.message.tags) + self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -- 2.36.6 From 2e2228bd60e4aa7761e64833fe517604a9784096 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 3 Sep 2023 10:18:16 +0200 Subject: [PATCH 088/121] chat: new possibilites for adding messages and better tests --- chatmastermind/chat.py | 75 ++++++++++++++++++++++++---- tests/test_chat.py | 109 ++++++++++++++++++++++++++++++++--------- 2 files changed, 153 insertions(+), 31 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 9fc0a27..7e6df8f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path, return messages +def make_file_path(dir_path: pathlib.Path, + file_suffix: str, + next_fid: Callable[[], int]) -> pathlib.Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + return dir_path / f"{next_fid():04d}{file_suffix}" + + def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, @@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path, file_path = message.file_path # message has no file_path: create one if not file_path: - fid = next_fid() - fname = f"{fid:04d}{file_suffix}" - file_path = dir_path / fname + file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name @@ -124,11 +132,11 @@ class Chat: """ self.messages = [] - def add_msgs(self, msgs: list[Message]) -> None: + def add_messages(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ - self.messages += msgs + self.messages += messages self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -279,25 +287,25 @@ class ChatDB(Chat): self.messages += new_messages self.sort() - def write_db(self, msgs: Optional[list[Message]] = None) -> None: + def write_db(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) - def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + def write_cache(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) @@ -309,3 +317,52 @@ class ChatDB(Chat): clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the DB directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the cache directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def write_messages(self, messages: Optional[list[Message]] = None) -> None: + """ + Write either the given messages or the internal ones to their current file_path. + If messages are given, they all must have a valid file_path. When writing the + internal messages, the ones with a valid file_path are written, the others + are ignored. + """ + if messages and any(m.file_path is None for m in messages): + raise ChatError("Can't write files without a valid file_path") + msgs = iter(messages if messages else self.messages) + while (m := next(msgs, None)): + m.to_file() diff --git a/tests/test_chat.py b/tests/test_chat.py index 9e74061..a1c020e 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -5,7 +5,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, terminal_width +from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from .test_main import CmmTestCase @@ -22,14 +22,14 @@ class TestChat(CmmTestCase): file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: - self.chat.add_msgs([self.message2, self.message1]) + self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') @@ -38,18 +38,18 @@ class TestChat(CmmTestCase): self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: - self.chat.add_msgs([self.message1]) + self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) - def test_add_msgs(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') @@ -58,13 +58,13 @@ class TestChat(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} @@ -81,7 +81,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 @@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase): self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: + """ + List all Message files in the given TemporaryDirectory. + """ + # exclude '.next' + return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) def tearDown(self) -> None: self.db_path.cleanup() @@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase): def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.get_next_fid(), 1) - self.assertEqual(chat_db.get_next_fid(), 2) - self.assertEqual(chat_db.get_next_fid(), 3) + self.assertEqual(chat_db.get_next_fid(), 5) + self.assertEqual(chat_db.get_next_fid(), 6) + self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_fname, 'r') as f: - self.assertEqual(f.read(), '3') + self.assertEqual(f.read(), '7') def test_chat_db_write(self) -> None: # create a new ChatDB instance @@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) @@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -333,15 +344,69 @@ class TestChatDB(CmmTestCase): message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) - chat_db.add_msgs([message_empty, message_cache]) + chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + # check if the file_path has been correctly set + self.assertIsNotNone(message1.file_path) + self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + # check if the file_path has been correctly set + self.assertIsNotNone(message2.file_path) + self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 5) + + with self.assertRaises(ChatError): + chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) + + def test_chat_db_write_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + # try to write a message without a valid file_path + message = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.write_messages([message]) + + # write a message with a valid file_path + message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' + chat_db.write_messages([message]) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) -- 2.36.6 From abb7fdacb65a7e266f63f4c2397e76e5e5961338 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 08:49:43 +0200 Subject: [PATCH 089/121] message / chat: output improvements --- chatmastermind/chat.py | 16 ++++------------ chatmastermind/message.py | 24 ++++++++++++++++++++++++ tests/test_chat.py | 16 ++++++++++++---- tests/test_message.py | 24 ++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 7e6df8f..c631dab 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -7,7 +7,7 @@ from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, ClassVar, Any, Callable -from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -170,18 +170,10 @@ class Chat: output: list[str] = [] for message in self.messages: if source_code_only: - output.extend(source_code(message.question, include_delims=True)) + output.append(message.to_str(source_code_only=True)) continue - output.append('-' * terminal_width()) - if with_tags: - output.append(message.tags_str()) - if with_files: - output.append('FILE: ' + str(message.file_path)) - output.append(Question.txt_header) - output.append(message.question) - if message.answer: - output.append(Answer.txt_header) - output.append(message.answer) + output.append(message.to_str(with_tags, with_files)) + output.append('\n' + ('-' * terminal_width()) + '\n') if paged: print_paged('\n'.join(output)) else: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 0fb949c..35de3b9 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -392,6 +392,30 @@ class Message(): data[cls.file_yaml_key] = file_path return cls.from_dict(data) + def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: + """ + Return the current Message as a string. + """ + output: list[str] = [] + if source_code_only: + # use the source code from answer only + if self.answer: + output.extend(self.answer.source_code(include_delims=True)) + return '\n'.join(output) if len(output) > 0 else '' + if with_tags: + output.append(self.tags_str()) + if with_file: + output.append('FILE: ' + str(self.file_path)) + output.append(Question.txt_header) + output.append(self.question) + if self.answer: + output.append(Answer.txt_header) + output.append(self.answer) + return '\n'.join(output) + + def __str__(self) -> str: + return self.to_str(False, False, False) + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ Write a Message to the given file. Type is determined based on the suffix. diff --git a/tests/test_chat.py b/tests/test_chat.py index a1c020e..f8302eb 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -66,16 +66,20 @@ class TestChat(CmmTestCase): def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) - expected_output = f"""{'-'*terminal_width()} -{Question.txt_header} + expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) @@ -83,20 +87,24 @@ Answer 2 def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) - expected_output = f"""{'-'*terminal_width()} -{TagLine.prefix} atag1 btag2 + expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) diff --git a/tests/test_message.py b/tests/test_message.py index e860538..a49c893 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -804,3 +804,27 @@ class MessageRenameTagsTestCase(CmmTestCase): self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) self.assertIsNotNone(self.message.tags) self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] + + +class MessageToStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_to_str(self) -> None: + expected_output = f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(), expected_output) + + def test_to_str_with_tags_and_file(self) -> None: + expected_output = f"""{TagLine.prefix} atag1 btag2 +FILE: /tmp/foo/bla +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) -- 2.36.6 From e1414835c8c2cdc96d9c425b7d585afb3ffbb261 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 08:16:55 +0200 Subject: [PATCH 090/121] chat: added functions for finding and deleting messages --- chatmastermind/chat.py | 52 ++++++++++++++++++++++++++++++++---------- tests/test_chat.py | 22 ++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c631dab..4e8fb20 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -2,7 +2,7 @@ Module implementing various chat classes and functions for managing a chat history. """ import shutil -import pathlib +from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass @@ -30,7 +30,7 @@ def print_paged(text: str) -> None: pager(text) -def read_dir(dir_path: pathlib.Path, +def read_dir(dir_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ @@ -55,9 +55,9 @@ def read_dir(dir_path: pathlib.Path, return messages -def make_file_path(dir_path: pathlib.Path, +def make_file_path(dir_path: Path, file_suffix: str, - next_fid: Callable[[], int]) -> pathlib.Path: + next_fid: Callable[[], int]) -> Path: """ Create a file_path for the given directory using the given file_suffix and ID generator function. @@ -65,7 +65,7 @@ def make_file_path(dir_path: pathlib.Path, return dir_path / f"{next_fid():04d}{file_suffix}" -def write_dir(dir_path: pathlib.Path, +def write_dir(dir_path: Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: @@ -90,7 +90,7 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) -def clear_dir(dir_path: pathlib.Path, +def clear_dir(dir_path: Path, glob: Optional[str] = None) -> None: """ Deletes all Message files in the given directory. @@ -139,6 +139,34 @@ class Chat: self.messages += messages self.sort() + def latest_message(self) -> Optional[Message]: + """ + Returns the last added message (according to the file ID). + """ + if len(self.messages) > 0: + self.sort() + return self.messages[-1] + else: + return None + + def find_messages(self, msg_names: list[str]) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the + caller should check the result if he requires all messages). + """ + return [m for m in self.messages + if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + + def remove_messages(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (incl. the suffix) or full paths. + """ + self.messages = [m for m in self.messages + if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + self.sort() + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. @@ -192,8 +220,8 @@ class ChatDB(Chat): default_file_suffix: ClassVar[str] = '.txt' - cache_path: pathlib.Path - db_path: pathlib.Path + cache_path: Path + db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix @@ -209,8 +237,8 @@ class ChatDB(Chat): @classmethod def from_dir(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ @@ -230,8 +258,8 @@ class ChatDB(Chat): @classmethod def from_messages(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index f8302eb..d81a97a 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -62,6 +62,28 @@ class TestChat(CmmTestCase): tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + def test_find_remove_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + msgs = self.chat.find_messages(['0001.txt']) + self.assertListEqual(msgs, [self.message1]) + msgs = self.chat.find_messages(['0001.txt', '0002.txt']) + self.assertListEqual(msgs, [self.message1, self.message2]) + # add new Message with full path + message3 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('/foo/bla/0003.txt')) + self.chat.add_messages([message3]) + # find new Message by full path + msgs = self.chat.find_messages(['/foo/bla/0003.txt']) + self.assertListEqual(msgs, [message3]) + # find Message with full path only by filename + msgs = self.chat.find_messages(['0003.txt']) + self.assertListEqual(msgs, [message3]) + # remove last message + self.chat.remove_messages(['0003.txt']) + self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) -- 2.36.6 From 8923a13352980ec6af36e1f381f177ae5a5a1841 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:46:23 +0200 Subject: [PATCH 091/121] cmm: the 'tags' command now uses the new 'ChatDB' --- chatmastermind/main.py | 34 +++++++++++++++++++++------------- chatmastermind/utils.py | 5 ----- tests/test_main.py | 2 +- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index c30ea4e..f9eccba 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,10 +7,11 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType +from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config +from .chat import ChatDB from itertools import zip_longest from typing import Any @@ -56,12 +57,17 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: Config) -> None: +def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ - Handler for the 'tag' command. + Handler for the 'tags' command. """ + chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), + db_path=pathlib.Path(config.db)) if args.list: - print_tags_frequency(get_tags(config, None)) + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming def config_cmd(args: argparse.Namespace, config: Config) -> None: @@ -190,14 +196,16 @@ def create_parser() -> argparse.ArgumentParser: hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') - # 'tag' command parser - tag_cmd_parser = cmdparser.add_parser('tag', - help="Manage tags.", - aliases=['t']) - tag_cmd_parser.set_defaults(func=tag_cmd) - tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) - tag_group.add_argument('-l', '--list', help="List all tags and their frequency", - action='store_true') + # 'tags' command parser + tags_cmd_parser = cmdparser.add_parser('tags', + help="Manage tags.", + aliases=['t']) + tags_cmd_parser.set_defaults(func=tags_cmd) + tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) + tags_group.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') + tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") + tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") # 'config' command parser config_cmd_parser = cmdparser.add_parser('config', diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 6543ce1..4135ae3 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -79,8 +79,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals print(message['content']) else: print(f"{message['role'].upper()}: {message['content']}") - - -def print_tags_frequency(tags: list[str]) -> None: - for tag in sorted(set(tags)): - print(f"- {tag}: {tags.count(tag)}") diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..23c3d00 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -227,7 +227,7 @@ class TestCreateParser(CmmTestCase): mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From 4c378dde854499771f377b29d8ae40b4984c7eec Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:21:49 +0200 Subject: [PATCH 092/121] cmm: the 'hist' command now uses the new 'ChatDB' --- chatmastermind/main.py | 58 +++++++++++++++++++++++------------------- tests/test_main.py | 15 ++++++----- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index f9eccba..8aef252 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -12,6 +12,7 @@ from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB +from .message import MessageFilter from itertools import zip_longest from typing import Any @@ -31,11 +32,11 @@ def create_question_with_hist(args: argparse.Namespace, by the specified tags. """ tags = args.tags or [] - extags = args.extags or [] + etags = args.etags or [] otags = args.output_tags or [] - if not args.only_source_code: - print_tag_args(tags, extags, otags) + if not args.source_code_only: + print_tag_args(tags, etags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -52,8 +53,10 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, extags, config, - args.match_all_tags, False, False) + chat = create_chat_hist(full_question, tags, etags, config, + match_all_tags=True if args.atags else False, # FIXME + with_tags=False, + with_file=False) return chat, full_question, tags @@ -94,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: if args.model: config.openai.model = args.model chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.only_source_code) + print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) save_answers(question, answers, tags, otags, config) @@ -106,14 +109,18 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ - tags = args.tags or [] - extags = args.extags or [] - chat = create_chat_hist(None, tags, extags, config, - args.match_all_tags, - args.with_tags, - args.with_files) - print_chat_hist(chat, args.dump, args.only_source_code) + mfilter = MessageFilter(tags_or=args.tags, + tags_and=args.atags, + tags_not=args.etags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) def print_cmd(args: argparse.Namespace, config: Config) -> None: @@ -129,7 +136,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: else: print(f"Unknown file type: {args.file}") sys.exit(1) - if args.only_source_code: + if args.source_code_only: display_source_code(data['answer']) elif args.answer: print(data['answer'].strip()) @@ -153,18 +160,17 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names', metavar='TAGS') + help='List of tag names (one must match)', metavar='TAGS') tag_arg.completer = tags_completer # type: ignore - extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', - help='List of tag names to exclude', metavar='EXTAGS') - extag_arg.completer = tags_completer # type: ignore + atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', + help='List of tag names (all must match)', metavar='TAGS') + atag_arg.completer = tags_completer # type: ignore + etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', + help='List of tag names to exclude', metavar='ETAGS') + etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', help='List of output tag names, default is input', metavar='OTAGS') otag_arg.completer = tags_completer # type: ignore - tag_parser.add_argument('-a', '--match-all-tags', - help="All given tags must match when selecting chat history entries", - action='store_true') - # enable autocompletion for tags # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], @@ -179,7 +185,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', + ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') # 'hist' command parser @@ -187,14 +193,14 @@ def create_parser() -> argparse.ArgumentParser: help="Print chat history.", aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) - hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", - action='store_true') hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", action='store_true') - hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') + hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') + hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', diff --git a/tests/test_main.py b/tests/test_main.py index 23c3d00..bb9aa2a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase): self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], - extags=['extag1'], + atags=None, + etags=['etag1'], output_tags=None, question=[self.question], source=None, - only_source_code=False, + source_code_only=False, number=3, max_tokens=None, temperature=None, @@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase): with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.extags, + self.args.etags, []) mock_create_chat_hist.assert_called_once_with(self.question, self.args.tags, - self.args.extags, + self.args.etags, self.config, - False, False, False) + match_all_tags=False, + with_tags=False, + with_file=False) mock_print_chat_hist.assert_called_once_with('test_chat', False, - self.args.only_source_code) + self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, self.args.number) -- 2.36.6 From 5e4ec70072fe77a888973746594e56df01839dfd Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:42:59 +0200 Subject: [PATCH 093/121] cmm: tags completion now uses 'Message.tags_from_dir' (fixes tag completion for me) --- chatmastermind/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 8aef252..1796f69 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,13 +6,13 @@ import yaml import sys import argcomplete import argparse -import pathlib +from pathlib import Path from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data +from .storage import save_answers, create_chat_hist, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import MessageFilter +from .message import Message, MessageFilter from itertools import zip_longest from typing import Any @@ -64,8 +64,8 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. """ - chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), - db_path=pathlib.Path(config.db)) + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) if args.list: tags_freq = chat.tags_frequency(args.prefix, args.contain) for tag, freq in tags_freq.items(): @@ -127,7 +127,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ - fname = pathlib.Path(args.file) + fname = Path(args.file) if fname.suffix == '.yaml': with open(args.file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) -- 2.36.6 From e186afbef046e04f5588805c6def89ee6a5c5eee Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:07:02 +0200 Subject: [PATCH 094/121] cmm: the 'print' command now uses 'Message.from_file()' --- chatmastermind/main.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 1796f69..ed67f7b 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -2,17 +2,16 @@ # -*- coding: utf-8 -*- # vim: set fileencoding=utf-8 : -import yaml import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType +from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter +from .message import Message, MessageFilter, MessageError from itertools import zip_longest from typing import Any @@ -128,13 +127,12 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'print' command. """ fname = Path(args.file) - if fname.suffix == '.yaml': - with open(args.file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif fname.suffix == '.txt': - data = read_file(fname) - else: - print(f"Unknown file type: {args.file}") + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") sys.exit(1) if args.source_code_only: display_source_code(data['answer']) @@ -227,14 +225,22 @@ def create_parser() -> argparse.ArgumentParser: # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', - help="Print files.", + help="Print message files.", aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) +<<<<<<< HEAD print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') +||||||| parent of bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', + action='store_true') +======= + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', + action='store_true') +>>>>>>> bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') argcomplete.autocomplete(parser) return parser -- 2.36.6 From 4bd144c4d75a2892947e470a863dc87d6f4f0633 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:00:15 +0200 Subject: [PATCH 095/121] added new module 'ai.py' --- chatmastermind/ai.py | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 chatmastermind/ai.py diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py new file mode 100644 index 0000000..4a8b914 --- /dev/null +++ b/chatmastermind/ai.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import Protocol, Optional, Union +from .configuration import AIConfig +from .tags import Tag +from .message import Message +from .chat import Chat + + +class AIError(Exception): + pass + + +@dataclass +class Tokens: + prompt: int = 0 + completion: int = 0 + total: int = 0 + + +@dataclass +class AIResponse: + """ + The response to an AI request. Consists of one or more messages + (each containing the question and a single answer) and the nr. + of used tokens. + """ + messages: list[Message] + tokens: Optional[Tokens] = None + + +class AI(Protocol): + """ + The base class for AI clients. + """ + + name: str + config: AIConfig + + def request(self, + question: Message, + context: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + context (i. e. chat history). The nr. of requested answers + corresponds to the nr. of messages in the 'AIResponse'. + """ + raise NotImplementedError + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def tokens(self, data: Union[Message, Chat]) -> int: + """ + Computes the nr. of AI language tokens for the given message + or chat. Note that the computation may not be 100% accurate + and is not implemented for all AIs. + """ + raise NotImplementedError -- 2.36.6 From 823d3bf7dc1ed2bc40dc3604007a21e0a69bb475 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 10:18:09 +0200 Subject: [PATCH 096/121] added new module 'openai.py' --- chatmastermind/ais/openai.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 chatmastermind/ais/openai.py diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py new file mode 100644 index 0000000..74438b8 --- /dev/null +++ b/chatmastermind/ais/openai.py @@ -0,0 +1,96 @@ +""" +Implements the OpenAI client classes and functions. +""" +import openai +from typing import Optional, Union +from ..tags import Tag +from ..message import Message, Answer +from ..chat import Chat +from ..ai import AI, AIResponse, Tokens +from ..configuration import OpenAIConfig + +ChatType = list[dict[str, str]] + + +class OpenAI(AI): + """ + The OpenAI AI client. + """ + + def __init__(self, name: str, config: OpenAIConfig) -> None: + self.name = name + self.config = config + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + chat history. The nr. of requested answers corresponds to the + nr. of messages in the 'AIResponse'. + """ + # FIXME: use real 'system' message (store in OpenAIConfig) + oai_chat = self.openai_chat(chat, "system", question) + response = openai.ChatCompletion.create( + model=self.config.model, + messages=oai_chat, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + n=num_answers, + frequency_penalty=self.config.frequency_penalty, + presence_penalty=self.config.presence_penalty) + answers: list[Message] = [] + for choice in response['choices']: # type: ignore + answers.append(Message(question=question.question, + answer=Answer(choice['message']['content']), + tags=otags, + ai=self.name, + model=self.config.model)) + return AIResponse(answers, Tokens(response['usage']['prompt'], + response['usage']['completion'], + response['usage']['total'])) + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def print_models(self) -> None: + """ + Print all models supported by the current AI. + """ + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def openai_chat(self, chat: Chat, system: str, + question: Optional[Message] = None) -> ChatType: + """ + Create a chat history with system message in OpenAI format. + Optionally append a new question. + """ + oai_chat: ChatType = [] + + def append(role: str, content: str) -> None: + oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + + append('system', system) + for message in chat.messages: + if message.answer: + append('user', message.question) + append('assistant', message.answer) + if question: + append('user', question.question) + return oai_chat + + def tokens(self, data: Union[Message, Chat]) -> int: + raise NotImplementedError -- 2.36.6 From 7d154522420d112b44a94f3717726331d1ae0af7 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 5 Sep 2023 23:24:20 +0200 Subject: [PATCH 097/121] added new module 'ai_factory' --- chatmastermind/ai_factory.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 chatmastermind/ai_factory.py diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py new file mode 100644 index 0000000..c90366b --- /dev/null +++ b/chatmastermind/ai_factory.py @@ -0,0 +1,20 @@ +""" +Creates different AI instances, based on the given configuration. +""" + +import argparse +from .configuration import Config +from .ai import AI, AIError +from .ais.openai import OpenAI + + +def create_ai(args: argparse.Namespace, config: Config) -> AI: + """ + Creates an AI subclass instance from the given args and configuration. + """ + if args.ai == 'openai': + # FIXME: create actual 'OpenAIConfig' and set values from 'args' + # FIXME: use actual name from config + return OpenAI("openai", config.openai) + else: + raise AIError(f"AI '{args.ai}' is not supported") -- 2.36.6 From 034e4093f1ff65d352ea41b89155eb22153477f8 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:35:53 +0200 Subject: [PATCH 098/121] cmm: added 'question' command --- chatmastermind/main.py | 103 +++++++++++++++++++++++++++++++++-------- tests/test_main.py | 18 +++---- 2 files changed, 93 insertions(+), 28 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ed67f7b..67eafae 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -11,7 +11,9 @@ from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter, MessageError +from .message import Message, MessageFilter, MessageError, Question +from .ai_factory import create_ai +from .ai import AI, AIResponse from itertools import zip_longest from typing import Any @@ -30,12 +32,12 @@ def create_question_with_hist(args: argparse.Namespace, Creates the "AI request", including the question and chat history as determined by the specified tags. """ - tags = args.tags or [] - etags = args.etags or [] + tags = args.or_tags or [] + xtags = args.exclude_tags or [] otags = args.output_tags or [] if not args.source_code_only: - print_tag_args(tags, etags, otags) + print_tag_args(tags, xtags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -52,8 +54,8 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, etags, config, - match_all_tags=True if args.atags else False, # FIXME + chat = create_chat_hist(full_question, tags, xtags, config, + match_all_tags=True if args.and_tags else False, # FIXME with_tags=False, with_file=False) return chat, full_question, tags @@ -85,6 +87,47 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None: config.to_file(args.config) +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = Message(question=Question(args.question), + tags=args.ouput_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass + + def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. @@ -98,7 +141,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] - answers, usage = ai(chat, config, args.number) + answers, usage = ai(chat, config, args.num_answers) save_answers(question, answers, tags, otags, config) print("-" * terminal_width()) print(f"Usage: {usage}") @@ -109,9 +152,9 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'hist' command. """ - mfilter = MessageFilter(tags_or=args.tags, - tags_and=args.atags, - tags_not=args.etags, + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, question_contains=args.question, answer_contains=args.answer) chat = ChatDB.from_dir(Path('.'), @@ -147,7 +190,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-c', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -157,19 +200,41 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) - tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names (one must match)', metavar='TAGS') + tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', + help='List of tag names (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore - atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', - help='List of tag names (all must match)', metavar='TAGS') + atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', + help='List of tag names (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore - etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', - help='List of tag names to exclude', metavar='ETAGS') + etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', + help='List of tag names to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OTAGS') + help='List of output tag names, default is input', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # 'question' command parser + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + help="ask, create and process questions.", + aliases=['q']) + question_cmd_parser.set_defaults(func=question_cmd) + question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) + question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') + question_group.add_argument('-c', '--create', nargs='+', help='Create a question') + question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') + question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') + question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', + action='store_true') + question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) + question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) + question_cmd_parser.add_argument('-A', '--AI', help='AI to use') + question_cmd_parser.add_argument('-M', '--model', help='Model to use') + question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, + default=1) + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', + action='store_true') + # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], help="Ask a question.", @@ -180,7 +245,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, + ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', diff --git a/tests/test_main.py b/tests/test_main.py index bb9aa2a..ce9121a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -114,14 +114,14 @@ class TestHandleQuestion(CmmTestCase): def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( - tags=['tag1'], - atags=None, - etags=['etag1'], + or_tags=['tag1'], + and_tags=None, + exclude_tags=['xtag1'], output_tags=None, question=[self.question], source=None, source_code_only=False, - number=3, + num_answers=3, max_tokens=None, temperature=None, model=None, @@ -143,12 +143,12 @@ class TestHandleQuestion(CmmTestCase): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.etags, + mock_print_tag_args.assert_called_once_with(self.args.or_tags, + self.args.exclude_tags, []) mock_create_chat_hist.assert_called_once_with(self.question, - self.args.tags, - self.args.etags, + self.args.or_tags, + self.args.exclude_tags, self.config, match_all_tags=False, with_tags=False, @@ -158,7 +158,7 @@ class TestHandleQuestion(CmmTestCase): self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, - self.args.number) + self.args.num_answers) expected_calls = [] for num, answer in enumerate(mock_ai.return_value[0], start=1): title = f'-- ANSWER {num} ' -- 2.36.6 From d6bb5800b16601a7fd9086f0b6c50991b78aed6e Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:12:05 +0200 Subject: [PATCH 099/121] test_main: temporarily disabled all testcases --- tests/test_chat.py | 6 +- tests/test_main.py | 468 +++++++++++++++++++++--------------------- tests/test_message.py | 34 +-- tests/test_tags.py | 6 +- 4 files changed, 257 insertions(+), 257 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index d81a97a..8e4aa8c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,3 +1,4 @@ +import unittest import pathlib import tempfile import time @@ -6,10 +7,9 @@ from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError -from .test_main import CmmTestCase -class TestChat(CmmTestCase): +class TestChat(unittest.TestCase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), @@ -131,7 +131,7 @@ Answer 2 self.assertEqual(mock_stdout.getvalue(), expected_output) -class TestChatDB(CmmTestCase): +class TestChatDB(unittest.TestCase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() diff --git a/tests/test_main.py b/tests/test_main.py index ce9121a..91e6462 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,236 +1,236 @@ -import unittest -import io -import pathlib -import argparse -from chatmastermind.utils import terminal_width -from chatmastermind.main import create_parser, ask_cmd -from chatmastermind.api_client import ai -from chatmastermind.configuration import Config -from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from unittest import mock -from unittest.mock import patch, MagicMock, Mock, ANY +# import unittest +# import io +# import pathlib +# import argparse +# from chatmastermind.utils import terminal_width +# from chatmastermind.main import create_parser, ask_cmd +# from chatmastermind.api_client import ai +# from chatmastermind.configuration import Config +# from chatmastermind.storage import create_chat_hist, save_answers, dump_data +# from unittest import mock +# from unittest.mock import patch, MagicMock, Mock, ANY -class CmmTestCase(unittest.TestCase): - """ - Base class for all cmm testcases. - """ - def dummy_config(self, db: str) -> Config: - """ - Creates a dummy configuration. - """ - return Config.from_dict( - {'system': 'dummy_system', - 'db': db, - 'openai': {'api_key': 'dummy_key', - 'model': 'dummy_model', - 'max_tokens': 4000, - 'temperature': 1.0, - 'top_p': 1, - 'frequency_penalty': 0, - 'presence_penalty': 0}} - ) - - -class TestCreateChat(CmmTestCase): - - def setUp(self) -> None: - self.config = self.dummy_config(db='test_files') - self.question = "test question" - self.tags = ['test_tag'] - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['test_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 4) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['other_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 2) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.side_effect = ( - io.StringIO(dump_data({'question': 'test_content', - 'answer': 'some answer', - 'tags': ['test_tag']})), - io.StringIO(dump_data({'question': 'test_content2', - 'answer': 'some answer2', - 'tags': ['test_tag2']})), - ) - - test_chat = create_chat_hist(self.question, [], None, self.config) - - self.assertEqual(len(test_chat), 6) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': 'test_content2'}) - self.assertEqual(test_chat[4], - {'role': 'assistant', 'content': 'some answer2'}) - - -class TestHandleQuestion(CmmTestCase): - - def setUp(self) -> None: - self.question = "test question" - self.args = argparse.Namespace( - or_tags=['tag1'], - and_tags=None, - exclude_tags=['xtag1'], - output_tags=None, - question=[self.question], - source=None, - source_code_only=False, - num_answers=3, - max_tokens=None, - temperature=None, - model=None, - match_all_tags=False, - with_tags=False, - with_file=False, - ) - self.config = self.dummy_config(db='test_files') - - @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") - @patch("chatmastermind.main.print_tag_args") - @patch("chatmastermind.main.print_chat_hist") - @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) - @patch("chatmastermind.utils.pp") - @patch("builtins.print") - def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, - mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, - mock_create_chat_hist: MagicMock) -> None: - open_mock = MagicMock() - with patch("chatmastermind.storage.open", open_mock): - ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.or_tags, - self.args.exclude_tags, - []) - mock_create_chat_hist.assert_called_once_with(self.question, - self.args.or_tags, - self.args.exclude_tags, - self.config, - match_all_tags=False, - with_tags=False, - with_file=False) - mock_print_chat_hist.assert_called_once_with('test_chat', - False, - self.args.source_code_only) - mock_ai.assert_called_with("test_chat", - self.config, - self.args.num_answers) - expected_calls = [] - for num, answer in enumerate(mock_ai.return_value[0], start=1): - title = f'-- ANSWER {num} ' - title_end = '-' * (terminal_width() - len(title)) - expected_calls.append(((f'{title}{title_end}',),)) - expected_calls.append(((answer,),)) - expected_calls.append((("-" * terminal_width(),),)) - expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - self.assertEqual(mock_print.call_args_list, expected_calls) - open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) - open_mock.assert_has_calls(open_expected_calls, any_order=True) - - -class TestSaveAnswers(CmmTestCase): - @mock.patch('builtins.open') - @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: - question = "Test question?" - answers = ["Answer 1", "Answer 2"] - tags = ["tag1", "tag2"] - otags = ["otag1", "otag2"] - config = self.dummy_config(db='test_db') - - with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ - mock.patch('chatmastermind.storage.yaml.dump'), \ - mock.patch('io.StringIO') as stringio_mock: - stringio_instance = stringio_mock.return_value - stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] - save_answers(question, answers, tags, otags, config) - - open_calls = [ - mock.call(pathlib.Path('test_db/.next'), 'r'), - mock.call(pathlib.Path('test_db/.next'), 'w'), - ] - open_mock.assert_has_calls(open_calls, any_order=True) - - -class TestAI(CmmTestCase): - - @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock) -> None: - mock_create.return_value = { - 'choices': [ - {'message': {'content': 'response_text_1'}}, - {'message': {'content': 'response_text_2'}} - ], - 'usage': {'tokens': 10} - } - - chat = [{"role": "system", "content": "hello ai"}] - config = self.dummy_config(db='dummy') - config.openai.model = "text-davinci-002" - config.openai.max_tokens = 150 - config.openai.temperature = 0.5 - - result = ai(chat, config, 2) - expected_result = (['response_text_1', 'response_text_2'], - {'tokens': 10}) - self.assertEqual(result, expected_result) - - -class TestCreateParser(CmmTestCase): - def test_create_parser(self) -> None: - with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: - mock_cmdparser = Mock() - mock_add_subparsers.return_value = mock_cmdparser - parser = create_parser() - self.assertIsInstance(parser, argparse.ArgumentParser) - mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) - mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) - self.assertTrue('.config.yaml' in parser.get_default('config')) +# class CmmTestCase(unittest.TestCase): +# """ +# Base class for all cmm testcases. +# """ +# def dummy_config(self, db: str) -> Config: +# """ +# Creates a dummy configuration. +# """ +# return Config.from_dict( +# {'system': 'dummy_system', +# 'db': db, +# 'openai': {'api_key': 'dummy_key', +# 'model': 'dummy_model', +# 'max_tokens': 4000, +# 'temperature': 1.0, +# 'top_p': 1, +# 'frequency_penalty': 0, +# 'presence_penalty': 0}} +# ) +# +# +# class TestCreateChat(CmmTestCase): +# +# def setUp(self) -> None: +# self.config = self.dummy_config(db='test_files') +# self.question = "test question" +# self.tags = ['test_tag'] +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['test_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 4) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['other_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 2) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.side_effect = ( +# io.StringIO(dump_data({'question': 'test_content', +# 'answer': 'some answer', +# 'tags': ['test_tag']})), +# io.StringIO(dump_data({'question': 'test_content2', +# 'answer': 'some answer2', +# 'tags': ['test_tag2']})), +# ) +# +# test_chat = create_chat_hist(self.question, [], None, self.config) +# +# self.assertEqual(len(test_chat), 6) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': 'test_content2'}) +# self.assertEqual(test_chat[4], +# {'role': 'assistant', 'content': 'some answer2'}) +# +# +# class TestHandleQuestion(CmmTestCase): +# +# def setUp(self) -> None: +# self.question = "test question" +# self.args = argparse.Namespace( +# or_tags=['tag1'], +# and_tags=None, +# exclude_tags=['xtag1'], +# output_tags=None, +# question=[self.question], +# source=None, +# source_code_only=False, +# num_answers=3, +# max_tokens=None, +# temperature=None, +# model=None, +# match_all_tags=False, +# with_tags=False, +# with_file=False, +# ) +# self.config = self.dummy_config(db='test_files') +# +# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") +# @patch("chatmastermind.main.print_tag_args") +# @patch("chatmastermind.main.print_chat_hist") +# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) +# @patch("chatmastermind.utils.pp") +# @patch("builtins.print") +# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, +# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, +# mock_create_chat_hist: MagicMock) -> None: +# open_mock = MagicMock() +# with patch("chatmastermind.storage.open", open_mock): +# ask_cmd(self.args, self.config) +# mock_print_tag_args.assert_called_once_with(self.args.or_tags, +# self.args.exclude_tags, +# []) +# mock_create_chat_hist.assert_called_once_with(self.question, +# self.args.or_tags, +# self.args.exclude_tags, +# self.config, +# match_all_tags=False, +# with_tags=False, +# with_file=False) +# mock_print_chat_hist.assert_called_once_with('test_chat', +# False, +# self.args.source_code_only) +# mock_ai.assert_called_with("test_chat", +# self.config, +# self.args.num_answers) +# expected_calls = [] +# for num, answer in enumerate(mock_ai.return_value[0], start=1): +# title = f'-- ANSWER {num} ' +# title_end = '-' * (terminal_width() - len(title)) +# expected_calls.append(((f'{title}{title_end}',),)) +# expected_calls.append(((answer,),)) +# expected_calls.append((("-" * terminal_width(),),)) +# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) +# self.assertEqual(mock_print.call_args_list, expected_calls) +# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) +# open_mock.assert_has_calls(open_expected_calls, any_order=True) +# +# +# class TestSaveAnswers(CmmTestCase): +# @mock.patch('builtins.open') +# @mock.patch('chatmastermind.storage.print') +# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: +# question = "Test question?" +# answers = ["Answer 1", "Answer 2"] +# tags = ["tag1", "tag2"] +# otags = ["otag1", "otag2"] +# config = self.dummy_config(db='test_db') +# +# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ +# mock.patch('chatmastermind.storage.yaml.dump'), \ +# mock.patch('io.StringIO') as stringio_mock: +# stringio_instance = stringio_mock.return_value +# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] +# save_answers(question, answers, tags, otags, config) +# +# open_calls = [ +# mock.call(pathlib.Path('test_db/.next'), 'r'), +# mock.call(pathlib.Path('test_db/.next'), 'w'), +# ] +# open_mock.assert_has_calls(open_calls, any_order=True) +# +# +# class TestAI(CmmTestCase): +# +# @patch("openai.ChatCompletion.create") +# def test_ai(self, mock_create: MagicMock) -> None: +# mock_create.return_value = { +# 'choices': [ +# {'message': {'content': 'response_text_1'}}, +# {'message': {'content': 'response_text_2'}} +# ], +# 'usage': {'tokens': 10} +# } +# +# chat = [{"role": "system", "content": "hello ai"}] +# config = self.dummy_config(db='dummy') +# config.openai.model = "text-davinci-002" +# config.openai.max_tokens = 150 +# config.openai.temperature = 0.5 +# +# result = ai(chat, config, 2) +# expected_result = (['response_text_1', 'response_text_2'], +# {'tokens': 10}) +# self.assertEqual(result, expected_result) +# +# +# class TestCreateParser(CmmTestCase): +# def test_create_parser(self) -> None: +# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: +# mock_cmdparser = Mock() +# mock_add_subparsers.return_value = mock_cmdparser +# parser = create_parser() +# self.assertIsInstance(parser, argparse.ArgumentParser) +# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) +# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) +# self.assertTrue('.config.yaml' in parser.get_default('config')) diff --git a/tests/test_message.py b/tests/test_message.py index a49c893..57d5982 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,12 +1,12 @@ +import unittest import pathlib import tempfile from typing import cast -from .test_main import CmmTestCase from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine -class SourceCodeTestCase(CmmTestCase): +class SourceCodeTestCase(unittest.TestCase): def test_source_code_with_include_delims(self) -> None: text = """ Some text before the code block @@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase): self.assertEqual(result, expected_result) -class QuestionTestCase(CmmTestCase): +class QuestionTestCase(unittest.TestCase): def test_question_with_header(self) -> None: with self.assertRaises(MessageError): Question(f"{Question.txt_header}\nWhat is your name?") @@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase): self.assertEqual(question, "What is your favorite color?") -class AnswerTestCase(CmmTestCase): +class AnswerTestCase(unittest.TestCase): def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): Answer(f"{Answer.txt_header}\nno") @@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase): self.assertEqual(answer, "No") -class MessageToFileTxtTestCase(CmmTestCase): +class MessageToFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -160,7 +160,7 @@ This is a question. self.message_complete.file_path = self.file_path -class MessageToFileYamlTestCase(CmmTestCase): +class MessageToFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase): self.assertEqual(content, expected_content) -class MessageFromFileTxtTestCase(CmmTestCase): +class MessageFromFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -388,7 +388,7 @@ This is a question. self.assertIsNone(message) -class MessageFromFileYamlTestCase(CmmTestCase): +class MessageFromFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase): self.assertIsNone(message) -class TagsFromFileTestCase(CmmTestCase): +class TagsFromFileTestCase(unittest.TestCase): def setUp(self) -> None: self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) @@ -663,7 +663,7 @@ This is an answer. self.assertSetEqual(tags, set()) -class TagsFromDirTestCase(CmmTestCase): +class TagsFromDirTestCase(unittest.TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory() @@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase): self.assertSetEqual(all_tags, set()) -class MessageIDTestCase(CmmTestCase): +class MessageIDTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase): self.message_no_file_path.msg_id() -class MessageHashTestCase(CmmTestCase): +class MessageHashTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase): self.assertIn(msg, msgs) -class MessageTagsStrTestCase(CmmTestCase): +class MessageTagsStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase): self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') -class MessageFilterTagsTestCase(CmmTestCase): +class MessageFilterTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) -class MessageInTestCase(CmmTestCase): +class MessageInTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase): self.assertFalse(message_in(self.message1, [self.message2])) -class MessageRenameTagsTestCase(CmmTestCase): +class MessageRenameTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase): self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -class MessageToStrTestCase(CmmTestCase): +class MessageToStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), Answer('This is an answer.'), diff --git a/tests/test_tags.py b/tests/test_tags.py index aa89a06..edd3c05 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -1,8 +1,8 @@ -from .test_main import CmmTestCase +import unittest from chatmastermind.tags import Tag, TagLine, TagError -class TestTag(CmmTestCase): +class TestTag(unittest.TestCase): def test_valid_tag(self) -> None: tag = Tag('mytag') self.assertEqual(tag, 'mytag') @@ -18,7 +18,7 @@ class TestTag(CmmTestCase): self.assertEqual(Tag.alternative_separators, [',']) -class TestTagLine(CmmTestCase): +class TestTagLine(unittest.TestCase): def test_valid_tagline(self) -> None: tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') -- 2.36.6 From 6a4cc7a65d9c3cc094a78568e4cded6a92e3f63e Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:23:29 +0200 Subject: [PATCH 100/121] setup: added 'ais' subfolder --- chatmastermind/ais/__init__.py | 0 setup.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 chatmastermind/ais/__init__.py diff --git a/chatmastermind/ais/__init__.py b/chatmastermind/ais/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 02d9ab1..8484629 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages(), + packages=find_packages() + ["chatmastermind.ais"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -32,7 +32,7 @@ setup( "openai", "PyYAML", "argcomplete", - "pytest" + "pytest", ], python_requires=">=3.9", test_suite="tests", -- 2.36.6 From 21d39c6c6646213caeaa595b9833dce0ffafbb33 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:43:23 +0200 Subject: [PATCH 101/121] cmm: removed all the old code and modules --- chatmastermind/api_client.py | 45 ------- chatmastermind/main.py | 104 ++------------- chatmastermind/storage.py | 121 ------------------ chatmastermind/utils.py | 81 ------------ tests/test_main.py | 236 ----------------------------------- 5 files changed, 12 insertions(+), 575 deletions(-) delete mode 100644 chatmastermind/api_client.py delete mode 100644 chatmastermind/storage.py delete mode 100644 chatmastermind/utils.py delete mode 100644 tests/test_main.py diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py deleted file mode 100644 index 2c4a094..0000000 --- a/chatmastermind/api_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import openai - -from .utils import ChatType -from .configuration import Config - - -def openai_api_key(api_key: str) -> None: - openai.api_key = api_key - - -def print_models() -> None: - """ - Print all models supported by the current AI. - """ - not_ready = [] - for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): - if engine['ready']: - print(engine['id']) - else: - not_ready.append(engine['id']) - if len(not_ready) > 0: - print('\nNot ready: ' + ', '.join(not_ready)) - - -def ai(chat: ChatType, - config: Config, - number: int - ) -> tuple[list[str], dict[str, int]]: - """ - Make AI request with the given chat history and configuration. - Return AI response and tokens used. - """ - response = openai.ChatCompletion.create( - model=config.openai.model, - messages=chat, - temperature=config.openai.temperature, - max_tokens=config.openai.max_tokens, - top_p=config.openai.top_p, - n=number, - frequency_penalty=config.openai.frequency_penalty, - presence_penalty=config.openai.presence_penalty) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 67eafae..58ce9ed 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,61 +6,19 @@ import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType -from .storage import save_answers, create_chat_hist -from .api_client import ai, openai_api_key, print_models -from .configuration import Config +from .configuration import Config, default_config_path from .chat import ChatDB from .message import Message, MessageFilter, MessageError, Question from .ai_factory import create_ai from .ai import AI, AIResponse -from itertools import zip_longest from typing import Any -default_config = '.config.yaml' - def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) return get_tags_unique(config, prefix) -def create_question_with_hist(args: argparse.Namespace, - config: Config, - ) -> tuple[ChatType, str, list[str]]: - """ - Creates the "AI request", including the question and chat history as determined - by the specified tags. - """ - tags = args.or_tags or [] - xtags = args.exclude_tags or [] - otags = args.output_tags or [] - - if not args.source_code_only: - print_tag_args(tags, xtags, otags) - - question_parts = [] - question_list = args.question if args.question is not None else [] - source_list = args.source if args.source is not None else [] - - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: - question_parts.append(question) - elif source is not None: - with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") - - full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, xtags, config, - match_all_tags=True if args.and_tags else False, # FIXME - with_tags=False, - with_file=False) - return chat, full_question, tags - - def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. @@ -74,17 +32,12 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: # TODO: add renaming -def config_cmd(args: argparse.Namespace, config: Config) -> None: +def config_cmd(args: argparse.Namespace) -> None: """ Handler for the 'config' command. """ - if args.list_models: - print_models() - elif args.print_model: - print(config.openai.model) - elif args.model: - config.openai.model = args.model - config.to_file(args.config) + if args.create: + Config.create_default(Path(args.create)) def question_cmd(args: argparse.Namespace, config: Config) -> None: @@ -95,6 +48,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: db_path=Path(config.db)) # if it's a new question, create and store it immediately if args.ask or args.create: + # FIXME: add sources to the question message = Message(question=Question(args.question), tags=args.ouput_tags, # FIXME ai=args.ai, @@ -128,25 +82,6 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: pass -def ask_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'ask' command. - """ - if args.max_tokens: - config.openai.max_tokens = args.max_tokens - if args.temperature: - config.openai.temperature = args.temperature - if args.model: - config.openai.model = args.model - chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.source_code_only) - otags = args.output_tags or [] - answers, usage = ai(chat, config, args.num_answers) - save_answers(question, answers, tags, otags, config) - print("-" * terminal_width()) - print(f"Usage: {usage}") - - def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. @@ -190,7 +125,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-C', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -235,22 +170,6 @@ def create_parser() -> argparse.ArgumentParser: question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') - # 'ask' command parser - ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.", - aliases=['a']) - ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', - required=True) - ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') - # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], help="Print chat history.", @@ -286,7 +205,7 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') - config_group.add_argument('-M', '--model', help="Set model in the config file") + config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', @@ -315,11 +234,12 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = Config.from_file(args.config) - openai_api_key(config.openai.api_key) - - command.func(command, config) + if command.func == config_cmd: + command.func(command) + else: + config = Config.from_file(args.config) + command.func(command, config) return 0 diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py deleted file mode 100644 index 8b9ed97..0000000 --- a/chatmastermind/storage.py +++ /dev/null @@ -1,121 +0,0 @@ -import yaml -import io -import pathlib -from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config -from typing import Any, Optional - - -def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: - with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() - # also support tags separated by ',' (old format) - separator = ',' if ',' in tagline else ' ' - tags = [t.strip() for t in tagline.split(separator)] - if tags_only: - return {"tags": tags} - text = fd.read().strip().split('\n') - question_idx = text.index("=== QUESTION ===") + 1 - answer_idx = text.index("==== ANSWER ====") - question = "\n".join(text[question_idx:answer_idx]).strip() - answer = "\n".join(text[answer_idx + 1:]).strip() - return {"question": question, "answer": answer, "tags": tags, - "file": fname.name} - - -def dump_data(data: dict[str, Any]) -> str: - with io.StringIO() as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - return fd.getvalue() - - -def write_file(fname: str, data: dict[str, Any]) -> None: - with open(fname, "w") as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - - -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]], - config: Config - ) -> None: - wtags = otags or tags - num, inum = 0, 0 - next_fname = pathlib.Path(str(config.db)) / '.next' - try: - with open(next_fname, 'r') as f: - num = int(f.read()) - except Exception: - pass - for answer in answers: - num += 1 - inum += 1 - title = f'-- ANSWER {inum} ' - title_end = '-' * (terminal_width() - len(title)) - print(f'{title}{title_end}') - print(answer) - write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) - with open(next_fname, 'w') as f: - f.write(f'{num}') - - -def create_chat_hist(question: Optional[str], - tags: Optional[list[str]], - extags: Optional[list[str]], - config: Config, - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> ChatType: - chat: ChatType = [] - append_message(chat, 'system', str(config.system).strip()) - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data['file'] = file.name - elif file.suffix == '.txt': - data = read_file(file) - else: - continue - data_tags = set(data.get('tags', [])) - tags_match: bool - if match_all_tags: - tags_match = not tags or set(tags).issubset(data_tags) - else: - tags_match = not tags or bool(data_tags.intersection(tags)) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat, with_tags, with_file) - if question: - append_message(chat, 'user', question) - return chat - - -def get_tags(config: Config, prefix: Optional[str]) -> list[str]: - result = [] - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif file.suffix == '.txt': - data = read_file(file, tags_only=True) - else: - continue - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return result - - -def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: - return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py deleted file mode 100644 index 4135ae3..0000000 --- a/chatmastermind/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import shutil -from pprint import PrettyPrinter -from typing import Any - -ChatType = list[dict[str, str]] - - -def terminal_width() -> int: - return shutil.get_terminal_size().columns - - -def pp(*args: Any, **kwargs: Any) -> None: - return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) - - -def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: - """ - Prints the tags specified in the given args. - """ - printed_messages = [] - - if tags: - printed_messages.append(f"Tags: {' '.join(tags)}") - if extags: - printed_messages.append(f"Excluding tags: {' '.join(extags)}") - if otags: - printed_messages.append(f"Output tags: {' '.join(otags)}") - - if printed_messages: - print("\n".join(printed_messages)) - print() - - -def append_message(chat: ChatType, - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: dict[str, str], - chat: ChatType, - with_tags: bool = False, - with_file: bool = False - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - if with_tags: - tags = " ".join(message['tags']) - append_message(chat, 'tags', tags) - if with_file: - append_message(chat, 'file', message['file']) - - -def display_source_code(content: str) -> None: - try: - content_start = content.index('```') - content_start = content.index('\n', content_start) + 1 - content_end = content.rindex('```') - if content_start < content_end: - print(content[content_start:content_end].strip()) - except ValueError: - pass - - -def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: - if dump: - pp(chat) - return - for message in chat: - text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 - if source_code: - display_source_code(message['content']) - continue - if message['role'] == 'user': - print('-' * terminal_width()) - if text_too_long: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 91e6462..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,236 +0,0 @@ -# import unittest -# import io -# import pathlib -# import argparse -# from chatmastermind.utils import terminal_width -# from chatmastermind.main import create_parser, ask_cmd -# from chatmastermind.api_client import ai -# from chatmastermind.configuration import Config -# from chatmastermind.storage import create_chat_hist, save_answers, dump_data -# from unittest import mock -# from unittest.mock import patch, MagicMock, Mock, ANY - - -# class CmmTestCase(unittest.TestCase): -# """ -# Base class for all cmm testcases. -# """ -# def dummy_config(self, db: str) -> Config: -# """ -# Creates a dummy configuration. -# """ -# return Config.from_dict( -# {'system': 'dummy_system', -# 'db': db, -# 'openai': {'api_key': 'dummy_key', -# 'model': 'dummy_model', -# 'max_tokens': 4000, -# 'temperature': 1.0, -# 'top_p': 1, -# 'frequency_penalty': 0, -# 'presence_penalty': 0}} -# ) -# -# -# class TestCreateChat(CmmTestCase): -# -# def setUp(self) -> None: -# self.config = self.dummy_config(db='test_files') -# self.question = "test question" -# self.tags = ['test_tag'] -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['test_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 4) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['other_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 2) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.side_effect = ( -# io.StringIO(dump_data({'question': 'test_content', -# 'answer': 'some answer', -# 'tags': ['test_tag']})), -# io.StringIO(dump_data({'question': 'test_content2', -# 'answer': 'some answer2', -# 'tags': ['test_tag2']})), -# ) -# -# test_chat = create_chat_hist(self.question, [], None, self.config) -# -# self.assertEqual(len(test_chat), 6) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': 'test_content2'}) -# self.assertEqual(test_chat[4], -# {'role': 'assistant', 'content': 'some answer2'}) -# -# -# class TestHandleQuestion(CmmTestCase): -# -# def setUp(self) -> None: -# self.question = "test question" -# self.args = argparse.Namespace( -# or_tags=['tag1'], -# and_tags=None, -# exclude_tags=['xtag1'], -# output_tags=None, -# question=[self.question], -# source=None, -# source_code_only=False, -# num_answers=3, -# max_tokens=None, -# temperature=None, -# model=None, -# match_all_tags=False, -# with_tags=False, -# with_file=False, -# ) -# self.config = self.dummy_config(db='test_files') -# -# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") -# @patch("chatmastermind.main.print_tag_args") -# @patch("chatmastermind.main.print_chat_hist") -# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) -# @patch("chatmastermind.utils.pp") -# @patch("builtins.print") -# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, -# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, -# mock_create_chat_hist: MagicMock) -> None: -# open_mock = MagicMock() -# with patch("chatmastermind.storage.open", open_mock): -# ask_cmd(self.args, self.config) -# mock_print_tag_args.assert_called_once_with(self.args.or_tags, -# self.args.exclude_tags, -# []) -# mock_create_chat_hist.assert_called_once_with(self.question, -# self.args.or_tags, -# self.args.exclude_tags, -# self.config, -# match_all_tags=False, -# with_tags=False, -# with_file=False) -# mock_print_chat_hist.assert_called_once_with('test_chat', -# False, -# self.args.source_code_only) -# mock_ai.assert_called_with("test_chat", -# self.config, -# self.args.num_answers) -# expected_calls = [] -# for num, answer in enumerate(mock_ai.return_value[0], start=1): -# title = f'-- ANSWER {num} ' -# title_end = '-' * (terminal_width() - len(title)) -# expected_calls.append(((f'{title}{title_end}',),)) -# expected_calls.append(((answer,),)) -# expected_calls.append((("-" * terminal_width(),),)) -# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) -# self.assertEqual(mock_print.call_args_list, expected_calls) -# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) -# open_mock.assert_has_calls(open_expected_calls, any_order=True) -# -# -# class TestSaveAnswers(CmmTestCase): -# @mock.patch('builtins.open') -# @mock.patch('chatmastermind.storage.print') -# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: -# question = "Test question?" -# answers = ["Answer 1", "Answer 2"] -# tags = ["tag1", "tag2"] -# otags = ["otag1", "otag2"] -# config = self.dummy_config(db='test_db') -# -# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ -# mock.patch('chatmastermind.storage.yaml.dump'), \ -# mock.patch('io.StringIO') as stringio_mock: -# stringio_instance = stringio_mock.return_value -# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] -# save_answers(question, answers, tags, otags, config) -# -# open_calls = [ -# mock.call(pathlib.Path('test_db/.next'), 'r'), -# mock.call(pathlib.Path('test_db/.next'), 'w'), -# ] -# open_mock.assert_has_calls(open_calls, any_order=True) -# -# -# class TestAI(CmmTestCase): -# -# @patch("openai.ChatCompletion.create") -# def test_ai(self, mock_create: MagicMock) -> None: -# mock_create.return_value = { -# 'choices': [ -# {'message': {'content': 'response_text_1'}}, -# {'message': {'content': 'response_text_2'}} -# ], -# 'usage': {'tokens': 10} -# } -# -# chat = [{"role": "system", "content": "hello ai"}] -# config = self.dummy_config(db='dummy') -# config.openai.model = "text-davinci-002" -# config.openai.max_tokens = 150 -# config.openai.temperature = 0.5 -# -# result = ai(chat, config, 2) -# expected_result = (['response_text_1', 'response_text_2'], -# {'tokens': 10}) -# self.assertEqual(result, expected_result) -# -# -# class TestCreateParser(CmmTestCase): -# def test_create_parser(self) -> None: -# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: -# mock_cmdparser = Mock() -# mock_add_subparsers.return_value = mock_cmdparser -# parser = create_parser() -# self.assertIsInstance(parser, argparse.ArgumentParser) -# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) -# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) -# self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From 61e710a4b1d7b5570862714376ae6262b26dcb9f Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 13:31:01 +0200 Subject: [PATCH 102/121] cmm: splitted commands into separate modules (and more cleanup) --- chatmastermind/commands/config.py | 11 ++++++ chatmastermind/commands/hist.py | 23 ++++++++++++ chatmastermind/commands/print.py | 19 ++++++++++ chatmastermind/commands/question.py | 57 +++++++++++++++++++++++++++++ chatmastermind/commands/tags.py | 17 +++++++++ chatmastermind/main.py | 44 ++++++++++------------ setup.py | 2 +- tests/test_ai_factory.py | 48 ++++++++++++++++++++++++ 8 files changed, 196 insertions(+), 25 deletions(-) create mode 100644 chatmastermind/commands/config.py create mode 100644 chatmastermind/commands/hist.py create mode 100644 chatmastermind/commands/print.py create mode 100644 chatmastermind/commands/question.py create mode 100644 chatmastermind/commands/tags.py create mode 100644 tests/test_ai_factory.py diff --git a/chatmastermind/commands/config.py b/chatmastermind/commands/config.py new file mode 100644 index 0000000..262164c --- /dev/null +++ b/chatmastermind/commands/config.py @@ -0,0 +1,11 @@ +import argparse +from pathlib import Path +from ..configuration import Config + + +def config_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'config' command. + """ + if args.create: + Config.create_default(Path(args.create)) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py new file mode 100644 index 0000000..88ed3be --- /dev/null +++ b/chatmastermind/commands/hist.py @@ -0,0 +1,23 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import MessageFilter + + +def hist_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'hist' command. + """ + + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py new file mode 100644 index 0000000..51e76f8 --- /dev/null +++ b/chatmastermind/commands/print.py @@ -0,0 +1,19 @@ +import sys +import argparse +from pathlib import Path +from ..configuration import Config +from ..message import Message, MessageError + + +def print_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'print' command. + """ + fname = Path(args.file) + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") + sys.exit(1) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py new file mode 100644 index 0000000..9c56ced --- /dev/null +++ b/chatmastermind/commands/question.py @@ -0,0 +1,57 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import Message, Question +from ..ai_factory import create_ai +from ..ai import AI, AIResponse + + +def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: + """ + Creates (and writes) a new message from the given arguments. + """ + # FIXME: add sources to the question + message = Message(question=Question(args.question), + tags=args.output_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + return message + + +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = create_message(chat, args) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass diff --git a/chatmastermind/commands/tags.py b/chatmastermind/commands/tags.py new file mode 100644 index 0000000..2906a5b --- /dev/null +++ b/chatmastermind/commands/tags.py @@ -0,0 +1,17 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB + + +def tags_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'tags' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + if args.list: + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 58ce9ed..02cdffd 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,12 +6,14 @@ import sys import argcomplete import argparse from pathlib import Path -from .configuration import Config, default_config_path -from .chat import ChatDB -from .message import Message, MessageFilter, MessageError, Question -from .ai_factory import create_ai -from .ai import AI, AIResponse from typing import Any +from .configuration import Config, default_config_path +from .message import Message +from .commands.question import question_cmd +from .commands.tags import tags_cmd +from .commands.config import config_cmd +from .commands.hist import hist_cmd +from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: @@ -136,20 +138,28 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', - help='List of tag names (one must match)', metavar='OTAGS') + help='List of tags (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', - help='List of tag names (all must match)', metavar='ATAGS') + help='List of tags (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', - help='List of tag names to exclude', metavar='XTAGS') + help='List of tags to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OUTTAGS') + help='List of output tags (default: use input tags)', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # a parent parser for all commands that support AI configuration + ai_parser = argparse.ArgumentParser(add_help=False) + ai_parser.add_argument('-A', '--AI', help='AI ID to use') + ai_parser.add_argument('-M', '--model', help='Model to use') + ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) + ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) + ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) + # 'question' command parser - question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], help="ask, create and process questions.", aliases=['q']) question_cmd_parser.set_defaults(func=question_cmd) @@ -160,12 +170,6 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - question_cmd_parser.add_argument('-A', '--AI', help='AI to use') - question_cmd_parser.add_argument('-M', '--model', help='Model to use') - question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') @@ -213,18 +217,10 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) -<<<<<<< HEAD print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') -||||||| parent of bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', - action='store_true') -======= - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', - action='store_true') ->>>>>>> bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') argcomplete.autocomplete(parser) return parser diff --git a/setup.py b/setup.py index 8484629..a311605 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages() + ["chatmastermind.ais"], + packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py new file mode 100644 index 0000000..d63970e --- /dev/null +++ b/tests/test_ai_factory.py @@ -0,0 +1,48 @@ +import argparse +import unittest +from unittest.mock import MagicMock +from chatmastermind.ai_factory import create_ai +from chatmastermind.configuration import Config +from chatmastermind.ai import AIError +from chatmastermind.ais.openai import OpenAI + + +class TestCreateAI(unittest.TestCase): + def setUp(self) -> None: + self.args = MagicMock(spec=argparse.Namespace) + self.args.ai = 'default' + self.args.model = None + self.args.max_tokens = None + self.args.temperature = None + + def test_create_ai_from_args(self) -> None: + # Create an AI with the default configuration + config = Config() + self.args.ai = 'default' + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_ai_from_default(self) -> None: + self.args.ai = None + # Create an AI with the default configuration + config = Config() + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_empty_ai_error(self) -> None: + self.args.ai = None + # Create Config with empty AIs + config = Config() + config.ais = {} + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) + + def test_create_unsupported_ai_error(self) -> None: + # Mock argparse.Namespace with ai='invalid_ai' + self.args.ai = 'invalid_ai' + # Create default Config + config = Config() + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) -- 2.36.6 From ecb699478335c1c054b8dd917762c967270dac5b Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:52:03 +0200 Subject: [PATCH 103/121] configuration et al: implemented new Config format --- chatmastermind/ai.py | 13 ++-- chatmastermind/ai_factory.py | 29 ++++++-- chatmastermind/ais/openai.py | 9 +-- chatmastermind/configuration.py | 119 ++++++++++++++++++++++++++------ 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index 4a8b914..e94de8e 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -33,18 +33,23 @@ class AI(Protocol): The base class for AI clients. """ + ID: str name: str config: AIConfig def request(self, question: Message, - context: Chat, + chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ - Make an AI request, asking the given question with the given - context (i. e. chat history). The nr. of requested answers - corresponds to the nr. of messages in the 'AIResponse'. + Make an AI request. Parameters: + * question: the question to ask + * chat: the chat history to be added as context + * num_answers: nr. of requested answers (corresponds + to the nr. of messages in the 'AIResponse') + * otags: the output tags, i. e. the tags that all + returned messages should contain """ raise NotImplementedError diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c90366b..c4a063a 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -3,18 +3,35 @@ Creates different AI instances, based on the given configuration. """ import argparse -from .configuration import Config +from typing import cast +from .configuration import Config, OpenAIConfig, default_ai_ID from .ai import AI, AIError from .ais.openai import OpenAI def create_ai(args: argparse.Namespace, config: Config) -> AI: """ - Creates an AI subclass instance from the given args and configuration. + Creates an AI subclass instance from the given arguments + and configuration file. """ - if args.ai == 'openai': - # FIXME: create actual 'OpenAIConfig' and set values from 'args' - # FIXME: use actual name from config - return OpenAI("openai", config.openai) + if args.ai: + try: + ai_conf = config.ais[args.ai] + except KeyError: + raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + elif default_ai_ID in config.ais: + ai_conf = config.ais[default_ai_ID] + else: + raise AIError("No AI name given and no default exists") + + if ai_conf.name == 'openai': + ai = OpenAI(cast(OpenAIConfig, ai_conf)) + if args.model: + ai.config.model = args.model + if args.max_tokens: + ai.config.max_tokens = args.max_tokens + if args.temperature: + ai.config.temperature = args.temperature + return ai else: raise AIError(f"AI '{args.ai}' is not supported") diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 74438b8..14ce33f 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -17,9 +17,11 @@ class OpenAI(AI): The OpenAI AI client. """ - def __init__(self, name: str, config: OpenAIConfig) -> None: - self.name = name + def __init__(self, config: OpenAIConfig) -> None: + self.ID = config.ID + self.name = config.name self.config = config + openai.api_key = config.api_key def request(self, question: Message, @@ -31,8 +33,7 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - # FIXME: use real 'system' message (store in OpenAIConfig) - oai_chat = self.openai_chat(chat, "system", question) + oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0780604..d82f913 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,17 +1,40 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional, ClassVar +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] +default_ai_ID: str = 'default' +default_config_path = '.config.yaml' + + +class ConfigError(Exception): + pass + + @dataclass class AIConfig: """ The base class of all AI configurations. """ - name: str + # the name of the AI the config class represents + # -> it's a class variable and thus not part of the + # dataclass constructor + name: ClassVar[str] + # a user-defined ID for an AI configuration entry + ID: str + + # the name must not be changed + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name': + raise AttributeError("'{name}' is not allowed to be changed") + else: + super().__setattr__(name, value) @dataclass @@ -19,21 +42,27 @@ class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + name: ClassVar[str] = 'openai' + + # all members have default values, so we can easily create + # a default configuration + ID: str = 'default' + api_key: str = '0123456789' + system: str = 'You are an assistant' + model: str = 'gpt-3.5-turbo-16k' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: """ Create OpenAIConfig from a dict. """ - return cls( - name='OpenAI', + res = cls( + system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -42,6 +71,30 @@ class OpenAIConfig(AIConfig): frequency_penalty=float(source['frequency_penalty']), presence_penalty=float(source['presence_penalty']) ) + # overwrite default ID if provided + if 'ID' in source: + res.ID = source['ID'] + return res + + +def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given name. + """ + if name.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"AI '{name}' is not supported") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais} @dataclass @@ -49,30 +102,52 @@ class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create Config from a dict. + Create Config from a dict (with the same format as the config file). """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for ID, conf in source['ais'].items(): + # add the AI ID to the config (for easy internal access) + conf['ID'] = ID + ai_conf = ai_config_instance(conf['name'], conf) + ais[ID] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f, sort_keys=False) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for conf in data['ais'].values(): + del (conf['ID']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) def as_dict(self) -> dict[str, Any]: - return asdict(self) + res = asdict(self) + for ID, conf in res['ais'].items(): + conf.update({'name': self.ais[ID].name}) + return res -- 2.36.6 From c52713c833754290d931f174e0a8aa402e0fd58b Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 10:40:22 +0200 Subject: [PATCH 104/121] configuration: added tests --- chatmastermind/configuration.py | 2 +- tests/test_configuration.py | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 tests/test_configuration.py diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index d82f913..398fa03 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -87,7 +87,7 @@ def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> else: return OpenAIConfig.from_dict(conf_dict) else: - raise ConfigError(f"AI '{name}' is not supported") + raise ConfigError(f"Unknown AI '{name}'") def create_default_ai_configs() -> dict[str, AIConfig]: diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 0000000..f3f9a98 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,160 @@ +import os +import unittest +import yaml +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import cast +from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config + + +class TestAIConfigInstance(unittest.TestCase): + def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: + ai_config = cast(OpenAIConfig, ai_config_instance('openai')) + ai_reference = OpenAIConfig() + self.assertEqual(ai_config.ID, ai_reference.ID) + self.assertEqual(ai_config.name, ai_reference.name) + self.assertEqual(ai_config.api_key, ai_reference.api_key) + self.assertEqual(ai_config.system, ai_reference.system) + self.assertEqual(ai_config.model, ai_reference.model) + self.assertEqual(ai_config.temperature, ai_reference.temperature) + self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) + self.assertEqual(ai_config.top_p, ai_reference.top_p) + self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) + self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) + + def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: + conf_dict = { + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) + self.assertEqual(ai_config.system, 'Custom system') + self.assertEqual(ai_config.api_key, '9876543210') + self.assertEqual(ai_config.model, 'custom_model') + self.assertEqual(ai_config.max_tokens, 5000) + self.assertAlmostEqual(ai_config.temperature, 0.5) + self.assertAlmostEqual(ai_config.top_p, 0.8) + self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) + self.assertAlmostEqual(ai_config.presence_penalty, 0.2) + + def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: + with self.assertRaises(ConfigError): + ai_config_instance('invalid_name') + + +class TestConfig(unittest.TestCase): + def setUp(self) -> None: + self.test_file = NamedTemporaryFile(delete=False) + + def tearDown(self) -> None: + os.remove(self.test_file.name) + + def test_from_dict_should_create_config_from_dict(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + config = Config.from_dict(source_dict) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertEqual(config.ais['default'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + # check that 'ID' has been added + self.assertEqual(config.ais['default'].ID, 'default') + + def test_create_default_should_create_default_config(self) -> None: + Config.create_default(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + default_config = yaml.load(f, Loader=yaml.FullLoader) + config_reference = Config() + self.assertEqual(default_config['db'], config_reference.db) + + def test_from_file_should_load_config_from_file(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + config = Config.from_file(self.test_file.name) + self.assertIsInstance(config, Config) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertIsInstance(config.ais['default'], AIConfig) + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + + def test_to_file_should_save_config_to_file(self) -> None: + config = Config( + db='./test_db/', + ais={ + 'default': OpenAIConfig( + ID='default', + system='Custom system', + api_key='9876543210', + model='custom_model', + max_tokens=5000, + temperature=0.5, + top_p=0.8, + frequency_penalty=0.7, + presence_penalty=0.2 + ) + } + ) + config.to_file(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['db'], './test_db/') + self.assertEqual(len(saved_config['ais']), 1) + self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + + def test_from_file_error_unknown_ai(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'foobla', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + with self.assertRaises(ConfigError): + Config.from_file(self.test_file.name) -- 2.36.6 From c4f7bcc94e87811a5788f0eea65e0e719c29e68b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:51:17 +0200 Subject: [PATCH 105/121] question_cmd: fixes --- chatmastermind/commands/question.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 9c56ced..1709a3c 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path +from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB from ..message import Message, Question @@ -11,8 +12,26 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Creates (and writes) a new message from the given arguments. """ - # FIXME: add sources to the question - message = Message(question=Question(args.question), + question_parts = [] + question_list = args.question if args.question is not None else [] + source_list = args.source if args.source is not None else [] + + # FIXME: don't surround all sourced files with ``` + # -> do it only if '--source-code-only' is True and no source code + # could be extracted from that file + for question, source in zip_longest(question_list, source_list, fillvalue=None): + if question is not None and source is not None: + with open(source) as r: + question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") + elif question is not None: + question_parts.append(question) + elif source is not None: + with open(source) as r: + question_parts.append(f"```\n{r.read().strip()}\n```") + + full_question = '\n\n'.join(question_parts) + + message = Message(question=Question(full_question), tags=args.output_tags, # FIXME ai=args.ai, model=args.model) -- 2.36.6 From 3eca53998b674d0cc6a218c69170b2f60c110355 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:31:30 +0200 Subject: [PATCH 106/121] question cmd: added tests --- tests/test_question_cmd.py | 111 +++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_question_cmd.py diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py new file mode 100644 index 0000000..96b2fdf --- /dev/null +++ b/tests/test_question_cmd.py @@ -0,0 +1,111 @@ +import os +import unittest +import argparse +import tempfile +from pathlib import Path +from unittest.mock import MagicMock +from chatmastermind.commands.question import create_message +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB + + +class TestMessageCreate(unittest.TestCase): + """ + Test if messages created by the 'question' command have + the correct format. + """ + def setUp(self) -> None: + # create ChatDB structure + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), + db_path=Path(self.db_path.name)) + # create arguments mock + self.args = MagicMock(spec=argparse.Namespace) + self.args.source = None + self.args.source_code_only = False + self.args.ai = None + self.args.model = None + self.args.output_tags = None + # create some files for sourcing + self.source_file1 = tempfile.NamedTemporaryFile(delete=False) + self.source_file1_content = """This is just text. +No source code. +Nope. Go look elsewhere!""" + with open(self.source_file1.name, 'w') as f: + f.write(self.source_file1_content) + self.source_file2 = tempfile.NamedTemporaryFile(delete=False) + self.source_file2_content = """This is just text. +``` +This is embedded source code. +``` +And some text again.""" + with open(self.source_file2.name, 'w') as f: + f.write(self.source_file2_content) + self.source_file3 = tempfile.NamedTemporaryFile(delete=False) + self.source_file3_content = """This is all source code. +Yes, really. +Language is called 'brainfart'.""" + with open(self.source_file3.name, 'w') as f: + f.write(self.source_file3_content) + + def tearDown(self) -> None: + os.remove(self.source_file1.name) + os.remove(self.source_file2.name) + os.remove(self.source_file3.name) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return list(Path(tmp_dir.name).glob('*.[ty]*')) + + def test_message_file_created(self) -> None: + self.args.question = ["What is this?"] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + create_message(self.chat, self.args) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + message = Message.from_file(cache_dir_files[0]) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] + + def test_single_question(self) -> None: + self.args.question = ["What is this?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) + self.assertEqual(len(message.question.source_code()), 0) + + def test_multipart_question(self) -> None: + self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("""What is this + +'bard' thing? + +Is it good?""")) + + def test_single_question_with_text_only_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file1.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains no source code + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file1_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file2_content}""")) -- 2.36.6 From 86eebc39eafa1fcdfa66ac3eec7aa2c1049c9582 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:16:17 +0200 Subject: [PATCH 107/121] Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. --- chatmastermind/commands/question.py | 22 ++++++++++++---------- chatmastermind/main.py | 5 ++--- tests/test_question_cmd.py | 22 ++++++++++++++++++---- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 1709a3c..818b1de 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -15,19 +15,21 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: question_parts = [] question_list = args.question if args.question is not None else [] source_list = args.source if args.source is not None else [] + code_list = args.source_code if args.source_code is not None else [] - # FIXME: don't surround all sourced files with ``` - # -> do it only if '--source-code-only' is True and no source code - # could be extracted from that file - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: + for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + if question is not None and len(question.strip()) > 0: question_parts.append(question) - elif source is not None: + if source is not None and len(source) > 0: with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + if code is not None and len(code) > 0: + with open(code) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 02cdffd..46bad44 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -170,9 +170,8 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 96b2fdf..06cc527 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -23,7 +23,7 @@ class TestMessageCreate(unittest.TestCase): # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source = None - self.args.source_code_only = False + self.args.source_code = None self.args.ai = None self.args.model = None self.args.output_tags = None @@ -94,11 +94,11 @@ Is it good?""")) # source file contains no source code # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_embedded_source_source(self) -> None: self.args.question = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) @@ -106,6 +106,20 @@ Is it good?""")) # source file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file2_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source_code = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question(f"""What is this? + +``` +{self.source_file2_content} +```""")) -- 2.36.6 From 54ece6efeb23f36fa6ffc156ed5dc3d97ec83752 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:38:40 +0200 Subject: [PATCH 108/121] Port print arguments -q/-a/-S from main to restructuring. --- chatmastermind/commands/print.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py index 51e76f8..3d2b990 100644 --- a/chatmastermind/commands/print.py +++ b/chatmastermind/commands/print.py @@ -13,7 +13,15 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: try: message = Message.from_file(fname) if message: - print(message.to_str(source_code_only=args.source_code_only)) + if args.question: + print(message.question) + elif args.answer: + print(message.answer) + elif message.answer and args.only_source_code: + for code in message.answer.source_code(): + print(code) + else: + print(message.to_str()) except MessageError: print(f"File is not a valid message: {args.file}") sys.exit(1) -- 2.36.6 From 6f3ea9842564f26b86eb7235179962c97c9999b0 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 16:05:27 +0200 Subject: [PATCH 109/121] Small fixes. --- chatmastermind/ai_factory.py | 8 ++++---- chatmastermind/commands/question.py | 6 +++--- tests/test_ai_factory.py | 10 +++++----- tests/test_question_cmd.py | 14 +++++++------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c4a063a..bc4583c 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -14,11 +14,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: Creates an AI subclass instance from the given arguments and configuration file. """ - if args.ai: + if args.AI: try: - ai_conf = config.ais[args.ai] + ai_conf = config.ais[args.AI] except KeyError: - raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") elif default_ai_ID in config.ais: ai_conf = config.ais[default_ai_ID] else: @@ -34,4 +34,4 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: ai.config.temperature = args.temperature return ai else: - raise AIError(f"AI '{args.ai}' is not supported") + raise AIError(f"AI '{args.AI}' is not supported") diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 818b1de..90b782b 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -13,7 +13,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: Creates (and writes) a new message from the given arguments. """ question_parts = [] - question_list = args.question if args.question is not None else [] + question_list = args.ask if args.ask is not None else [] source_list = args.source if args.source is not None else [] code_list = args.source_code if args.source_code is not None else [] @@ -35,7 +35,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: message = Message(question=Question(full_question), tags=args.output_tags, # FIXME - ai=args.ai, + ai=args.AI, model=args.model) chat.add_to_cache([message]) return message @@ -59,7 +59,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME - args.otags) # FIXME + args.output_tags) # FIXME assert response # TODO: # * add answer to the message above (and create diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d63970e..d00b319 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.ai = 'default' + self.args.AI = 'default' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,19 +18,19 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.ai = 'default' + self.args.AI = 'default' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_ai_from_default(self) -> None: - self.args.ai = None + self.args.AI = None # Create an AI with the default configuration config = Config() ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_empty_ai_error(self) -> None: - self.args.ai = None + self.args.AI = None # Create Config with empty AIs config = Config() config.ais = {} @@ -40,7 +40,7 @@ class TestCreateAI(unittest.TestCase): def test_create_unsupported_ai_error(self) -> None: # Mock argparse.Namespace with ai='invalid_ai' - self.args.ai = 'invalid_ai' + self.args.AI = 'invalid_ai' # Create default Config config = Config() # Call create_ai function and assert that it raises AIError diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 06cc527..aa0dc25 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -24,7 +24,7 @@ class TestMessageCreate(unittest.TestCase): self.args = MagicMock(spec=argparse.Namespace) self.args.source = None self.args.source_code = None - self.args.ai = None + self.args.AI = None self.args.model = None self.args.output_tags = None # create some files for sourcing @@ -59,7 +59,7 @@ Language is called 'brainfart'.""" return list(Path(tmp_dir.name).glob('*.[ty]*')) def test_message_file_created(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) @@ -70,14 +70,14 @@ Language is called 'brainfart'.""" self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: - self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this @@ -87,7 +87,7 @@ Language is called 'brainfart'.""" Is it good?""")) def test_single_question_with_text_only_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -99,7 +99,7 @@ Is it good?""")) {self.source_file1_content}""")) def test_single_question_with_embedded_source_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -111,7 +111,7 @@ Is it good?""")) {self.source_file2_content}""")) def test_single_question_with_embedded_source_code_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) -- 2.36.6 From f99cd3ed41b404d6e6197cc0239825e5177a2d10 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 18:28:10 +0200 Subject: [PATCH 110/121] question_cmd: fixed source code extraction and added a testcase --- chatmastermind/commands/question.py | 17 +++++-- chatmastermind/main.py | 2 +- chatmastermind/message.py | 2 +- tests/test_question_cmd.py | 79 +++++++++++++++++++++-------- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 90b782b..756a051 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question +from ..message import Message, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ question_parts = [] question_list = args.ask if args.ask is not None else [] - source_list = args.source if args.source is not None else [] - code_list = args.source_code if args.source_code is not None else [] + text_files = args.source_text if args.source_text is not None else [] + code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) if source is not None and len(source) > 0: @@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code is not None and len(code) > 0: with open(code) as r: content = r.read().strip() - if len(content) > 0: + if len(content) == 0: + continue + # try to extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + # if there's none, add the whole file + else: question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 46bad44..1a375d0 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -170,7 +170,7 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 35de3b9..7107c13 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -414,7 +414,7 @@ class Message(): return '\n'.join(output) def __str__(self) -> str: - return self.to_str(False, False, False) + return self.to_str(True, True, False) def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index aa0dc25..40ea4d8 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -22,18 +22,19 @@ class TestMessageCreate(unittest.TestCase): db_path=Path(self.db_path.name)) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) - self.args.source = None + self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None - # create some files for sourcing + # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) + # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` @@ -42,12 +43,26 @@ This is embedded source code. And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) + # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) + # File 4 : two source code blocks + self.source_file4 = tempfile.NamedTemporaryFile(delete=False) + self.source_file4_content = """This is just text. +``` +This is embedded source code. +``` +And some text again. +``` +This is embedded source code. +``` +Aaaand again some text.""" + with open(self.source_file4.name, 'w') as f: + f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) @@ -86,40 +101,62 @@ Language is called 'brainfart'.""" Is it good?""")) - def test_single_question_with_text_only_source(self) -> None: + def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file1.name}"] + self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains no source code + # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_source(self) -> None: - self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file2.name}"] - message = create_message(self.chat, self.args) - self.assertIsInstance(message, Message) - # source file contains 1 source code block - # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question(f"""What is this? - -{self.source_file2_content}""")) - - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains 1 source code block + # file contains 1 source code block # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` +""")) + + def test_single_question_with_code_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file3.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file is complete source code + self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` -{self.source_file2_content} +{self.source_file3_content} ```""")) + + def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file4.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 2 source code blocks + # -> expect them in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` + + +``` +This is embedded source code. +``` +""")) -- 2.36.6 From cc76da2ab36ae3cef44bd203018656d3a39501d0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:39:00 +0200 Subject: [PATCH 111/121] chat: added 'update_messages()' function and test --- chatmastermind/chat.py | 16 ++++++++++++++++ tests/test_chat.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4e8fb20..ddabb56 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -386,3 +386,19 @@ class ChatDB(Chat): msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): m.to_file() + + def update_messages(self, messages: list[Message], write: bool = True) -> None: + """ + Update existing messages. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. Only accepts + existing messages. + """ + if any(not message_in(m, self.messages) for m in messages): + raise ChatError("Can't update messages that are not in the internal list") + # remove old versions and add new ones + self.messages = [m for m in self.messages if not message_in(m, messages)] + self.messages += messages + self.sort() + # write the UPDATED messages if requested + if write: + self.write_messages(messages) diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e4aa8c..ed630a4 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -440,3 +440,31 @@ class TestChatDB(unittest.TestCase): cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + + def test_chat_db_update_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + message = chat_db.messages[0] + message.answer = Answer("New answer") + # update message without writing + chat_db.update_messages([message], write=False) + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # re-read the message and check for old content + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) + # now check with writing (message should be overwritten) + chat_db.update_messages([message], write=True) + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # test without file_path -> expect error + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.update_messages([message1]) -- 2.36.6 From 864ab7aeb1c2980145b258edb4b8baf76dbcd3bf Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:18:14 +0200 Subject: [PATCH 112/121] chat: added check for existing files when creating new filenames --- chatmastermind/chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ddabb56..7c4dd35 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -62,7 +62,10 @@ def make_file_path(dir_path: Path, Create a file_path for the given directory using the given file_suffix and ID generator function. """ - return dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + while file_path.exists(): + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + return file_path def write_dir(dir_path: Path, -- 2.36.6 From faac42d3c277b06af0b5f36bac18189f31a410cb Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:52:07 +0200 Subject: [PATCH 113/121] question_cmd: fixed '--ask' command --- chatmastermind/ai.py | 6 ++++++ chatmastermind/ais/openai.py | 19 ++++++++++++++----- chatmastermind/commands/question.py | 15 ++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index e94de8e..b97b5f1 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -66,3 +66,9 @@ class AI(Protocol): and is not implemented for all AIs. """ raise NotImplementedError + + def print(self) -> None: + """ + Print some info about the current AI, like system message. + """ + pass diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 14ce33f..1db4d20 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -43,16 +43,20 @@ class OpenAI(AI): n=num_answers, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - answers: list[Message] = [] - for choice in response['choices']: # type: ignore + question.answer = Answer(response['choices'][0]['message']['content']) + question.tags = otags + question.ai = self.name + question.model = self.config.model + answers: list[Message] = [question] + for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, ai=self.name, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt'], - response['usage']['completion'], - response['usage']['total'])) + return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], + response['usage']['completion_tokens'], + response['usage']['total_tokens'])) def models(self) -> list[str]: """ @@ -95,3 +99,8 @@ class OpenAI(AI): def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError + + def print(self) -> None: + print(f"MODEL: {self.config.model}") + print("=== SYSTEM ===") + print(self.config.system) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 756a051..fdabd62 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: # create the correct AI instance ai: AI = create_ai(args, config) if args.ask: + ai.print() + chat.print(paged=False) response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME args.output_tags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass + chat.update_messages([response.messages[0]]) + chat.add_to_cache(response.messages[1:]) + for idx, msg in enumerate(response.messages): + print(f"=== ANSWER {idx+1} ===") + print(msg.answer) + if response.tokens: + print("===============") + print(response.tokens) elif args.repeat: lmessage = chat.latest_message() assert lmessage -- 2.36.6 From 595ff8e294c945db38effda65d6445668db99e74 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:54:17 +0200 Subject: [PATCH 114/121] question_cmd: added message filtering by tags --- chatmastermind/commands/question.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index fdabd62..f439447 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question, source_code +from ..message import Message, MessageFilter, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -52,8 +52,12 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags) chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) + db_path=Path(config.db), + mfilter=mfilter) # if it's a new question, create and store it immediately if args.ask or args.create: message = create_message(chat, args) @@ -77,14 +81,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: if response.tokens: print("===============") print(response.tokens) - elif args.repeat: + elif args.repeat is not None: lmessage = chat.latest_message() assert lmessage # TODO: repeat either the last question or the # one(s) given in 'args.repeat' (overwrite # existing ones if 'args.overwrite' is True) pass - elif args.process: + elif args.process is not None: # TODO: process either all questions without an # answer or the one(s) given in 'args.process' pass -- 2.36.6 From 2e08ccf6060ebac3f66cc2d3e0d0c45ee9e5c3e2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:55:47 +0200 Subject: [PATCH 115/121] openai: stores AI.ID instead of AI.name in message --- chatmastermind/ais/openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 1db4d20..a388a7a 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -45,14 +45,14 @@ class OpenAI(AI): presence_penalty=self.config.presence_penalty) question.answer = Answer(response['choices'][0]['message']['content']) question.tags = otags - question.ai = self.name + question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, - ai=self.name, + ai=self.ID, model=self.config.model)) return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], response['usage']['completion_tokens'], -- 2.36.6 From 66908f5fed330f625250c35ae12c4ba970d83daf Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:24:20 +0200 Subject: [PATCH 116/121] message: fixed matching with empty tag sets --- chatmastermind/message.py | 4 ++-- tests/test_chat.py | 22 ++++++++++++++++++++-- tests/test_message.py | 6 ++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 7107c13..df59ed6 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -312,7 +312,7 @@ class Message(): mfilter.tags_not if mfilter else None) else: message = cls.__from_file_yaml(file_path) - if message and (not mfilter or (mfilter and message.match(mfilter))): + if message and (mfilter is None or message.match(mfilter)): return message else: return None @@ -508,7 +508,7 @@ class Message(): Return True if all attributes match, else False. """ mytags = self.tags or set() - if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 diff --git a/tests/test_chat.py b/tests/test_chat.py index ed630a4..1916a2b 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_filter(self) -> None: + def test_chat_db_from_dir_filter_tags(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or={Tag('tag1')})) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + + def test_chat_db_from_dir_filter_tags_empty(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or=set(), + tags_and=set(), + tags_not=set())) + self.assertEqual(len(chat_db.messages), 0) + + def test_chat_db_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messges(self) -> None: + def test_chat_db_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, diff --git a/tests/test_message.py b/tests/test_message.py index 57d5982..1f440df 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -300,6 +300,12 @@ This is a question. MessageFilter(tags_or={Tag('tag1')})) self.assertIsNone(message) + def test_from_file_txt_empty_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or=set(), + tags_and=set())) + self.assertIsNone(message) + def test_from_file_txt_no_tags_match_tags_not(self) -> None: message = Message.from_file(self.file_path_min, MessageFilter(tags_not={Tag('tag1')})) -- 2.36.6 From b840ebd7923b6e31f2f7070c0d10d74bd2343a51 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:56:50 +0200 Subject: [PATCH 117/121] message: to_file() now uses intermediate temporary file --- chatmastermind/message.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index df59ed6..64929a3 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,6 +3,8 @@ Module implementing message related functions and classes. """ import pathlib import yaml +import tempfile +import shutil from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -445,16 +447,18 @@ class Message(): * Answer.txt_header * Answer """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) if self.tags: - fd.write(f'{TagLine.from_set(self.tags)}\n') + temp_fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: - fd.write(f'{AILine.from_ai(self.ai)}\n') + temp_fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: - fd.write(f'{ModelLine.from_model(self.model)}\n') - fd.write(f'{Question.txt_header}\n{self.question}\n') + temp_fd.write(f'{ModelLine.from_model(self.model)}\n') + temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ @@ -466,7 +470,8 @@ class Message(): * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) @@ -476,7 +481,8 @@ class Message(): data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) - yaml.dump(data, fd, sort_keys=False) + yaml.dump(data, temp_fd, sort_keys=False) + shutil.move(temp_file_path, file_path) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ -- 2.36.6 From 22fa187e5f8f886bddcf61fd4ccbb0825cedf044 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:25:33 +0200 Subject: [PATCH 118/121] question_cmd: when no tags are specified, no tags are selected --- chatmastermind/commands/question.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index f439447..4936d8f 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -52,9 +52,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags) + mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), + tags_and=args.and_tags if args.and_tags is not None else set(), + tags_not=args.exclude_tags if args.exclude_tags is not None else set()) chat = ChatDB.from_dir(cache_path=Path('.'), db_path=Path(config.db), mfilter=mfilter) -- 2.36.6 From 481f9ecf7cf178fce8dd55ff8af854b0db0835b6 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:37:06 +0200 Subject: [PATCH 119/121] configuration: improved config file format --- chatmastermind/configuration.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 398fa03..08f6cbe 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -17,6 +17,18 @@ class ConfigError(Exception): pass +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + @dataclass class AIConfig: """ @@ -48,13 +60,13 @@ class OpenAIConfig(AIConfig): # a default configuration ID: str = 'default' api_key: str = '0123456789' - system: str = 'You are an assistant' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 max_tokens: int = 4000 top_p: float = 1.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + system: str = 'You are an assistant' @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: @@ -62,14 +74,14 @@ class OpenAIConfig(AIConfig): Create OpenAIConfig from a dict. """ res = cls( - system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), temperature=float(source['temperature']), top_p=float(source['top_p']), frequency_penalty=float(source['frequency_penalty']), - presence_penalty=float(source['presence_penalty']) + presence_penalty=float(source['presence_penalty']), + system=str(source['system']) ) # overwrite default ID if provided if 'ID' in source: @@ -148,6 +160,8 @@ class Config: def as_dict(self) -> dict[str, Any]: res = asdict(self) + # add the AI name manually (as first element) + # (not done by 'asdict' because it's a class variable) for ID, conf in res['ais'].items(): - conf.update({'name': self.ais[ID].name}) + res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf} return res -- 2.36.6 From 33023d29f9de4fde3e12bc49d34aee88c89dca2f Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 11 Sep 2023 07:38:49 +0200 Subject: [PATCH 120/121] configuration: made 'default' AI ID optional --- chatmastermind/ai_factory.py | 18 ++++++++++++------ chatmastermind/configuration.py | 3 +-- tests/test_ai_factory.py | 4 ++-- tests/test_configuration.py | 14 +++++++------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index bc4583c..420b287 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration. import argparse from typing import cast -from .configuration import Config, OpenAIConfig, default_ai_ID +from .configuration import Config, AIConfig, OpenAIConfig from .ai import AI, AIError from .ais.openai import OpenAI -def create_ai(args: argparse.Namespace, config: Config) -> AI: +def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 """ Creates an AI subclass instance from the given arguments - and configuration file. + and configuration file. If AI has not been set in the + arguments, it searches for the ID 'default'. If that + is not found, it uses the first AI in the list. """ + ai_conf: AIConfig if args.AI: try: ai_conf = config.ais[args.AI] except KeyError: raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") - elif default_ai_ID in config.ais: - ai_conf = config.ais[default_ai_ID] + elif 'default' in config.ais: + ai_conf = config.ais['default'] else: - raise AIError("No AI name given and no default exists") + try: + ai_conf = next(iter(config.ais.values())) + except StopIteration: + raise AIError("No AI found in this configuration") if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 08f6cbe..5397f4a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') supported_ais: list[str] = ['openai'] -default_ai_ID: str = 'default' default_config_path = '.config.yaml' @@ -58,7 +57,7 @@ class OpenAIConfig(AIConfig): # all members have default values, so we can easily create # a default configuration - ID: str = 'default' + ID: str = 'myopenai' api_key: str = '0123456789' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d00b319..9cb94d3 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.AI = 'default' + self.args.AI = 'myopenai' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.AI = 'default' + self.args.AI = 'myopenai' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index f3f9a98..ba8a5aa 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase): source_dict = { 'db': './test_db/', 'ais': { - 'default': { + 'myopenai': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', @@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase): config = Config.from_dict(source_dict) self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) - self.assertEqual(config.ais['default'].name, 'openai') - self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + self.assertEqual(config.ais['myopenai'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') # check that 'ID' has been added - self.assertEqual(config.ais['default'].ID, 'default') + self.assertEqual(config.ais['myopenai'].ID, 'myopenai') def test_create_default_should_create_default_config(self) -> None: Config.create_default(Path(self.test_file.name)) @@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase): config = Config( db='./test_db/', ais={ - 'default': OpenAIConfig( - ID='default', + 'myopenai': OpenAIConfig( + ID='myopenai', system='Custom system', api_key='9876543210', model='custom_model', @@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase): saved_config = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) - self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = { -- 2.36.6 From 17de0b99678381fa7e0fac9285d71bc26c649a67 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Mon, 11 Sep 2023 13:17:59 +0200 Subject: [PATCH 121/121] Remove old code. --- chatmastermind/main.py | 105 +---------------------------------------- 1 file changed, 1 insertion(+), 104 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 1a375d0..99aca09 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -18,110 +18,7 @@ from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) - return get_tags_unique(config, prefix) - - -def tags_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'tags' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - if args.list: - tags_freq = chat.tags_frequency(args.prefix, args.contain) - for tag, freq in tags_freq.items(): - print(f"- {tag}: {freq}") - # TODO: add renaming - - -def config_cmd(args: argparse.Namespace) -> None: - """ - Handler for the 'config' command. - """ - if args.create: - Config.create_default(Path(args.create)) - - -def question_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'question' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - # if it's a new question, create and store it immediately - if args.ask or args.create: - # FIXME: add sources to the question - message = Message(question=Question(args.question), - tags=args.ouput_tags, # FIXME - ai=args.ai, - model=args.model) - chat.add_to_cache([message]) - if args.create: - return - - # create the correct AI instance - ai: AI = create_ai(args, config) - if args.ask: - response: AIResponse = ai.request(message, - chat, - args.num_answers, # FIXME - args.otags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass - elif args.repeat: - lmessage = chat.latest_message() - assert lmessage - # TODO: repeat either the last question or the - # one(s) given in 'args.repeat' (overwrite - # existing ones if 'args.overwrite' is True) - pass - elif args.process: - # TODO: process either all questions without an - # answer or the one(s) given in 'args.process' - pass - - -def hist_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'hist' command. - """ - - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags, - question_contains=args.question, - answer_contains=args.answer) - chat = ChatDB.from_dir(Path('.'), - Path(config.db), - mfilter=mfilter) - chat.print(args.source_code_only, - args.with_tags, - args.with_files) - - -def print_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'print' command. - """ - fname = Path(args.file) - try: - message = Message.from_file(fname) - if message: - print(message.to_str(source_code_only=args.source_code_only)) - except MessageError: - print(f"File is not a valid message: {args.file}") - sys.exit(1) - if args.source_code_only: - display_source_code(data['answer']) - elif args.answer: - print(data['answer'].strip()) - elif args.question: - print(data['question'].strip()) - else: - print(dump_data(data).strip()) + return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) def create_parser() -> argparse.ArgumentParser: -- 2.36.6