diff --git a/.gitignore b/.gitignore index 4ade1df..89bf5fb 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json .config.yaml db -noweb \ No newline at end of file +noweb +Session.vim diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py new file mode 100644 index 0000000..b97b5f1 --- /dev/null +++ b/chatmastermind/ai.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from typing import Protocol, Optional, Union +from .configuration import AIConfig +from .tags import Tag +from .message import Message +from .chat import Chat + + +class AIError(Exception): + pass + + +@dataclass +class Tokens: + prompt: int = 0 + completion: int = 0 + total: int = 0 + + +@dataclass +class AIResponse: + """ + The response to an AI request. Consists of one or more messages + (each containing the question and a single answer) and the nr. + of used tokens. + """ + messages: list[Message] + tokens: Optional[Tokens] = None + + +class AI(Protocol): + """ + The base class for AI clients. + """ + + ID: str + name: str + config: AIConfig + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request. Parameters: + * question: the question to ask + * chat: the chat history to be added as context + * num_answers: nr. of requested answers (corresponds + to the nr. of messages in the 'AIResponse') + * otags: the output tags, i. e. the tags that all + returned messages should contain + """ + raise NotImplementedError + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def tokens(self, data: Union[Message, Chat]) -> int: + """ + Computes the nr. of AI language tokens for the given message + or chat. Note that the computation may not be 100% accurate + and is not implemented for all AIs. + """ + raise NotImplementedError + + def print(self) -> None: + """ + Print some info about the current AI, like system message. + """ + pass diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py new file mode 100644 index 0000000..420b287 --- /dev/null +++ b/chatmastermind/ai_factory.py @@ -0,0 +1,43 @@ +""" +Creates different AI instances, based on the given configuration. +""" + +import argparse +from typing import cast +from .configuration import Config, AIConfig, OpenAIConfig +from .ai import AI, AIError +from .ais.openai import OpenAI + + +def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 + """ + Creates an AI subclass instance from the given arguments + and configuration file. If AI has not been set in the + arguments, it searches for the ID 'default'. If that + is not found, it uses the first AI in the list. + """ + ai_conf: AIConfig + if args.AI: + try: + ai_conf = config.ais[args.AI] + except KeyError: + raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") + elif 'default' in config.ais: + ai_conf = config.ais['default'] + else: + try: + ai_conf = next(iter(config.ais.values())) + except StopIteration: + raise AIError("No AI found in this configuration") + + if ai_conf.name == 'openai': + ai = OpenAI(cast(OpenAIConfig, ai_conf)) + if args.model: + ai.config.model = args.model + if args.max_tokens: + ai.config.max_tokens = args.max_tokens + if args.temperature: + ai.config.temperature = args.temperature + return ai + else: + raise AIError(f"AI '{args.AI}' is not supported") diff --git a/chatmastermind/ais/__init__.py b/chatmastermind/ais/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py new file mode 100644 index 0000000..a388a7a --- /dev/null +++ b/chatmastermind/ais/openai.py @@ -0,0 +1,106 @@ +""" +Implements the OpenAI client classes and functions. +""" +import openai +from typing import Optional, Union +from ..tags import Tag +from ..message import Message, Answer +from ..chat import Chat +from ..ai import AI, AIResponse, Tokens +from ..configuration import OpenAIConfig + +ChatType = list[dict[str, str]] + + +class OpenAI(AI): + """ + The OpenAI AI client. + """ + + def __init__(self, config: OpenAIConfig) -> None: + self.ID = config.ID + self.name = config.name + self.config = config + openai.api_key = config.api_key + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + chat history. The nr. of requested answers corresponds to the + nr. of messages in the 'AIResponse'. + """ + oai_chat = self.openai_chat(chat, self.config.system, question) + response = openai.ChatCompletion.create( + model=self.config.model, + messages=oai_chat, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + n=num_answers, + frequency_penalty=self.config.frequency_penalty, + presence_penalty=self.config.presence_penalty) + question.answer = Answer(response['choices'][0]['message']['content']) + question.tags = otags + question.ai = self.ID + question.model = self.config.model + answers: list[Message] = [question] + for choice in response['choices'][1:]: # type: ignore + answers.append(Message(question=question.question, + answer=Answer(choice['message']['content']), + tags=otags, + ai=self.ID, + model=self.config.model)) + return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], + response['usage']['completion_tokens'], + response['usage']['total_tokens'])) + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def print_models(self) -> None: + """ + Print all models supported by the current AI. + """ + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def openai_chat(self, chat: Chat, system: str, + question: Optional[Message] = None) -> ChatType: + """ + Create a chat history with system message in OpenAI format. + Optionally append a new question. + """ + oai_chat: ChatType = [] + + def append(role: str, content: str) -> None: + oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + + append('system', system) + for message in chat.messages: + if message.answer: + append('user', message.question) + append('assistant', message.answer) + if question: + append('user', question.question) + return oai_chat + + def tokens(self, data: Union[Message, Chat]) -> int: + raise NotImplementedError + + def print(self) -> None: + print(f"MODEL: {self.config.model}") + print("=== SYSTEM ===") + print(self.config.system) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py deleted file mode 100644 index 2c4a094..0000000 --- a/chatmastermind/api_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import openai - -from .utils import ChatType -from .configuration import Config - - -def openai_api_key(api_key: str) -> None: - openai.api_key = api_key - - -def print_models() -> None: - """ - Print all models supported by the current AI. - """ - not_ready = [] - for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): - if engine['ready']: - print(engine['id']) - else: - not_ready.append(engine['id']) - if len(not_ready) > 0: - print('\nNot ready: ' + ', '.join(not_ready)) - - -def ai(chat: ChatType, - config: Config, - number: int - ) -> tuple[list[str], dict[str, int]]: - """ - Make AI request with the given chat history and configuration. - Return AI response and tokens used. - """ - response = openai.ChatCompletion.create( - model=config.openai.model, - messages=chat, - temperature=config.openai.temperature, - max_tokens=config.openai.max_tokens, - top_p=config.openai.top_p, - n=number, - frequency_penalty=config.openai.frequency_penalty, - presence_penalty=config.openai.presence_penalty) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py new file mode 100644 index 0000000..7c4dd35 --- /dev/null +++ b/chatmastermind/chat.py @@ -0,0 +1,407 @@ +""" +Module implementing various chat classes and functions for managing a chat history. +""" +import shutil +from pathlib import Path +from pprint import PrettyPrinter +from pydoc import pager +from dataclasses import dataclass +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable +from .message import Message, MessageFilter, MessageError, message_in +from .tags import Tag + +ChatInst = TypeVar('ChatInst', bound='Chat') +ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') + + +class ChatError(Exception): + pass + + +def terminal_width() -> int: + return shutil.get_terminal_size().columns + + +def pp(*args: Any, **kwargs: Any) -> None: + return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) + + +def print_paged(text: str) -> None: + pager(text) + + +def read_dir(dir_path: Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> list[Message]: + """ + Reads the messages from the given folder. + Parameters: + * 'dir_path': source directory + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages: list[Message] = [] + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + try: + message = Message.from_file(file_path, mfilter) + if message: + messages.append(message) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return messages + + +def make_file_path(dir_path: Path, + file_suffix: str, + next_fid: Callable[[], int]) -> Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + while file_path.exists(): + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + return file_path + + +def write_dir(dir_path: Path, + messages: list[Message], + file_suffix: str, + next_fid: Callable[[], int]) -> None: + """ + Write all messages to the given directory. If a message has no file_path, + a new one will be created. If message.file_path exists, it will be modified + to point to the given directory. + Parameters: + * 'dir_path': destination directory + * 'messages': list of messages to write + * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] + * 'next_fid': callable that returns the next file ID + """ + for message in messages: + file_path = message.file_path + # message has no file_path: create one + if not file_path: + file_path = make_file_path(dir_path, file_suffix, next_fid) + # file_path does not point to given directory: modify it + elif not file_path.parent.samefile(dir_path): + file_path = dir_path / file_path.name + message.to_file(file_path) + + +def clear_dir(dir_path: Path, + glob: Optional[str] = None) -> None: + """ + Deletes all Message files in the given directory. + """ + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in file_iter: + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + file_path.unlink(missing_ok=True) + + +@dataclass +class Chat: + """ + A class containing a complete chat history. + """ + + messages: list[Message] + + def filter(self, mfilter: MessageFilter) -> None: + """ + Use 'Message.match(mfilter) to remove all messages that + don't fulfill the filter requirements. + """ + self.messages = [m for m in self.messages if m.match(mfilter)] + + def sort(self, reverse: bool = False) -> None: + """ + Sort the messages according to 'Message.msg_id()'. + """ + try: + # the message may not have an ID if it doesn't have a file_path + self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) + except MessageError: + pass + + def clear(self) -> None: + """ + Delete all messages. + """ + self.messages = [] + + def add_messages(self, messages: list[Message]) -> None: + """ + Add new messages and sort them if possible. + """ + self.messages += messages + self.sort() + + def latest_message(self) -> Optional[Message]: + """ + Returns the last added message (according to the file ID). + """ + if len(self.messages) > 0: + self.sort() + return self.messages[-1] + else: + return None + + def find_messages(self, msg_names: list[str]) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the + caller should check the result if he requires all messages). + """ + return [m for m in self.messages + if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + + def remove_messages(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (incl. the suffix) or full paths. + """ + self.messages = [m for m in self.messages + if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + self.sort() + + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Get the tags of all messages, optionally filtered by prefix or substring. + """ + tags: set[Tag] = set() + for m in self.messages: + tags |= m.filter_tags(prefix, contain) + return set(sorted(tags)) + + def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: + """ + Get the frequency of all tags of all messages, optionally filtered by prefix or substring. + """ + tags: list[Tag] = [] + for m in self.messages: + tags += [tag for tag in m.filter_tags(prefix, contain)] + return {tag: tags.count(tag) for tag in sorted(tags)} + + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by all messages in this chat. + If unknown, 0 is returned. + """ + return sum(m.tokens() for m in self.messages) + + def print(self, source_code_only: bool = False, + with_tags: bool = False, with_files: bool = False, + paged: bool = True) -> None: + output: list[str] = [] + for message in self.messages: + if source_code_only: + output.append(message.to_str(source_code_only=True)) + continue + output.append(message.to_str(with_tags, with_files)) + output.append('\n' + ('-' * terminal_width()) + '\n') + if paged: + print_paged('\n'.join(output)) + else: + print(*output, sep='\n') + + +@dataclass +class ChatDB(Chat): + """ + A 'Chat' class that is bound to a given directory structure. Supports reading + and writing messages from / to that structure. Such a structure consists of + two directories: a 'cache directory', where all messages are temporarily + stored, and a 'DB' directory, where selected messages can be stored + persistently. + """ + + default_file_suffix: ClassVar[str] = '.txt' + + cache_path: Path + db_path: Path + # a MessageFilter that all messages must match (if given) + mfilter: Optional[MessageFilter] = None + file_suffix: str = default_file_suffix + # the glob pattern for all messages + glob: Optional[str] = None + + def __post_init__(self) -> None: + # contains the latest message ID + self.next_fname = self.db_path / '.next' + # make all paths absolute + self.cache_path = self.cache_path.absolute() + self.db_path = self.db_path.absolute() + + @classmethod + def from_dir(cls: Type[ChatDBInst], + cache_path: Path, + db_path: Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a 'ChatDB' instance from the given directory structure. + Reads all messages from 'db_path' into the local message list. + Parameters: + * 'cache_path': path to the directory for temporary messages + * 'db_path': path to the directory for persistent messages + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages = read_dir(db_path, glob, mfilter) + return cls(messages, cache_path, db_path, mfilter, + cls.default_file_suffix, glob) + + @classmethod + def from_messages(cls: Type[ChatDBInst], + cache_path: Path, + db_path: Path, + messages: list[Message], + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a ChatDB instance from the given message list. + """ + return cls(messages, cache_path, db_path, mfilter) + + def get_next_fid(self) -> int: + try: + with open(self.next_fname, 'r') as f: + next_fid = int(f.read()) + 1 + self.set_next_fid(next_fid) + return next_fid + except Exception: + self.set_next_fid(1) + return 1 + + def set_next_fid(self, fid: int) -> None: + with open(self.next_fname, 'w') as f: + f.write(f'{fid}') + + def read_db(self) -> None: + """ + Reads new messages from the DB directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.db_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def read_cache(self) -> None: + """ + Reads new messages from the cache directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.cache_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def write_db(self, messages: Optional[list[Message]] = None) -> None: + """ + Write messages to the DB directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point + to the DB directory. + """ + write_dir(self.db_path, + messages if messages else self.messages, + self.file_suffix, + self.get_next_fid) + + def write_cache(self, messages: Optional[list[Message]] = None) -> None: + """ + Write messages to the cache directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point to + the cache directory. + """ + write_dir(self.cache_path, + messages if messages else self.messages, + self.file_suffix, + self.get_next_fid) + + def clear_cache(self) -> None: + """ + Deletes all Message files from the cache dir and removes those messages from + the internal list. + """ + clear_dir(self.cache_path, self.glob) + # only keep messages from DB dir (or those that have not yet been written) + self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the DB directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the cache directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def write_messages(self, messages: Optional[list[Message]] = None) -> None: + """ + Write either the given messages or the internal ones to their current file_path. + If messages are given, they all must have a valid file_path. When writing the + internal messages, the ones with a valid file_path are written, the others + are ignored. + """ + if messages and any(m.file_path is None for m in messages): + raise ChatError("Can't write files without a valid file_path") + msgs = iter(messages if messages else self.messages) + while (m := next(msgs, None)): + m.to_file() + + def update_messages(self, messages: list[Message], write: bool = True) -> None: + """ + Update existing messages. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. Only accepts + existing messages. + """ + if any(not message_in(m, self.messages) for m in messages): + raise ChatError("Can't update messages that are not in the internal list") + # remove old versions and add new ones + self.messages = [m for m in self.messages if not message_in(m, messages)] + self.messages += messages + self.sort() + # write the UPDATED messages if requested + if write: + self.write_messages(messages) diff --git a/chatmastermind/commands/config.py b/chatmastermind/commands/config.py new file mode 100644 index 0000000..262164c --- /dev/null +++ b/chatmastermind/commands/config.py @@ -0,0 +1,11 @@ +import argparse +from pathlib import Path +from ..configuration import Config + + +def config_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'config' command. + """ + if args.create: + Config.create_default(Path(args.create)) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py new file mode 100644 index 0000000..88ed3be --- /dev/null +++ b/chatmastermind/commands/hist.py @@ -0,0 +1,23 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import MessageFilter + + +def hist_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'hist' command. + """ + + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py new file mode 100644 index 0000000..3d2b990 --- /dev/null +++ b/chatmastermind/commands/print.py @@ -0,0 +1,27 @@ +import sys +import argparse +from pathlib import Path +from ..configuration import Config +from ..message import Message, MessageError + + +def print_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'print' command. + """ + fname = Path(args.file) + try: + message = Message.from_file(fname) + if message: + if args.question: + print(message.question) + elif args.answer: + print(message.answer) + elif message.answer and args.only_source_code: + for code in message.answer.source_code(): + print(code) + else: + print(message.to_str()) + except MessageError: + print(f"File is not a valid message: {args.file}") + sys.exit(1) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py new file mode 100644 index 0000000..4936d8f --- /dev/null +++ b/chatmastermind/commands/question.py @@ -0,0 +1,94 @@ +import argparse +from pathlib import Path +from itertools import zip_longest +from ..configuration import Config +from ..chat import ChatDB +from ..message import Message, MessageFilter, Question, source_code +from ..ai_factory import create_ai +from ..ai import AI, AIResponse + + +def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: + """ + Creates (and writes) a new message from the given arguments. + """ + question_parts = [] + question_list = args.ask if args.ask is not None else [] + text_files = args.source_text if args.source_text is not None else [] + code_files = args.source_code if args.source_code is not None else [] + + for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): + if question is not None and len(question.strip()) > 0: + question_parts.append(question) + if source is not None and len(source) > 0: + with open(source) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + if code is not None and len(code) > 0: + with open(code) as r: + content = r.read().strip() + if len(content) == 0: + continue + # try to extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + # if there's none, add the whole file + else: + question_parts.append(f"```\n{content}\n```") + + full_question = '\n\n'.join(question_parts) + + message = Message(question=Question(full_question), + tags=args.output_tags, # FIXME + ai=args.AI, + model=args.model) + chat.add_to_cache([message]) + return message + + +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), + tags_and=args.and_tags if args.and_tags is not None else set(), + tags_not=args.exclude_tags if args.exclude_tags is not None else set()) + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db), + mfilter=mfilter) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = create_message(chat, args) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + ai.print() + chat.print(paged=False) + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.output_tags) # FIXME + chat.update_messages([response.messages[0]]) + chat.add_to_cache(response.messages[1:]) + for idx, msg in enumerate(response.messages): + print(f"=== ANSWER {idx+1} ===") + print(msg.answer) + if response.tokens: + print("===============") + print(response.tokens) + elif args.repeat is not None: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process is not None: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass diff --git a/chatmastermind/commands/tags.py b/chatmastermind/commands/tags.py new file mode 100644 index 0000000..2906a5b --- /dev/null +++ b/chatmastermind/commands/tags.py @@ -0,0 +1,17 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB + + +def tags_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'tags' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + if args.list: + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0037916..5397f4a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,66 +1,166 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional, ClassVar +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] +default_config_path = '.config.yaml' + + +class ConfigError(Exception): + pass + + +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + @dataclass -class OpenAIConfig(): +class AIConfig: + """ + The base class of all AI configurations. + """ + # the name of the AI the config class represents + # -> it's a class variable and thus not part of the + # dataclass constructor + name: ClassVar[str] + # a user-defined ID for an AI configuration entry + ID: str + + # the name must not be changed + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name': + raise AttributeError("'{name}' is not allowed to be changed") + else: + super().__setattr__(name, value) + + +@dataclass +class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + name: ClassVar[str] = 'openai' + + # all members have default values, so we can easily create + # a default configuration + ID: str = 'myopenai' + api_key: str = '0123456789' + model: str = 'gpt-3.5-turbo-16k' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + system: str = 'You are an assistant' @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: """ Create OpenAIConfig from a dict. """ - return cls( + res = cls( api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), temperature=float(source['temperature']), top_p=float(source['top_p']), frequency_penalty=float(source['frequency_penalty']), - presence_penalty=float(source['presence_penalty']) + presence_penalty=float(source['presence_penalty']), + system=str(source['system']) ) + # overwrite default ID if provided + if 'ID' in source: + res.ID = source['ID'] + return res + + +def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given name. + """ + if name.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"Unknown AI '{name}'") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais} @dataclass -class Config(): +class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create OpenAIConfig from a dict. + Create Config from a dict (with the same format as the config file). """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for ID, conf in source['ais'].items(): + # add the AI ID to the config (for easy internal access) + conf['ID'] = ID + ai_conf = ai_config_instance(conf['name'], conf) + ais[ID] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for conf in data['ais'].values(): + del (conf['ID']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) + + def as_dict(self) -> dict[str, Any]: + res = asdict(self) + # add the AI name manually (as first element) + # (not done by 'asdict' because it's a class variable) + for ID, conf in res['ais'].items(): + res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf} + return res diff --git a/chatmastermind/main.py b/chatmastermind/main.py index c30ea4e..99aca09 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -2,141 +2,29 @@ # -*- coding: utf-8 -*- # vim: set fileencoding=utf-8 : -import yaml import sys import argcomplete import argparse -import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data -from .api_client import ai, openai_api_key, print_models -from .configuration import Config -from itertools import zip_longest +from pathlib import Path from typing import Any - -default_config = '.config.yaml' +from .configuration import Config, default_config_path +from .message import Message +from .commands.question import question_cmd +from .commands.tags import tags_cmd +from .commands.config import config_cmd +from .commands.hist import hist_cmd +from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) - return get_tags_unique(config, prefix) - - -def create_question_with_hist(args: argparse.Namespace, - config: Config, - ) -> tuple[ChatType, str, list[str]]: - """ - Creates the "AI request", including the question and chat history as determined - by the specified tags. - """ - tags = args.tags or [] - extags = args.extags or [] - otags = args.output_tags or [] - - if not args.only_source_code: - print_tag_args(tags, extags, otags) - - question_parts = [] - question_list = args.question if args.question is not None else [] - source_list = args.source if args.source is not None else [] - - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: - question_parts.append(question) - elif source is not None: - with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") - - full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, extags, config, - args.match_all_tags, False, False) - return chat, full_question, tags - - -def tag_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'tag' command. - """ - if args.list: - print_tags_frequency(get_tags(config, None)) - - -def config_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'config' command. - """ - if args.list_models: - print_models() - elif args.print_model: - print(config.openai.model) - elif args.model: - config.openai.model = args.model - config.to_file(args.config) - - -def ask_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'ask' command. - """ - if args.max_tokens: - config.openai.max_tokens = args.max_tokens - if args.temperature: - config.openai.temperature = args.temperature - if args.model: - config.openai.model = args.model - chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.only_source_code) - otags = args.output_tags or [] - answers, usage = ai(chat, config, args.number) - save_answers(question, answers, tags, otags, config) - print("-" * terminal_width()) - print(f"Usage: {usage}") - - -def hist_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'hist' command. - """ - tags = args.tags or [] - extags = args.extags or [] - - chat = create_chat_hist(None, tags, extags, config, - args.match_all_tags, - args.with_tags, - args.with_files) - print_chat_hist(chat, args.dump, args.only_source_code) - - -def print_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'print' command. - """ - fname = pathlib.Path(args.file) - if fname.suffix == '.yaml': - with open(args.file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif fname.suffix == '.txt': - data = read_file(fname) - else: - print(f"Unknown file type: {args.file}") - sys.exit(1) - if args.only_source_code: - display_source_code(data['answer']) - elif args.answer: - print(data['answer'].strip()) - elif args.question: - print(data['question'].strip()) - else: - print(dump_data(data).strip()) + return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-c', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -146,58 +34,66 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) - tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names', metavar='TAGS') + tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', + help='List of tags (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore - extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', - help='List of tag names to exclude', metavar='EXTAGS') - extag_arg.completer = tags_completer # type: ignore + atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', + help='List of tags (all must match)', metavar='ATAGS') + atag_arg.completer = tags_completer # type: ignore + etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', + help='List of tags to exclude', metavar='XTAGS') + etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OTAGS') + help='List of output tags (default: use input tags)', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore - tag_parser.add_argument('-a', '--match-all-tags', - help="All given tags must match when selecting chat history entries", - action='store_true') - # enable autocompletion for tags - # 'ask' command parser - ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.", - aliases=['a']) - ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', - required=True) - ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, - default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', - action='store_true') + # a parent parser for all commands that support AI configuration + ai_parser = argparse.ArgumentParser(add_help=False) + ai_parser.add_argument('-A', '--AI', help='AI ID to use') + ai_parser.add_argument('-M', '--model', help='Model to use') + ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) + ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) + ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) + + # 'question' command parser + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], + help="ask, create and process questions.", + aliases=['q']) + question_cmd_parser.set_defaults(func=question_cmd) + question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) + question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') + question_group.add_argument('-c', '--create', nargs='+', help='Create a question') + question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') + question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') + question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', + action='store_true') + question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], help="Print chat history.", aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) - hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", - action='store_true') hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", action='store_true') - hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') + hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') + hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') - # 'tag' command parser - tag_cmd_parser = cmdparser.add_parser('tag', - help="Manage tags.", - aliases=['t']) - tag_cmd_parser.set_defaults(func=tag_cmd) - tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) - tag_group.add_argument('-l', '--list', help="List all tags and their frequency", - action='store_true') + # 'tags' command parser + tags_cmd_parser = cmdparser.add_parser('tags', + help="Manage tags.", + aliases=['t']) + tags_cmd_parser.set_defaults(func=tags_cmd) + tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) + tags_group.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') + tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") + tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") # 'config' command parser config_cmd_parser = cmdparser.add_parser('config', @@ -209,11 +105,11 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') - config_group.add_argument('-M', '--model', help="Set model in the config file") + config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', - help="Print files.", + help="Print message files.", aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) @@ -230,11 +126,12 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = Config.from_file(args.config) - openai_api_key(config.openai.api_key) - - command.func(command, config) + if command.func == config_cmd: + command.func(command) + else: + config = Config.from_file(args.config) + command.func(command, config) return 0 diff --git a/chatmastermind/message.py b/chatmastermind/message.py new file mode 100644 index 0000000..64929a3 --- /dev/null +++ b/chatmastermind/message.py @@ -0,0 +1,561 @@ +""" +Module implementing message related functions and classes. +""" +import pathlib +import yaml +import tempfile +import shutil +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable +from dataclasses import dataclass, asdict, field +from .tags import Tag, TagLine, TagError, match_tags, rename_tags + +QuestionInst = TypeVar('QuestionInst', bound='Question') +AnswerInst = TypeVar('AnswerInst', bound='Answer') +MessageInst = TypeVar('MessageInst', bound='Message') +AILineInst = TypeVar('AILineInst', bound='AILine') +ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') +YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] + + +class MessageError(Exception): + pass + + +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + +def source_code(text: str, include_delims: bool = False) -> list[str]: + """ + Extract all source code sections from the given text, i. e. all lines + surrounded by lines tarting with '```'. If 'include_delims' is True, + the surrounding lines are included, otherwise they are omitted. The + result list contains every source code section as a single string. + The order in the list represents the order of the sections in the text. + """ + code_sections: list[str] = [] + code_lines: list[str] = [] + in_code_block = False + + for line in text.split('\n'): + if line.strip().startswith('```'): + if include_delims: + code_lines.append(line) + if in_code_block: + code_sections.append('\n'.join(code_lines) + '\n') + code_lines.clear() + in_code_block = not in_code_block + elif in_code_block: + code_lines.append(line) + + return code_sections + + +def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: + """ + Searches the given message list for a message with the same file + name as the given one (i. e. it compares Message.file_path.name). + If the given message has no file_path, False is returned. + """ + if not message.file_path: + return False + for m in messages: + if m.file_path and m.file_path.name == message.file_path.name: + return True + return False + + +@dataclass(kw_only=True) +class MessageFilter: + """ + Various filters for a Message. + """ + tags_or: Optional[set[Tag]] = None + tags_and: Optional[set[Tag]] = None + tags_not: Optional[set[Tag]] = None + ai: Optional[str] = None + model: Optional[str] = None + question_contains: Optional[str] = None + answer_contains: Optional[str] = None + answer_state: Optional[Literal['available', 'missing']] = None + ai_state: Optional[Literal['available', 'missing']] = None + model_state: Optional[Literal['available', 'missing']] = None + + +class AILine(str): + """ + A line that represents the AI name in a '.txt' file.. + """ + prefix: Final[str] = 'AI:' + + def __new__(cls: Type[AILineInst], string: str) -> AILineInst: + if not string.startswith(cls.prefix): + raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def ai(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: + return cls(' '.join([cls.prefix, ai])) + + +class ModelLine(str): + """ + A line that represents the model name in a '.txt' file.. + """ + prefix: Final[str] = 'MODEL:' + + def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: + if not string.startswith(cls.prefix): + raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def model(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: + return cls(' '.join([cls.prefix, model])) + + +class Answer(str): + """ + A single answer with a defined header. + """ + tokens: int = 0 # tokens used by this answer + txt_header: ClassVar[str] = '==== ANSWER ====' + yaml_key: ClassVar[str] = 'answer' + + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + """ + Make sure the answer string does not contain the header as a whole line. + """ + if cls.txt_header in string.split('\n'): + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if cls.txt_header in strings: + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +class Question(str): + """ + A single question with a defined header. + """ + tokens: int = 0 # tokens used by this question + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' + + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + """ + Make sure the question string does not contain the header as a whole line + (also not that from 'Answer', so it's always clear where the answer starts). + """ + string_lines = string.split('\n') + if cls.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if Answer.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if cls.txt_header in strings: + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +@dataclass +class Message(): + """ + Single message. Consists of a question and optionally an answer, a set of tags + and a file path. + """ + question: Question + answer: Optional[Answer] = None + # metadata, ignored when comparing messages + tags: Optional[set[Tag]] = field(default=None, compare=False) + ai: Optional[str] = field(default=None, compare=False) + model: Optional[str] = field(default=None, compare=False) + file_path: Optional[pathlib.Path] = field(default=None, compare=False) + # class variables + file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] + tags_yaml_key: ClassVar[str] = 'tags' + file_yaml_key: ClassVar[str] = 'file_path' + ai_yaml_key: ClassVar[str] = 'ai' + model_yaml_key: ClassVar[str] = 'model' + + def __hash__(self) -> int: + """ + The hash value is computed based on immutable members. + """ + return hash((self.question, self.answer)) + + @classmethod + def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: + """ + Create a Message from the given dict. + """ + return cls(question=data[Question.yaml_key], + answer=data.get(Answer.yaml_key, None), + tags=set(data.get(cls.tags_yaml_key, [])), + ai=data.get(cls.ai_yaml_key, None), + model=data.get(cls.model_yaml_key, None), + file_path=data.get(cls.file_yaml_key, None)) + + @classmethod + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: + """ + Return only the tags from the given Message file, + optionally filtered based on prefix or contained string. + """ + tags: set[Tag] = set() + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + # for TXT, it's enough to read the TagLine + if file_path.suffix == '.txt': + with open(file_path, "r") as fd: + try: + tags = TagLine(fd.readline()).tags(prefix, contain) + except TagError: + pass # message without tags + else: # '.yaml' + try: + message = cls.from_file(file_path) + if message: + msg_tags = message.filter_tags(prefix=prefix, contain=contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + if msg_tags: + tags = msg_tags + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix' + and 'contain'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix, contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return tags + + @classmethod + def from_file(cls: Type[MessageInst], file_path: pathlib.Path, + mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: + """ + Create a Message from the given file. Returns 'None' if the message does + not fulfill the filter requirements. For TXT files, the tags are matched + before building the whole message. The other filters are applied afterwards. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + + if file_path.suffix == '.txt': + message = cls.__from_file_txt(file_path, + mfilter.tags_or if mfilter else None, + mfilter.tags_and if mfilter else None, + mfilter.tags_not if mfilter else None) + else: + message = cls.__from_file_yaml(file_path) + if message and (mfilter is None or message.match(mfilter)): + return message + else: + return None + + @classmethod + def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 + tags_or: Optional[set[Tag]] = None, + tags_and: Optional[set[Tag]] = None, + tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: + """ + Create a Message from the given TXT file. Expects the following file structures: + For '.txt': + * TagLine [Optional] + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header [Optional] + * Answer [Optional] + + Returns 'None' if the message does not fulfill the tag requirements. + """ + tags: set[Tag] = set() + question: Question + answer: Optional[Answer] = None + ai: Optional[str] = None + model: Optional[str] = None + with open(file_path, "r") as fd: + # TagLine (Optional) + try: + pos = fd.tell() + tags = TagLine(fd.readline()).tags() + except TagError: + fd.seek(pos) + if tags_or or tags_and or tags_not: + # match with an empty set if the file has no tags + if not match_tags(tags, tags_or, tags_and, tags_not): + return None + # AILine (Optional) + try: + pos = fd.tell() + ai = AILine(fd.readline()).ai() + except MessageError: + fd.seek(pos) + # ModelLine (Optional) + try: + pos = fd.tell() + model = ModelLine(fd.readline()).model() + except MessageError: + fd.seek(pos) + # Question and Answer + text = fd.read().strip().split('\n') + try: + question_idx = text.index(Question.txt_header) + 1 + except ValueError: + raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") + try: + answer_idx = text.index(Answer.txt_header) + question = Question.from_list(text[question_idx:answer_idx]) + answer = Answer.from_list(text[answer_idx + 1:]) + except ValueError: + question = Question.from_list(text[question_idx:]) + return cls(question, answer, tags, ai, model, file_path) + + @classmethod + def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: + """ + Create a Message from the given YAML file. Expects the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string [Optional] + * Message.tags_yaml_key: list of strings [Optional] + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + data[cls.file_yaml_key] = file_path + return cls.from_dict(data) + + def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: + """ + Return the current Message as a string. + """ + output: list[str] = [] + if source_code_only: + # use the source code from answer only + if self.answer: + output.extend(self.answer.source_code(include_delims=True)) + return '\n'.join(output) if len(output) > 0 else '' + if with_tags: + output.append(self.tags_str()) + if with_file: + output.append('FILE: ' + str(self.file_path)) + output.append(Question.txt_header) + output.append(self.question) + if self.answer: + output.append(Answer.txt_header) + output.append(self.answer) + return '\n'.join(output) + + def __str__(self) -> str: + return self.to_str(True, True, False) + + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 + """ + Write a Message to the given file. Type is determined based on the suffix. + Currently supported suffixes: ['.txt', '.yaml'] + """ + if file_path: + self.file_path = file_path + if not self.file_path: + raise MessageError("Got no valid path to write message") + if self.file_path.suffix not in self.file_suffixes: + raise MessageError(f"File type '{self.file_path.suffix}' is not supported") + # TXT + if self.file_path.suffix == '.txt': + return self.__to_file_txt(self.file_path) + elif self.file_path.suffix == '.yaml': + return self.__to_file_yaml(self.file_path) + + def __to_file_txt(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in TXT format. + Creates the following file structures: + * TagLine + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header + * Answer + """ + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) + if self.tags: + temp_fd.write(f'{TagLine.from_set(self.tags)}\n') + if self.ai: + temp_fd.write(f'{AILine.from_ai(self.ai)}\n') + if self.model: + temp_fd.write(f'{ModelLine.from_model(self.model)}\n') + temp_fd.write(f'{Question.txt_header}\n{self.question}\n') + if self.answer: + temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + shutil.move(temp_file_path, file_path) + + def __to_file_yaml(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in YAML format. + Creates the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string + * Message.tags_yaml_key: list of strings + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) + data: YamlDict = {Question.yaml_key: str(self.question)} + if self.answer: + data[Answer.yaml_key] = str(self.answer) + if self.ai: + data[self.ai_yaml_key] = self.ai + if self.model: + data[self.model_yaml_key] = self.model + if self.tags: + data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) + yaml.dump(data, temp_fd, sort_keys=False) + shutil.move(temp_file_path, file_path) + + def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Filter tags based on their prefix (i. e. the tag starts with a given string) + or some contained string. + """ + if not self.tags: + return set() + res_tags = self.tags.copy() + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags + + def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: + """ + Returns all tags as a string with the TagLine prefix. Optionally filtered + using 'Message.filter_tags()'. + """ + if self.tags: + return str(TagLine.from_set(self.filter_tags(prefix, contain))) + else: + return str(TagLine.from_set(set())) + + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 + """ + Matches the current Message to the given filter atttributes. + Return True if all attributes match, else False. + """ + mytags = self.tags or set() + if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) + and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 + or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 + or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 + or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 + or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 + or (mfilter.model_state == 'available' and not self.model) # noqa: W503 + or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 + or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 + or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 + return False + return True + + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None: + """ + Renames the given tags. The first tuple element is the old name, + the second one is the new name. + """ + if self.tags: + self.tags = rename_tags(self.tags, tags_rename) + + def msg_id(self) -> str: + """ + Returns an ID that is unique throughout all messages in the same (DB) directory. + Currently this is the file name. The ID is also used for sorting messages. + """ + if self.file_path: + return self.file_path.name + else: + raise MessageError("Can't create file ID without a file path") + + def as_dict(self) -> dict[str, Any]: + return asdict(self) + + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by this message. + If unknown, 0 is returned. + """ + if self.answer: + return self.question.tokens + self.answer.tokens + else: + return self.question.tokens diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py deleted file mode 100644 index 8b9ed97..0000000 --- a/chatmastermind/storage.py +++ /dev/null @@ -1,121 +0,0 @@ -import yaml -import io -import pathlib -from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config -from typing import Any, Optional - - -def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: - with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() - # also support tags separated by ',' (old format) - separator = ',' if ',' in tagline else ' ' - tags = [t.strip() for t in tagline.split(separator)] - if tags_only: - return {"tags": tags} - text = fd.read().strip().split('\n') - question_idx = text.index("=== QUESTION ===") + 1 - answer_idx = text.index("==== ANSWER ====") - question = "\n".join(text[question_idx:answer_idx]).strip() - answer = "\n".join(text[answer_idx + 1:]).strip() - return {"question": question, "answer": answer, "tags": tags, - "file": fname.name} - - -def dump_data(data: dict[str, Any]) -> str: - with io.StringIO() as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - return fd.getvalue() - - -def write_file(fname: str, data: dict[str, Any]) -> None: - with open(fname, "w") as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - - -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]], - config: Config - ) -> None: - wtags = otags or tags - num, inum = 0, 0 - next_fname = pathlib.Path(str(config.db)) / '.next' - try: - with open(next_fname, 'r') as f: - num = int(f.read()) - except Exception: - pass - for answer in answers: - num += 1 - inum += 1 - title = f'-- ANSWER {inum} ' - title_end = '-' * (terminal_width() - len(title)) - print(f'{title}{title_end}') - print(answer) - write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) - with open(next_fname, 'w') as f: - f.write(f'{num}') - - -def create_chat_hist(question: Optional[str], - tags: Optional[list[str]], - extags: Optional[list[str]], - config: Config, - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> ChatType: - chat: ChatType = [] - append_message(chat, 'system', str(config.system).strip()) - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data['file'] = file.name - elif file.suffix == '.txt': - data = read_file(file) - else: - continue - data_tags = set(data.get('tags', [])) - tags_match: bool - if match_all_tags: - tags_match = not tags or set(tags).issubset(data_tags) - else: - tags_match = not tags or bool(data_tags.intersection(tags)) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat, with_tags, with_file) - if question: - append_message(chat, 'user', question) - return chat - - -def get_tags(config: Config, prefix: Optional[str]) -> list[str]: - result = [] - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif file.suffix == '.txt': - data = read_file(file, tags_only=True) - else: - continue - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return result - - -def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: - return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py new file mode 100644 index 0000000..5ea1a3a --- /dev/null +++ b/chatmastermind/tags.py @@ -0,0 +1,184 @@ +""" +Module implementing tag related functions and classes. +""" +from typing import Type, TypeVar, Optional, Final + +TagInst = TypeVar('TagInst', bound='Tag') +TagLineInst = TypeVar('TagLineInst', bound='TagLine') + + +class TagError(Exception): + pass + + +class Tag(str): + """ + A single tag. A string that can contain anything but the default separator (' '). + """ + # default separator + default_separator: Final[str] = ' ' + # alternative separators (e. g. for backwards compatibility) + alternative_separators: Final[list[str]] = [','] + + def __new__(cls: Type[TagInst], string: str) -> TagInst: + """ + Make sure the tag string does not contain the default separator. + """ + if cls.default_separator in string: + raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'") + instance = super().__new__(cls, string) + return instance + + +def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]: + """ + Deletes the given tags and returns a new set. + """ + return tags.difference(tags_delete) + + +def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]: + """ + Adds the given tags and returns a new set. + """ + return set(sorted(tags | tags_add)) + + +def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]: + """ + Merges the tags in 'tags_merge' into the current one and returns a new set. + """ + for ts in tags_merge: + tags |= ts + return tags + + +def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]: + """ + Renames the given tags and returns a new set. The first tuple element + is the old name, the second one is the new name. + """ + for t in tags_rename: + if t[0] in tags: + tags.remove(t[0]) + tags.add(t[1]) + return set(sorted(tags)) + + +def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the given set 'tags' matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()), + they match no tags. + """ + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tags for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing + + +class TagLine(str): + """ + A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by + a list of tags, separated by the defaut separator (' '). Any operations on a + TagLine will sort the tags. + """ + # the prefix + prefix: Final[str] = 'TAGS:' + + def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: + """ + Make sure the tagline string starts with the prefix. Also replace newlines + and multiple spaces with ' ', in order to support multiline TagLines. + """ + if not string.startswith(cls.prefix): + raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + string = ' '.join(string.split()) + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst: + """ + Create a new TagLine from a set of tags. + """ + return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) + + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Returns all tags contained in this line as a set, optionally + filtered based on prefix or contained string. + """ + tagstr = self[len(self.prefix):].strip() + if tagstr == '': + return set() # no tags, only prefix + separator = Tag.default_separator + # look for alternative separators and use the first one found + # -> we don't support different separators in the same TagLine + for s in Tag.alternative_separators: + if s in tagstr: + separator = s + break + res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() + + def merge(self, taglines: set['TagLine']) -> 'TagLine': + """ + Merges the tags of all given taglines into the current one and returns a new TagLine. + """ + tags_merge = [tl.tags() for tl in taglines] + return self.from_set(merge_tags(self.tags(), tags_merge)) + + def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine': + """ + Deletes the given tags and returns a new TagLine. + """ + return self.from_set(delete_tags(self.tags(), tags_delete)) + + def add_tags(self, tags_add: set[Tag]) -> 'TagLine': + """ + Adds the given tags and returns a new TagLine. + """ + return self.from_set(add_tags(self.tags(), tags_add)) + + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine': + """ + Renames the given tags and returns a new TagLine. The first + tuple element is the old name, the second one is the new name. + """ + return self.from_set(rename_tags(self.tags(), tags_rename)) + + def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the current TagLine matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + return match_tags(self.tags(), tags_or, tags_and, tags_not) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py deleted file mode 100644 index 6543ce1..0000000 --- a/chatmastermind/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -import shutil -from pprint import PrettyPrinter -from typing import Any - -ChatType = list[dict[str, str]] - - -def terminal_width() -> int: - return shutil.get_terminal_size().columns - - -def pp(*args: Any, **kwargs: Any) -> None: - return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) - - -def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: - """ - Prints the tags specified in the given args. - """ - printed_messages = [] - - if tags: - printed_messages.append(f"Tags: {' '.join(tags)}") - if extags: - printed_messages.append(f"Excluding tags: {' '.join(extags)}") - if otags: - printed_messages.append(f"Output tags: {' '.join(otags)}") - - if printed_messages: - print("\n".join(printed_messages)) - print() - - -def append_message(chat: ChatType, - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: dict[str, str], - chat: ChatType, - with_tags: bool = False, - with_file: bool = False - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - if with_tags: - tags = " ".join(message['tags']) - append_message(chat, 'tags', tags) - if with_file: - append_message(chat, 'file', message['file']) - - -def display_source_code(content: str) -> None: - try: - content_start = content.index('```') - content_start = content.index('\n', content_start) + 1 - content_end = content.rindex('```') - if content_start < content_end: - print(content[content_start:content_end].strip()) - except ValueError: - pass - - -def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: - if dump: - pp(chat) - return - for message in chat: - text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 - if source_code: - display_source_code(message['content']) - continue - if message['role'] == 'user': - print('-' * terminal_width()) - if text_too_long: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") - - -def print_tags_frequency(tags: list[str]) -> None: - for tag in sorted(set(tags)): - print(f"- {tag}: {tags.count(tag)}") diff --git a/setup.py b/setup.py index 02d9ab1..a311605 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages(), + packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -32,7 +32,7 @@ setup( "openai", "PyYAML", "argcomplete", - "pytest" + "pytest", ], python_requires=">=3.9", test_suite="tests", diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py new file mode 100644 index 0000000..9cb94d3 --- /dev/null +++ b/tests/test_ai_factory.py @@ -0,0 +1,48 @@ +import argparse +import unittest +from unittest.mock import MagicMock +from chatmastermind.ai_factory import create_ai +from chatmastermind.configuration import Config +from chatmastermind.ai import AIError +from chatmastermind.ais.openai import OpenAI + + +class TestCreateAI(unittest.TestCase): + def setUp(self) -> None: + self.args = MagicMock(spec=argparse.Namespace) + self.args.AI = 'myopenai' + self.args.model = None + self.args.max_tokens = None + self.args.temperature = None + + def test_create_ai_from_args(self) -> None: + # Create an AI with the default configuration + config = Config() + self.args.AI = 'myopenai' + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_ai_from_default(self) -> None: + self.args.AI = None + # Create an AI with the default configuration + config = Config() + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_empty_ai_error(self) -> None: + self.args.AI = None + # Create Config with empty AIs + config = Config() + config.ais = {} + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) + + def test_create_unsupported_ai_error(self) -> None: + # Mock argparse.Namespace with ai='invalid_ai' + self.args.AI = 'invalid_ai' + # Create default Config + config = Config() + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..1916a2b --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,488 @@ +import unittest +import pathlib +import tempfile +import time +from io import StringIO +from unittest.mock import patch +from chatmastermind.tags import TagLine +from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter +from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError + + +class TestChat(unittest.TestCase): + def setUp(self) -> None: + self.chat = Chat([]) + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('0002.txt')) + + def test_filter(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + self.chat.filter(MessageFilter(answer_contains='Answer 1')) + + self.assertEqual(len(self.chat.messages), 1) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + + def test_sort(self) -> None: + self.chat.add_messages([self.message2, self.message1]) + self.chat.sort() + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + self.chat.sort(reverse=True) + self.assertEqual(self.chat.messages[0].question, 'Question 2') + self.assertEqual(self.chat.messages[1].question, 'Question 1') + + def test_clear(self) -> None: + self.chat.add_messages([self.message1]) + self.chat.clear() + self.assertEqual(len(self.chat.messages), 0) + + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + self.assertEqual(len(self.chat.messages), 2) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + + def test_tags(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + tags_all = self.chat.tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.chat.tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.chat.tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + def test_tags_frequency(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + tags_freq = self.chat.tags_frequency() + self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + + def test_find_remove_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + msgs = self.chat.find_messages(['0001.txt']) + self.assertListEqual(msgs, [self.message1]) + msgs = self.chat.find_messages(['0001.txt', '0002.txt']) + self.assertListEqual(msgs, [self.message1, self.message2]) + # add new Message with full path + message3 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('/foo/bla/0003.txt')) + self.chat.add_messages([message3]) + # find new Message by full path + msgs = self.chat.find_messages(['/foo/bla/0003.txt']) + self.assertListEqual(msgs, [message3]) + # find Message with full path only by filename + msgs = self.chat.find_messages(['0003.txt']) + self.assertListEqual(msgs, [message3]) + # remove last message + self.chat.remove_messages(['0003.txt']) + self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + + @patch('sys.stdout', new_callable=StringIO) + def test_print(self, mock_stdout: StringIO) -> None: + self.chat.add_messages([self.message1, self.message2]) + self.chat.print(paged=False) + expected_output = f"""{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 + +{'-'*terminal_width()} + +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 + +{'-'*terminal_width()} + +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + self.chat.add_messages([self.message1, self.message2]) + self.chat.print(paged=False, with_tags=True, with_files=True) + expected_output = f"""{TagLine.prefix} atag1 btag2 +FILE: 0001.txt +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 + +{'-'*terminal_width()} + +{TagLine.prefix} btag2 +FILE: 0002.txt +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 + +{'-'*terminal_width()} + +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + +class TestChatDB(unittest.TestCase): + def setUp(self) -> None: + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('tag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('tag2')}, + file_path=pathlib.Path('0002.yaml')) + self.message3 = Message(Question('Question 3'), + Answer('Answer 3'), + {Tag('tag3')}, + file_path=pathlib.Path('0003.txt')) + self.message4 = Message(Question('Question 4'), + Answer('Answer 4'), + {Tag('tag4')}, + file_path=pathlib.Path('0004.yaml')) + + self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) + self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) + self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) + self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: + """ + List all Message files in the given TemporaryDirectory. + """ + # exclude '.next' + return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) + + def tearDown(self) -> None: + self.db_path.cleanup() + self.cache_path.cleanup() + pass + + def test_chat_db_from_dir(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + # check that the files are sorted + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, + pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_from_dir_glob(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + glob='*.txt') + self.assertEqual(len(chat_db.messages), 2) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + + def test_chat_db_from_dir_filter_tags(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or={Tag('tag1')})) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + + def test_chat_db_from_dir_filter_tags_empty(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or=set(), + tags_and=set(), + tags_not=set())) + self.assertEqual(len(chat_db.messages), 0) + + def test_chat_db_from_dir_filter_answer(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(answer_contains='Answer 2')) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[0].answer, 'Answer 2') + + def test_chat_db_from_messages(self) -> None: + chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + messages=[self.message1, self.message2, + self.message3, self.message4]) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + + def test_chat_db_fids(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.get_next_fid(), 5) + self.assertEqual(chat_db.get_next_fid(), 6) + self.assertEqual(chat_db.get_next_fid(), 7) + with open(chat_db.next_fname, 'r') as f: + self.assertEqual(f.read(), '7') + + def test_chat_db_write(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 4) + self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + # check that Message.file_path has been correctly updated + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + + # check the timestamp of the files in the DB directory + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} + # overwrite the messages in the db directory + time.sleep(0.05) + chat_db.write_db() + # check if the written files are in the DB directory + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + # check if all files in the DB dir have actually been overwritten + for file in db_dir_files: + self.assertGreater(file.stat().st_mtime, old_timestamps[file]) + # check that Message.file_path has been correctly updated (again) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_read(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + + # create 2 new files in the DB directory + new_message1 = Message(Question('Question 5'), + Answer('Answer 5'), + {Tag('tag5')}) + new_message2 = Message(Question('Question 6'), + Answer('Answer 6'), + {Tag('tag6')}) + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 6) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # create 2 new files in the cache directory + new_message3 = Message(Question('Question 7'), + Answer('Answer 5'), + {Tag('tag7')}) + new_message4 = Message(Question('Question 8'), + Answer('Answer 6'), + {Tag('tag8')}) + new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) + # read and check them + chat_db.read_cache() + self.assertEqual(len(chat_db.messages), 8) + # check that the new message have the cache dir path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) + # an the old ones keep their path (since they have not been replaced) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now overwrite two messages in the DB directory + new_message1.question = Question('New Question 1') + new_message2.question = Question('New Question 2') + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read from the DB dir and check if the modified messages have been updated + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + self.assertEqual(chat_db.messages[4].question, 'New Question 1') + self.assertEqual(chat_db.messages[5].question, 'New Question 2') + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now write the messages from the cache to the DB directory + new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + # check that they now have the DB path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + + def test_chat_db_clear(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 4) + + # now rewrite them to the DB dir and check for modified paths + chat_db.write_db() + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + + # add a new message with empty file_path + message_empty = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You don't belong here!")) + # and one for the cache dir + message_cache = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You're a creep!"), + file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + chat_db.add_messages([message_empty, message_cache]) + + # clear the cache and check the cache dir + chat_db.clear_cache() + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + # make sure that the DB messages (and the new message) are still there + self.assertEqual(len(chat_db.messages), 5) + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + # but not the message with the cache dir path + self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + # check if the file_path has been correctly set + self.assertIsNotNone(message1.file_path) + self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + # check if the file_path has been correctly set + self.assertIsNotNone(message2.file_path) + self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 5) + + with self.assertRaises(ChatError): + chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) + + def test_chat_db_write_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + # try to write a message without a valid file_path + message = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.write_messages([message]) + + # write a message with a valid file_path + message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' + chat_db.write_messages([message]) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + + def test_chat_db_update_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + message = chat_db.messages[0] + message.answer = Answer("New answer") + # update message without writing + chat_db.update_messages([message], write=False) + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # re-read the message and check for old content + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) + # now check with writing (message should be overwritten) + chat_db.update_messages([message], write=True) + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # test without file_path -> expect error + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.update_messages([message1]) diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 0000000..ba8a5aa --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,160 @@ +import os +import unittest +import yaml +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import cast +from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config + + +class TestAIConfigInstance(unittest.TestCase): + def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: + ai_config = cast(OpenAIConfig, ai_config_instance('openai')) + ai_reference = OpenAIConfig() + self.assertEqual(ai_config.ID, ai_reference.ID) + self.assertEqual(ai_config.name, ai_reference.name) + self.assertEqual(ai_config.api_key, ai_reference.api_key) + self.assertEqual(ai_config.system, ai_reference.system) + self.assertEqual(ai_config.model, ai_reference.model) + self.assertEqual(ai_config.temperature, ai_reference.temperature) + self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) + self.assertEqual(ai_config.top_p, ai_reference.top_p) + self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) + self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) + + def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: + conf_dict = { + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) + self.assertEqual(ai_config.system, 'Custom system') + self.assertEqual(ai_config.api_key, '9876543210') + self.assertEqual(ai_config.model, 'custom_model') + self.assertEqual(ai_config.max_tokens, 5000) + self.assertAlmostEqual(ai_config.temperature, 0.5) + self.assertAlmostEqual(ai_config.top_p, 0.8) + self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) + self.assertAlmostEqual(ai_config.presence_penalty, 0.2) + + def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: + with self.assertRaises(ConfigError): + ai_config_instance('invalid_name') + + +class TestConfig(unittest.TestCase): + def setUp(self) -> None: + self.test_file = NamedTemporaryFile(delete=False) + + def tearDown(self) -> None: + os.remove(self.test_file.name) + + def test_from_dict_should_create_config_from_dict(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'myopenai': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + config = Config.from_dict(source_dict) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertEqual(config.ais['myopenai'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') + # check that 'ID' has been added + self.assertEqual(config.ais['myopenai'].ID, 'myopenai') + + def test_create_default_should_create_default_config(self) -> None: + Config.create_default(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + default_config = yaml.load(f, Loader=yaml.FullLoader) + config_reference = Config() + self.assertEqual(default_config['db'], config_reference.db) + + def test_from_file_should_load_config_from_file(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + config = Config.from_file(self.test_file.name) + self.assertIsInstance(config, Config) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertIsInstance(config.ais['default'], AIConfig) + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + + def test_to_file_should_save_config_to_file(self) -> None: + config = Config( + db='./test_db/', + ais={ + 'myopenai': OpenAIConfig( + ID='myopenai', + system='Custom system', + api_key='9876543210', + model='custom_model', + max_tokens=5000, + temperature=0.5, + top_p=0.8, + frequency_penalty=0.7, + presence_penalty=0.2 + ) + } + ) + config.to_file(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['db'], './test_db/') + self.assertEqual(len(saved_config['ais']), 1) + self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') + + def test_from_file_error_unknown_ai(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'foobla', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + with self.assertRaises(ConfigError): + Config.from_file(self.test_file.name) diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index db5fcdb..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,233 +0,0 @@ -import unittest -import io -import pathlib -import argparse -from chatmastermind.utils import terminal_width -from chatmastermind.main import create_parser, ask_cmd -from chatmastermind.api_client import ai -from chatmastermind.configuration import Config -from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from unittest import mock -from unittest.mock import patch, MagicMock, Mock, ANY - - -class CmmTestCase(unittest.TestCase): - """ - Base class for all cmm testcases. - """ - def dummy_config(self, db: str) -> Config: - """ - Creates a dummy configuration. - """ - return Config.from_dict( - {'system': 'dummy_system', - 'db': db, - 'openai': {'api_key': 'dummy_key', - 'model': 'dummy_model', - 'max_tokens': 4000, - 'temperature': 1.0, - 'top_p': 1, - 'frequency_penalty': 0, - 'presence_penalty': 0}} - ) - - -class TestCreateChat(CmmTestCase): - - def setUp(self) -> None: - self.config = self.dummy_config(db='test_files') - self.question = "test question" - self.tags = ['test_tag'] - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['test_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 4) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['other_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 2) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.side_effect = ( - io.StringIO(dump_data({'question': 'test_content', - 'answer': 'some answer', - 'tags': ['test_tag']})), - io.StringIO(dump_data({'question': 'test_content2', - 'answer': 'some answer2', - 'tags': ['test_tag2']})), - ) - - test_chat = create_chat_hist(self.question, [], None, self.config) - - self.assertEqual(len(test_chat), 6) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': 'test_content2'}) - self.assertEqual(test_chat[4], - {'role': 'assistant', 'content': 'some answer2'}) - - -class TestHandleQuestion(CmmTestCase): - - def setUp(self) -> None: - self.question = "test question" - self.args = argparse.Namespace( - tags=['tag1'], - extags=['extag1'], - output_tags=None, - question=[self.question], - source=None, - only_source_code=False, - number=3, - max_tokens=None, - temperature=None, - model=None, - match_all_tags=False, - with_tags=False, - with_file=False, - ) - self.config = self.dummy_config(db='test_files') - - @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") - @patch("chatmastermind.main.print_tag_args") - @patch("chatmastermind.main.print_chat_hist") - @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) - @patch("chatmastermind.utils.pp") - @patch("builtins.print") - def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, - mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, - mock_create_chat_hist: MagicMock) -> None: - open_mock = MagicMock() - with patch("chatmastermind.storage.open", open_mock): - ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.extags, - []) - mock_create_chat_hist.assert_called_once_with(self.question, - self.args.tags, - self.args.extags, - self.config, - False, False, False) - mock_print_chat_hist.assert_called_once_with('test_chat', - False, - self.args.only_source_code) - mock_ai.assert_called_with("test_chat", - self.config, - self.args.number) - expected_calls = [] - for num, answer in enumerate(mock_ai.return_value[0], start=1): - title = f'-- ANSWER {num} ' - title_end = '-' * (terminal_width() - len(title)) - expected_calls.append(((f'{title}{title_end}',),)) - expected_calls.append(((answer,),)) - expected_calls.append((("-" * terminal_width(),),)) - expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - self.assertEqual(mock_print.call_args_list, expected_calls) - open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) - open_mock.assert_has_calls(open_expected_calls, any_order=True) - - -class TestSaveAnswers(CmmTestCase): - @mock.patch('builtins.open') - @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: - question = "Test question?" - answers = ["Answer 1", "Answer 2"] - tags = ["tag1", "tag2"] - otags = ["otag1", "otag2"] - config = self.dummy_config(db='test_db') - - with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ - mock.patch('chatmastermind.storage.yaml.dump'), \ - mock.patch('io.StringIO') as stringio_mock: - stringio_instance = stringio_mock.return_value - stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] - save_answers(question, answers, tags, otags, config) - - open_calls = [ - mock.call(pathlib.Path('test_db/.next'), 'r'), - mock.call(pathlib.Path('test_db/.next'), 'w'), - ] - open_mock.assert_has_calls(open_calls, any_order=True) - - -class TestAI(CmmTestCase): - - @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock) -> None: - mock_create.return_value = { - 'choices': [ - {'message': {'content': 'response_text_1'}}, - {'message': {'content': 'response_text_2'}} - ], - 'usage': {'tokens': 10} - } - - chat = [{"role": "system", "content": "hello ai"}] - config = self.dummy_config(db='dummy') - config.openai.model = "text-davinci-002" - config.openai.max_tokens = 150 - config.openai.temperature = 0.5 - - result = ai(chat, config, 2) - expected_result = (['response_text_1', 'response_text_2'], - {'tokens': 10}) - self.assertEqual(result, expected_result) - - -class TestCreateParser(CmmTestCase): - def test_create_parser(self) -> None: - with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: - mock_cmdparser = Mock() - mock_add_subparsers.return_value = mock_cmdparser - parser = create_parser() - self.assertIsInstance(parser, argparse.ArgumentParser) - mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) - mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) - self.assertTrue('.config.yaml' in parser.get_default('config')) diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..1f440df --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,836 @@ +import unittest +import pathlib +import tempfile +from typing import cast +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in +from chatmastermind.tags import Tag, TagLine + + +class SourceCodeTestCase(unittest.TestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(unittest.TestCase): + def test_question_with_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Question.txt_header}\nWhat is your name?") + + def test_question_with_answer_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Answer.txt_header}\nBob") + + def test_question_with_legal_header(self) -> None: + """ + If the header is just a part of a line, it's fine. + """ + question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + self.assertIsInstance(question, Question) + self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + + def test_question_without_header(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(unittest.TestCase): + def test_answer_with_header(self) -> None: + with self.assertRaises(MessageError): + Answer(f"{Answer.txt_header}\nno") + + def test_answer_with_legal_header(self) -> None: + answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + + def test_answer_without_header(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") + + +class MessageToFileTxtTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_txt_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""" + self.assertEqual(content, expected_content) + + def test_to_file_txt_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.txt_header} +This is a question. +""" + self.assertEqual(content, expected_content) + + def test_to_file_unsupported_file_type(self) -> None: + unsupported_file_path = pathlib.Path("example.doc") + with self.assertRaises(MessageError) as cm: + self.message_complete.to_file(unsupported_file_path) + self.assertEqual(str(cm.exception), "File type '.doc' is not supported") + + def test_to_file_no_file_path(self) -> None: + """ + Provoke an exception using an empty path. + """ + with self.assertRaises(MessageError) as cm: + # clear the internal file_path + self.message_complete.file_path = None + self.message_complete.to_file(None) + self.assertEqual(str(cm.exception), "Got no valid path to write message") + # reset the internal file_path + self.message_complete.file_path = self.file_path + + +class MessageToFileYamlTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_multiline = Message(Question('This is a\nmultiline question.'), + Answer('This is a\nmultiline answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_yaml_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: This is a question. +{Answer.yaml_key}: This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_multiline(self) -> None: + self.message_multiline.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: |- + This is a + multiline question. +{Answer.yaml_key}: |- + This is a + multiline answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"{Question.yaml_key}: This is a question.\n" + self.assertEqual(content, expected_content) + + +class MessageFromFileTxtTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_min.close() + self.file_path.unlink() + self.file_path_min.unlink() + + def test_from_file_txt_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_min(self) -> None: + """ + Read a message with only required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_txt_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_txt_empty_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or=set(), + tags_and=set())) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.txt") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_txt_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_txt_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_txt_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class MessageFromFileYamlTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + self.file_min.close() + self.file_path_min.unlink() + + def test_from_file_yaml_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_min(self) -> None: + """ + Read a message with only the required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.yaml") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_yaml_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_yaml_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_yaml_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_yaml_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class TagsFromFileTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt = pathlib.Path(self.file_txt.name) + with open(self.file_path_txt, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) + with open(self.file_path_txt_no_tags, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) + with open(self.file_path_txt_tags_empty, "w") as fd: + fd.write(f"""TAGS: +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml = pathlib.Path(self.file_yaml.name) + with open(self.file_path_yaml, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.tags_yaml_key}: + - tag1 + - tag2 + - ptag3 +""") + self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) + with open(self.file_path_yaml_no_tags, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +""") + + def tearDown(self) -> None: + self.file_txt.close() + self.file_path_txt.unlink() + self.file_yaml.close() + self.file_path_yaml.unlink() + self.file_txt_no_tags.close + self.file_path_txt_no_tags.unlink() + self.file_txt_tags_empty.close + self.file_path_txt_tags_empty.unlink() + self.file_yaml_no_tags.close() + self.file_path_yaml_no_tags.unlink() + + def test_tags_from_file_txt(self) -> None: + tags = Message.tags_from_file(self.file_path_txt) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_no_tags) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_tags_empty(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_tags_empty) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_yaml_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml_no_tags) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, contain='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, contain='R') + self.assertSetEqual(tags, set()) + + +class TagsFromDirTestCase(unittest.TestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_dir_no_tags = tempfile.TemporaryDirectory() + self.tag_sets = [ + {Tag('atag1'), Tag('atag2')}, + {Tag('btag3'), Tag('btag4')}, + {Tag('ctag5'), Tag('ctag6')} + ] + self.files = [ + pathlib.Path(self.temp_dir.name, 'file1.txt'), + pathlib.Path(self.temp_dir.name, 'file2.yaml'), + pathlib.Path(self.temp_dir.name, 'file3.txt') + ] + self.files_no_tags = [ + pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), + pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), + pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + ] + for file, tags in zip(self.files, self.tag_sets): + message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags) + message.to_file(file) + for file in self.files_no_tags: + message = Message(Question('This is a question.'), + Answer('This is an answer.')) + message.to_file(file) + + def tearDown(self) -> None: + self.temp_dir.cleanup() + self.temp_dir_no_tags.cleanup() + + def test_tags_from_dir(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) + expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2] + self.assertEqual(all_tags, expected_tags) + + def test_tags_from_dir_prefix(self) -> None: + atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a') + expected_tags = self.tag_sets[0] + self.assertEqual(atags, expected_tags) + + def test_tags_from_dir_no_tags(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name)) + self.assertSetEqual(all_tags, set()) + + +class MessageIDTestCase(unittest.TestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message = Message(Question('This is a question.'), + file_path=self.file_path) + self.message_no_file_path = Message(Question('This is a question.')) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_msg_id_txt(self) -> None: + self.assertEqual(self.message.msg_id(), self.file_path.name) + + def test_msg_id_txt_exception(self) -> None: + with self.assertRaises(MessageError): + self.message_no_file_path.msg_id() + + +class MessageHashTestCase(unittest.TestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a new question.'), + file_path=pathlib.Path('/tmp/foo/bla')) + self.message3 = Message(Question('This is a question.'), + Answer('This is an answer.'), + file_path=pathlib.Path('/tmp/foo/bla')) + # message4 is a copy of message1, because only question and + # answer are used for hashing and comparison + self.message4 = Message(Question('This is a question.'), + tags={Tag('tag1'), Tag('tag2')}, + ai='Blabla', + file_path=pathlib.Path('foobla')) + + def test_set_hashing(self) -> None: + msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4} + self.assertEqual(len(msgs), 3) + for msg in [self.message1, self.message2, self.message3]: + self.assertIn(msg, msgs) + + +class MessageTagsStrTestCase(unittest.TestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_tags_str(self) -> None: + self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') + + +class MessageFilterTagsTestCase(unittest.TestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_filter_tags(self) -> None: + tags_all = self.message.filter_tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.message.filter_tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.message.filter_tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + +class MessageInTestCase(unittest.TestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/bla/foo')) + + def test_message_in(self) -> None: + self.assertTrue(message_in(self.message1, [self.message1])) + self.assertFalse(message_in(self.message1, [self.message2])) + + +class MessageRenameTagsTestCase(unittest.TestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_rename_tags(self) -> None: + self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) + self.assertIsNotNone(self.message.tags) + self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] + + +class MessageToStrTestCase(unittest.TestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_to_str(self) -> None: + expected_output = f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(), expected_output) + + def test_to_str_with_tags_and_file(self) -> None: + expected_output = f"""{TagLine.prefix} atag1 btag2 +FILE: /tmp/foo/bla +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py new file mode 100644 index 0000000..40ea4d8 --- /dev/null +++ b/tests/test_question_cmd.py @@ -0,0 +1,162 @@ +import os +import unittest +import argparse +import tempfile +from pathlib import Path +from unittest.mock import MagicMock +from chatmastermind.commands.question import create_message +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB + + +class TestMessageCreate(unittest.TestCase): + """ + Test if messages created by the 'question' command have + the correct format. + """ + def setUp(self) -> None: + # create ChatDB structure + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), + db_path=Path(self.db_path.name)) + # create arguments mock + self.args = MagicMock(spec=argparse.Namespace) + self.args.source_text = None + self.args.source_code = None + self.args.AI = None + self.args.model = None + self.args.output_tags = None + # File 1 : no source code block, only text + self.source_file1 = tempfile.NamedTemporaryFile(delete=False) + self.source_file1_content = """This is just text. +No source code. +Nope. Go look elsewhere!""" + with open(self.source_file1.name, 'w') as f: + f.write(self.source_file1_content) + # File 2 : one embedded source code block + self.source_file2 = tempfile.NamedTemporaryFile(delete=False) + self.source_file2_content = """This is just text. +``` +This is embedded source code. +``` +And some text again.""" + with open(self.source_file2.name, 'w') as f: + f.write(self.source_file2_content) + # File 3 : all source code + self.source_file3 = tempfile.NamedTemporaryFile(delete=False) + self.source_file3_content = """This is all source code. +Yes, really. +Language is called 'brainfart'.""" + with open(self.source_file3.name, 'w') as f: + f.write(self.source_file3_content) + # File 4 : two source code blocks + self.source_file4 = tempfile.NamedTemporaryFile(delete=False) + self.source_file4_content = """This is just text. +``` +This is embedded source code. +``` +And some text again. +``` +This is embedded source code. +``` +Aaaand again some text.""" + with open(self.source_file4.name, 'w') as f: + f.write(self.source_file4_content) + + def tearDown(self) -> None: + os.remove(self.source_file1.name) + os.remove(self.source_file2.name) + os.remove(self.source_file3.name) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return list(Path(tmp_dir.name).glob('*.[ty]*')) + + def test_message_file_created(self) -> None: + self.args.ask = ["What is this?"] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + create_message(self.chat, self.args) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + message = Message.from_file(cache_dir_files[0]) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] + + def test_single_question(self) -> None: + self.args.ask = ["What is this?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) + self.assertEqual(len(message.question.source_code()), 0) + + def test_multipart_question(self) -> None: + self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("""What is this + +'bard' thing? + +Is it good?""")) + + def test_single_question_with_text_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_text = [f"{self.source_file1.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains no source code (only text) + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question(f"""What is this? + +{self.source_file1_content}""")) + + def test_single_question_with_text_file_and_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` +""")) + + def test_single_question_with_code_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file3.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file is complete source code + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question(f"""What is this? + +``` +{self.source_file3_content} +```""")) + + def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file4.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 2 source code blocks + # -> expect them in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` + + +``` +This is embedded source code. +``` +""")) diff --git a/tests/test_tags.py b/tests/test_tags.py new file mode 100644 index 0000000..edd3c05 --- /dev/null +++ b/tests/test_tags.py @@ -0,0 +1,163 @@ +import unittest +from chatmastermind.tags import Tag, TagLine, TagError + + +class TestTag(unittest.TestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(unittest.TestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: atag1 btag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) + + def test_tags_empty(self) -> None: + tagline = TagLine('TAGS:') + self.assertSetEqual(tagline.tags(), set()) + + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_contain(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(contain='t') + self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(contain='1') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(contain='R') + self.assertSetEqual(tags, set()) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) + + # Test case 10: 'tags_or' and 'tags_and' are empty, match no tags + self.assertFalse(tagline.match_tags(set(), set(), None)) + + # Test case 11: 'tags_or' is empty, match no tags + self.assertFalse(tagline.match_tags(set(), None, None)) + + # Test case 12: 'tags_and' is empty, match no tags + self.assertFalse(tagline.match_tags(None, set(), None)) + + # Test case 13: 'tags_or' is empty, match 'tags_and' + tags_and = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(None, tags_and, None)) + + # Test case 14: 'tags_and' is empty, match 'tags_or' + tags_or = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(tags_or, None, None))