Compare commits

..

No commits in common. "a895c1fc6a3a5f8099792454844a19b1dd132fbe" and "3ef1339cc0d9b58103599c068b464b2b93183641" have entirely different histories.

26 changed files with 674 additions and 3591 deletions

3
.gitignore vendored
View File

@ -130,5 +130,4 @@ dmypy.json
.config.yaml
db
noweb
Session.vim
noweb

View File

@ -1,74 +0,0 @@
from dataclasses import dataclass
from typing import Protocol, Optional, Union
from .configuration import AIConfig
from .tags import Tag
from .message import Message
from .chat import Chat
class AIError(Exception):
pass
@dataclass
class Tokens:
prompt: int = 0
completion: int = 0
total: int = 0
@dataclass
class AIResponse:
"""
The response to an AI request. Consists of one or more messages
(each containing the question and a single answer) and the nr.
of used tokens.
"""
messages: list[Message]
tokens: Optional[Tokens] = None
class AI(Protocol):
"""
The base class for AI clients.
"""
ID: str
name: str
config: AIConfig
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request. Parameters:
* question: the question to ask
* chat: the chat history to be added as context
* num_answers: nr. of requested answers (corresponds
to the nr. of messages in the 'AIResponse')
* otags: the output tags, i. e. the tags that all
returned messages should contain
"""
raise NotImplementedError
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
or chat. Note that the computation may not be 100% accurate
and is not implemented for all AIs.
"""
raise NotImplementedError
def print(self) -> None:
"""
Print some info about the current AI, like system message.
"""
pass

View File

@ -1,43 +0,0 @@
"""
Creates different AI instances, based on the given configuration.
"""
import argparse
from typing import cast
from .configuration import Config, AIConfig, OpenAIConfig
from .ai import AI, AIError
from .ais.openai import OpenAI
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
"""
Creates an AI subclass instance from the given arguments
and configuration file. If AI has not been set in the
arguments, it searches for the ID 'default'. If that
is not found, it uses the first AI in the list.
"""
ai_conf: AIConfig
if args.AI:
try:
ai_conf = config.ais[args.AI]
except KeyError:
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
elif 'default' in config.ais:
ai_conf = config.ais['default']
else:
try:
ai_conf = next(iter(config.ais.values()))
except StopIteration:
raise AIError("No AI found in this configuration")
if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf))
if args.model:
ai.config.model = args.model
if args.max_tokens:
ai.config.max_tokens = args.max_tokens
if args.temperature:
ai.config.temperature = args.temperature
return ai
else:
raise AIError(f"AI '{args.AI}' is not supported")

View File

@ -1,106 +0,0 @@
"""
Implements the OpenAI client classes and functions.
"""
import openai
from typing import Optional, Union
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
from ..ai import AI, AIResponse, Tokens
from ..configuration import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAI(AI):
"""
The OpenAI AI client.
"""
def __init__(self, config: OpenAIConfig) -> None:
self.ID = config.ID
self.name = config.name
self.config = config
openai.api_key = config.api_key
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request, asking the given question with the given
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
question.answer = Answer(response['choices'][0]['message']['content'])
question.tags = otags
question.ai = self.ID
question.model = self.config.model
answers: list[Message] = [question]
for choice in response['choices'][1:]: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.ID,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
response['usage']['completion_tokens'],
response['usage']['total_tokens']))
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
if question:
append('user', question.question)
return oai_chat
def tokens(self, data: Union[Message, Chat]) -> int:
raise NotImplementedError
def print(self) -> None:
print(f"MODEL: {self.config.model}")
print("=== SYSTEM ===")
print(self.config.system)

View File

@ -0,0 +1,45 @@
import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None:
openai.api_key = api_key
def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: ChatType,
config: Config,
number: int
) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
response = openai.ChatCompletion.create(
model=config.openai.model,
messages=chat,
temperature=config.openai.temperature,
max_tokens=config.openai.max_tokens,
top_p=config.openai.top_p,
n=number,
frequency_penalty=config.openai.frequency_penalty,
presence_penalty=config.openai.presence_penalty)
result = []
for choice in response['choices']: # type: ignore
result.append(choice['message']['content'].strip())
return result, dict(response['usage']) # type: ignore

View File

@ -1,407 +0,0 @@
"""
Module implementing various chat classes and functions for managing a chat history.
"""
import shutil
from pathlib import Path
from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
class ChatError(Exception):
pass
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_paged(text: str) -> None:
pager(text)
def read_dir(dir_path: Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> list[Message]:
"""
Reads the messages from the given folder.
Parameters:
* 'dir_path': source directory
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
try:
message = Message.from_file(file_path, mfilter)
if message:
messages.append(message)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return messages
def make_file_path(dir_path: Path,
file_suffix: str,
next_fid: Callable[[], int]) -> Path:
"""
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
while file_path.exists():
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
return file_path
def write_dir(dir_path: Path,
messages: list[Message],
file_suffix: str,
next_fid: Callable[[], int]) -> None:
"""
Write all messages to the given directory. If a message has no file_path,
a new one will be created. If message.file_path exists, it will be modified
to point to the given directory.
Parameters:
* 'dir_path': destination directory
* 'messages': list of messages to write
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
* 'next_fid': callable that returns the next file ID
"""
for message in messages:
file_path = message.file_path
# message has no file_path: create one
if not file_path:
file_path = make_file_path(dir_path, file_suffix, next_fid)
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
message.to_file(file_path)
def clear_dir(dir_path: Path,
glob: Optional[str] = None) -> None:
"""
Deletes all Message files in the given directory.
"""
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in file_iter:
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
file_path.unlink(missing_ok=True)
@dataclass
class Chat:
"""
A class containing a complete chat history.
"""
messages: list[Message]
def filter(self, mfilter: MessageFilter) -> None:
"""
Use 'Message.match(mfilter) to remove all messages that
don't fulfill the filter requirements.
"""
self.messages = [m for m in self.messages if m.match(mfilter)]
def sort(self, reverse: bool = False) -> None:
"""
Sort the messages according to 'Message.msg_id()'.
"""
try:
# the message may not have an ID if it doesn't have a file_path
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
except MessageError:
pass
def clear(self) -> None:
"""
Delete all messages.
"""
self.messages = []
def add_messages(self, messages: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += messages
self.sort()
def latest_message(self) -> Optional[Message]:
"""
Returns the last added message (according to the file ID).
"""
if len(self.messages) > 0:
self.sort()
return self.messages[-1]
else:
return None
def find_messages(self, msg_names: list[str]) -> list[Message]:
"""
Search and return the messages with the given names. Names can either be filenames
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
caller should check the result if he requires all messages).
"""
return [m for m in self.messages
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
def remove_messages(self, msg_names: list[str]) -> None:
"""
Remove the messages with the given names. Names can either be filenames
(incl. the suffix) or full paths.
"""
self.messages = [m for m in self.messages
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Get the tags of all messages, optionally filtered by prefix or substring.
"""
tags: set[Tag] = set()
for m in self.messages:
tags |= m.filter_tags(prefix, contain)
return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
"""
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
tags: list[Tag] = []
for m in self.messages:
tags += [tag for tag in m.filter_tags(prefix, contain)]
return {tag: tags.count(tag) for tag in sorted(tags)}
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by all messages in this chat.
If unknown, 0 is returned.
"""
return sum(m.tokens() for m in self.messages)
def print(self, source_code_only: bool = False,
with_tags: bool = False, with_files: bool = False,
paged: bool = True) -> None:
output: list[str] = []
for message in self.messages:
if source_code_only:
output.append(message.to_str(source_code_only=True))
continue
output.append(message.to_str(with_tags, with_files))
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
print(*output, sep='\n')
@dataclass
class ChatDB(Chat):
"""
A 'Chat' class that is bound to a given directory structure. Supports reading
and writing messages from / to that structure. Such a structure consists of
two directories: a 'cache directory', where all messages are temporarily
stored, and a 'DB' directory, where selected messages can be stored
persistently.
"""
default_file_suffix: ClassVar[str] = '.txt'
cache_path: Path
db_path: Path
# a MessageFilter that all messages must match (if given)
mfilter: Optional[MessageFilter] = None
file_suffix: str = default_file_suffix
# the glob pattern for all messages
glob: Optional[str] = None
def __post_init__(self) -> None:
# contains the latest message ID
self.next_fname = self.db_path / '.next'
# make all paths absolute
self.cache_path = self.cache_path.absolute()
self.db_path = self.db_path.absolute()
@classmethod
def from_dir(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
glob: Optional[str] = None,
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a 'ChatDB' instance from the given directory structure.
Reads all messages from 'db_path' into the local message list.
Parameters:
* 'cache_path': path to the directory for temporary messages
* 'db_path': path to the directory for persistent messages
* 'glob': if specified, files will be filtered using 'path.glob()',
otherwise it uses 'path.iterdir()'.
* 'mfilter': use with 'Message.from_file()' to filter messages
when reading them.
"""
messages = read_dir(db_path, glob, mfilter)
return cls(messages, cache_path, db_path, mfilter,
cls.default_file_suffix, glob)
@classmethod
def from_messages(cls: Type[ChatDBInst],
cache_path: Path,
db_path: Path,
messages: list[Message],
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
"""
Create a ChatDB instance from the given message list.
"""
return cls(messages, cache_path, db_path, mfilter)
def get_next_fid(self) -> int:
try:
with open(self.next_fname, 'r') as f:
next_fid = int(f.read()) + 1
self.set_next_fid(next_fid)
return next_fid
except Exception:
self.set_next_fid(1)
return 1
def set_next_fid(self, fid: int) -> None:
with open(self.next_fname, 'w') as f:
f.write(f'{fid}')
def read_db(self) -> None:
"""
Reads new messages from the DB directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def read_cache(self) -> None:
"""
Reads new messages from the cache directory. New ones are added to the internal list,
existing ones are replaced. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list.
"""
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
# remove all messages from self.messages that are in the new list
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
# copy the messages from the temporary list to self.messages and sort them
self.messages += new_messages
self.sort()
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
"""
write_dir(self.cache_path,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def clear_cache(self) -> None:
"""
Deletes all Message files from the cache dir and removes those messages from
the internal list.
"""
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()
def update_messages(self, messages: list[Message], write: bool = True) -> None:
"""
Update existing messages. A message is determined as 'existing' if a message with
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
existing messages.
"""
if any(not message_in(m, self.messages) for m in messages):
raise ChatError("Can't update messages that are not in the internal list")
# remove old versions and add new ones
self.messages = [m for m in self.messages if not message_in(m, messages)]
self.messages += messages
self.sort()
# write the UPDATED messages if requested
if write:
self.write_messages(messages)

View File

@ -1,11 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
def config_cmd(args: argparse.Namespace) -> None:
"""
Handler for the 'config' command.
"""
if args.create:
Config.create_default(Path(args.create))

View File

@ -1,23 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
from ..message import MessageFilter
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags,
tags_and=args.and_tags,
tags_not=args.exclude_tags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)

View File

@ -1,27 +0,0 @@
import sys
import argparse
from pathlib import Path
from ..configuration import Config
from ..message import Message, MessageError
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = Path(args.file)
try:
message = Message.from_file(fname)
if message:
if args.question:
print(message.question)
elif args.answer:
print(message.answer)
elif message.answer and args.only_source_code:
for code in message.answer.source_code():
print(code)
else:
print(message.to_str())
except MessageError:
print(f"File is not a valid message: {args.file}")
sys.exit(1)

View File

@ -1,94 +0,0 @@
import argparse
from pathlib import Path
from itertools import zip_longest
from ..configuration import Config
from ..chat import ChatDB
from ..message import Message, MessageFilter, Question, source_code
from ..ai_factory import create_ai
from ..ai import AI, AIResponse
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
"""
Creates (and writes) a new message from the given arguments.
"""
question_parts = []
question_list = args.ask if args.ask is not None else []
text_files = args.source_text if args.source_text is not None else []
code_files = args.source_code if args.source_code is not None else []
for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None):
if question is not None and len(question.strip()) > 0:
question_parts.append(question)
if source is not None and len(source) > 0:
with open(source) as r:
content = r.read().strip()
if len(content) > 0:
question_parts.append(content)
if code is not None and len(code) > 0:
with open(code) as r:
content = r.read().strip()
if len(content) == 0:
continue
# try to extract and add source code
code_parts = source_code(content, include_delims=True)
if len(code_parts) > 0:
question_parts += code_parts
# if there's none, add the whole file
else:
question_parts.append(f"```\n{content}\n```")
full_question = '\n\n'.join(question_parts)
message = Message(question=Question(full_question),
tags=args.output_tags, # FIXME
ai=args.AI,
model=args.model)
chat.add_to_cache([message])
return message
def question_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'question' command.
"""
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
tags_and=args.and_tags if args.and_tags is not None else set(),
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db),
mfilter=mfilter)
# if it's a new question, create and store it immediately
if args.ask or args.create:
message = create_message(chat, args)
if args.create:
return
# create the correct AI instance
ai: AI = create_ai(args, config)
if args.ask:
ai.print()
chat.print(paged=False)
response: AIResponse = ai.request(message,
chat,
args.num_answers, # FIXME
args.output_tags) # FIXME
chat.update_messages([response.messages[0]])
chat.add_to_cache(response.messages[1:])
for idx, msg in enumerate(response.messages):
print(f"=== ANSWER {idx+1} ===")
print(msg.answer)
if response.tokens:
print("===============")
print(response.tokens)
elif args.repeat is not None:
lmessage = chat.latest_message()
assert lmessage
# TODO: repeat either the last question or the
# one(s) given in 'args.repeat' (overwrite
# existing ones if 'args.overwrite' is True)
pass
elif args.process is not None:
# TODO: process either all questions without an
# answer or the one(s) given in 'args.process'
pass

View File

@ -1,17 +0,0 @@
import argparse
from pathlib import Path
from ..configuration import Config
from ..chat import ChatDB
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming

View File

@ -1,166 +1,66 @@
import yaml
from pathlib import Path
from typing import Type, TypeVar, Any, Optional, ClassVar
from dataclasses import dataclass, asdict, field
from typing import Type, TypeVar, Any
from dataclasses import dataclass, asdict
ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
default_config_path = '.config.yaml'
class ConfigError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
@dataclass
class AIConfig:
"""
The base class of all AI configurations.
"""
# the name of the AI the config class represents
# -> it's a class variable and thus not part of the
# dataclass constructor
name: ClassVar[str]
# a user-defined ID for an AI configuration entry
ID: str
# the name must not be changed
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name':
raise AttributeError("'{name}' is not allowed to be changed")
else:
super().__setattr__(name, value)
@dataclass
class OpenAIConfig(AIConfig):
class OpenAIConfig():
"""
The OpenAI section of the configuration file.
"""
name: ClassVar[str] = 'openai'
# all members have default values, so we can easily create
# a default configuration
ID: str = 'myopenai'
api_key: str = '0123456789'
model: str = 'gpt-3.5-turbo-16k'
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
system: str = 'You are an assistant'
api_key: str
model: str
temperature: float
max_tokens: int
top_p: float
frequency_penalty: float
presence_penalty: float
@classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
"""
Create OpenAIConfig from a dict.
"""
res = cls(
return cls(
api_key=str(source['api_key']),
model=str(source['model']),
max_tokens=int(source['max_tokens']),
temperature=float(source['temperature']),
top_p=float(source['top_p']),
frequency_penalty=float(source['frequency_penalty']),
presence_penalty=float(source['presence_penalty']),
system=str(source['system'])
presence_penalty=float(source['presence_penalty'])
)
# overwrite default ID if provided
if 'ID' in source:
res.ID = source['ID']
return res
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given name.
"""
if name.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"Unknown AI '{name}'")
def create_default_ai_configs() -> dict[str, AIConfig]:
"""
Create a dict containing default configurations for all supported AIs.
"""
return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais}
@dataclass
class Config:
class Config():
"""
The configuration file structure.
"""
# all members have default values, so we can easily create
# a default configuration
db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
system: str
db: str
openai: OpenAIConfig
@classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
"""
Create Config from a dict (with the same format as the config file).
Create OpenAIConfig from a dict.
"""
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for ID, conf in source['ais'].items():
# add the AI ID to the config (for easy internal access)
conf['ID'] = ID
ai_conf = ai_config_instance(conf['name'], conf)
ais[ID] = ai_conf
return cls(
system=str(source['system']),
db=str(source['db']),
ais=ais
openai=OpenAIConfig.from_dict(source['openai'])
)
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader)
return cls.from_dict(source)
def to_file(self, file_path: Path) -> None:
# remove the AI name from the config (for a cleaner format)
data = self.as_dict()
for conf in data['ais'].values():
del (conf['ID'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]:
res = asdict(self)
# add the AI name manually (as first element)
# (not done by 'asdict' because it's a class variable)
for ID, conf in res['ais'].items():
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
return res
def to_file(self, path: str) -> None:
with open(path, 'w') as f:
yaml.dump(asdict(self), f)

View File

@ -2,29 +2,141 @@
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
import yaml
import sys
import argcomplete
import argparse
from pathlib import Path
import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from itertools import zip_longest
from typing import Any
from .configuration import Config, default_config_path
from .message import Message
from .commands.question import question_cmd
from .commands.tags import tags_cmd
from .commands.config import config_cmd
from .commands.hist import hist_cmd
from .commands.print import print_cmd
default_config = '.config.yaml'
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
config = Config.from_file(parsed_args.config)
return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
return get_tags_unique(config, prefix)
def create_question_with_hist(args: argparse.Namespace,
config: Config,
) -> tuple[ChatType, str, list[str]]:
"""
Creates the "AI request", including the question and chat history as determined
by the specified tags.
"""
tags = args.tags or []
extags = args.extags or []
otags = args.output_tags or []
if not args.only_source_code:
print_tag_args(tags, extags, otags)
question_parts = []
question_list = args.question if args.question is not None else []
source_list = args.source if args.source is not None else []
for question, source in zip_longest(question_list, source_list, fillvalue=None):
if question is not None and source is not None:
with open(source) as r:
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
elif question is not None:
question_parts.append(question)
elif source is not None:
with open(source) as r:
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, extags, config,
args.match_all_tags, False, False)
return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tag' command.
"""
if args.list:
print_tags_frequency(get_tags(config, None))
def config_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'config' command.
"""
if args.list_models:
print_models()
elif args.print_model:
print(config.openai.model)
elif args.model:
config.openai.model = args.model
config.to_file(args.config)
def ask_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'ask' command.
"""
if args.max_tokens:
config.openai.max_tokens = args.max_tokens
if args.temperature:
config.openai.temperature = args.temperature
if args.model:
config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config)
print("-" * terminal_width())
print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
tags = args.tags or []
extags = args.extags or []
chat = create_chat_hist(None, tags, extags, config,
args.match_all_tags,
args.with_tags,
args.with_files)
print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = pathlib.Path(args.file)
if fname.suffix == '.yaml':
with open(args.file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif fname.suffix == '.txt':
data = read_file(fname)
else:
print(f"Unknown file type: {args.file}")
sys.exit(1)
if args.only_source_code:
display_source_code(data['answer'])
elif args.answer:
print(data['answer'].strip())
elif args.question:
print(data['question'].strip())
else:
print(dump_data(data).strip())
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="ChatMastermind is a Python application that automates conversation with AI")
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path)
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
# subcommand-parser
cmdparser = parser.add_subparsers(dest='command',
@ -34,66 +146,58 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
help='List of tags (one must match)', metavar='OTAGS')
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tag names', metavar='TAGS')
tag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
help='List of tags (all must match)', metavar='ATAGS')
atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
help='List of tags to exclude', metavar='XTAGS')
etag_arg.completer = tags_completer # type: ignore
extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+',
help='List of tag names to exclude', metavar='EXTAGS')
extag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tags (default: use input tags)', metavar='OUTTAGS')
help='List of output tag names, default is input', metavar='OTAGS')
otag_arg.completer = tags_completer # type: ignore
tag_parser.add_argument('-a', '--match-all-tags',
help="All given tags must match when selecting chat history entries",
action='store_true')
# enable autocompletion for tags
# a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI ID to use')
ai_parser.add_argument('-M', '--model', help='Model to use')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
# 'question' command parser
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser],
help="ask, create and process questions.",
aliases=['q'])
question_cmd_parser.set_defaults(func=question_cmd)
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
action='store_true')
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
# 'ask' command parser
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
help="Ask a question.",
aliases=['a'])
ask_cmd_parser.set_defaults(func=ask_cmd)
ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask',
required=True)
ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
ask_cmd_parser.add_argument('-M', '--model', help='Model to use')
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
action='store_true')
# 'hist' command parser
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
help="Print chat history.",
aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure",
action='store_true')
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true')
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
# 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.",
aliases=['t'])
tags_cmd_parser.set_defaults(func=tags_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'tag' command parser
tag_cmd_parser = cmdparser.add_parser('tag',
help="Manage tags.",
aliases=['t'])
tag_cmd_parser.set_defaults(func=tag_cmd)
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true')
# 'config' command parser
config_cmd_parser = cmdparser.add_parser('config',
@ -105,11 +209,11 @@ def create_parser() -> argparse.ArgumentParser:
action='store_true')
config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
action='store_true')
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file")
config_group.add_argument('-M', '--model', help="Set model in the config file")
# 'print' command parser
print_cmd_parser = cmdparser.add_parser('print',
help="Print message files.",
help="Print files.",
aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
@ -126,12 +230,11 @@ def main() -> int:
parser = create_parser()
args = parser.parse_args()
command = parser.parse_args()
config = Config.from_file(args.config)
if command.func == config_cmd:
command.func(command)
else:
config = Config.from_file(args.config)
command.func(command, config)
openai_api_key(config.openai.api_key)
command.func(command, config)
return 0

View File

@ -1,561 +0,0 @@
"""
Module implementing message related functions and classes.
"""
import pathlib
import yaml
import tempfile
import shutil
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
from dataclasses import dataclass, asdict, field
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
QuestionInst = TypeVar('QuestionInst', bound='Question')
AnswerInst = TypeVar('AnswerInst', bound='Answer')
MessageInst = TypeVar('MessageInst', bound='Message')
AILineInst = TypeVar('AILineInst', bound='AILine')
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
class MessageError(Exception):
pass
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
"""
Changes the YAML dump style to multiline syntax for multiline strings.
"""
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
yaml.add_representer(str, str_presenter)
def source_code(text: str, include_delims: bool = False) -> list[str]:
"""
Extract all source code sections from the given text, i. e. all lines
surrounded by lines tarting with '```'. If 'include_delims' is True,
the surrounding lines are included, otherwise they are omitted. The
result list contains every source code section as a single string.
The order in the list represents the order of the sections in the text.
"""
code_sections: list[str] = []
code_lines: list[str] = []
in_code_block = False
for line in text.split('\n'):
if line.strip().startswith('```'):
if include_delims:
code_lines.append(line)
if in_code_block:
code_sections.append('\n'.join(code_lines) + '\n')
code_lines.clear()
in_code_block = not in_code_block
elif in_code_block:
code_lines.append(line)
return code_sections
def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool:
"""
Searches the given message list for a message with the same file
name as the given one (i. e. it compares Message.file_path.name).
If the given message has no file_path, False is returned.
"""
if not message.file_path:
return False
for m in messages:
if m.file_path and m.file_path.name == message.file_path.name:
return True
return False
@dataclass(kw_only=True)
class MessageFilter:
"""
Various filters for a Message.
"""
tags_or: Optional[set[Tag]] = None
tags_and: Optional[set[Tag]] = None
tags_not: Optional[set[Tag]] = None
ai: Optional[str] = None
model: Optional[str] = None
question_contains: Optional[str] = None
answer_contains: Optional[str] = None
answer_state: Optional[Literal['available', 'missing']] = None
ai_state: Optional[Literal['available', 'missing']] = None
model_state: Optional[Literal['available', 'missing']] = None
class AILine(str):
"""
A line that represents the AI name in a '.txt' file..
"""
prefix: Final[str] = 'AI:'
def __new__(cls: Type[AILineInst], string: str) -> AILineInst:
if not string.startswith(cls.prefix):
raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance
def ai(self) -> str:
return self[len(self.prefix):].strip()
@classmethod
def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst:
return cls(' '.join([cls.prefix, ai]))
class ModelLine(str):
"""
A line that represents the model name in a '.txt' file..
"""
prefix: Final[str] = 'MODEL:'
def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst:
if not string.startswith(cls.prefix):
raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'")
instance = super().__new__(cls, string)
return instance
def model(self) -> str:
return self[len(self.prefix):].strip()
@classmethod
def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst:
return cls(' '.join([cls.prefix, model]))
class Answer(str):
"""
A single answer with a defined header.
"""
tokens: int = 0 # tokens used by this answer
txt_header: ClassVar[str] = '==== ANSWER ===='
yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
"""
Make sure the answer string does not contain the header as a whole line.
"""
if cls.txt_header in string.split('\n'):
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
class Question(str):
"""
A single question with a defined header.
"""
tokens: int = 0 # tokens used by this question
txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'question'
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
"""
Make sure the question string does not contain the header as a whole line
(also not that from 'Answer', so it's always clear where the answer starts).
"""
string_lines = string.split('\n')
if cls.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
if Answer.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
instance = super().__new__(cls, string)
return instance
@classmethod
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
"""
Build Question from a list of strings. Make sure strings do not contain the header.
"""
if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip())
return instance
def source_code(self, include_delims: bool = False) -> list[str]:
"""
Extract and return all source code sections.
"""
return source_code(self, include_delims)
@dataclass
class Message():
"""
Single message. Consists of a question and optionally an answer, a set of tags
and a file path.
"""
question: Question
answer: Optional[Answer] = None
# metadata, ignored when comparing messages
tags: Optional[set[Tag]] = field(default=None, compare=False)
ai: Optional[str] = field(default=None, compare=False)
model: Optional[str] = field(default=None, compare=False)
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
# class variables
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
tags_yaml_key: ClassVar[str] = 'tags'
file_yaml_key: ClassVar[str] = 'file_path'
ai_yaml_key: ClassVar[str] = 'ai'
model_yaml_key: ClassVar[str] = 'model'
def __hash__(self) -> int:
"""
The hash value is computed based on immutable members.
"""
return hash((self.question, self.answer))
@classmethod
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
"""
Create a Message from the given dict.
"""
return cls(question=data[Question.yaml_key],
answer=data.get(Answer.yaml_key, None),
tags=set(data.get(cls.tags_yaml_key, [])),
ai=data.get(cls.ai_yaml_key, None),
model=data.get(cls.model_yaml_key, None),
file_path=data.get(cls.file_yaml_key, None))
@classmethod
def tags_from_file(cls: Type[MessageInst],
file_path: pathlib.Path,
prefix: Optional[str] = None,
contain: Optional[str] = None) -> set[Tag]:
"""
Return only the tags from the given Message file,
optionally filtered based on prefix or contained string.
"""
tags: set[Tag] = set()
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
# for TXT, it's enough to read the TagLine
if file_path.suffix == '.txt':
with open(file_path, "r") as fd:
try:
tags = TagLine(fd.readline()).tags(prefix, contain)
except TagError:
pass # message without tags
else: # '.yaml'
try:
message = cls.from_file(file_path)
if message:
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
if msg_tags:
tags = msg_tags
return tags
@classmethod
def tags_from_dir(cls: Type[MessageInst],
path: pathlib.Path,
glob: Optional[str] = None,
prefix: Optional[str] = None,
contain: Optional[str] = None) -> set[Tag]:
"""
Return only the tags from message files in the given directory.
The files can be filtered using 'glob', the tags by using 'prefix'
and 'contain'.
"""
tags: set[Tag] = set()
file_iter = path.glob(glob) if glob else path.iterdir()
for file_path in sorted(file_iter):
if file_path.is_file():
try:
tags |= cls.tags_from_file(file_path, prefix, contain)
except MessageError as e:
print(f"Error processing message in '{file_path}': {str(e)}")
return tags
@classmethod
def from_file(cls: Type[MessageInst], file_path: pathlib.Path,
mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]:
"""
Create a Message from the given file. Returns 'None' if the message does
not fulfill the filter requirements. For TXT files, the tags are matched
before building the whole message. The other filters are applied afterwards.
"""
if not file_path.exists():
raise MessageError(f"Message file '{file_path}' does not exist")
if file_path.suffix not in cls.file_suffixes:
raise MessageError(f"File type '{file_path.suffix}' is not supported")
if file_path.suffix == '.txt':
message = cls.__from_file_txt(file_path,
mfilter.tags_or if mfilter else None,
mfilter.tags_and if mfilter else None,
mfilter.tags_not if mfilter else None)
else:
message = cls.__from_file_yaml(file_path)
if message and (mfilter is None or message.match(mfilter)):
return message
else:
return None
@classmethod
def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11
tags_or: Optional[set[Tag]] = None,
tags_and: Optional[set[Tag]] = None,
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
"""
Create a Message from the given TXT file. Expects the following file structures:
For '.txt':
* TagLine [Optional]
* AI [Optional]
* Model [Optional]
* Question.txt_header
* Question
* Answer.txt_header [Optional]
* Answer [Optional]
Returns 'None' if the message does not fulfill the tag requirements.
"""
tags: set[Tag] = set()
question: Question
answer: Optional[Answer] = None
ai: Optional[str] = None
model: Optional[str] = None
with open(file_path, "r") as fd:
# TagLine (Optional)
try:
pos = fd.tell()
tags = TagLine(fd.readline()).tags()
except TagError:
fd.seek(pos)
if tags_or or tags_and or tags_not:
# match with an empty set if the file has no tags
if not match_tags(tags, tags_or, tags_and, tags_not):
return None
# AILine (Optional)
try:
pos = fd.tell()
ai = AILine(fd.readline()).ai()
except MessageError:
fd.seek(pos)
# ModelLine (Optional)
try:
pos = fd.tell()
model = ModelLine(fd.readline()).model()
except MessageError:
fd.seek(pos)
# Question and Answer
text = fd.read().strip().split('\n')
try:
question_idx = text.index(Question.txt_header) + 1
except ValueError:
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'")
try:
answer_idx = text.index(Answer.txt_header)
question = Question.from_list(text[question_idx:answer_idx])
answer = Answer.from_list(text[answer_idx + 1:])
except ValueError:
question = Question.from_list(text[question_idx:])
return cls(question, answer, tags, ai, model, file_path)
@classmethod
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
"""
Create a Message from the given YAML file. Expects the following file structures:
* Question.yaml_key: single or multiline string
* Answer.yaml_key: single or multiline string [Optional]
* Message.tags_yaml_key: list of strings [Optional]
* Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional]
"""
with open(file_path, "r") as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
"""
Return the current Message as a string.
"""
output: list[str] = []
if source_code_only:
# use the source code from answer only
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output) if len(output) > 0 else ''
if with_tags:
output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(self.answer)
return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(True, True, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
"""
Write a Message to the given file. Type is determined based on the suffix.
Currently supported suffixes: ['.txt', '.yaml']
"""
if file_path:
self.file_path = file_path
if not self.file_path:
raise MessageError("Got no valid path to write message")
if self.file_path.suffix not in self.file_suffixes:
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
# TXT
if self.file_path.suffix == '.txt':
return self.__to_file_txt(self.file_path)
elif self.file_path.suffix == '.yaml':
return self.__to_file_yaml(self.file_path)
def __to_file_txt(self, file_path: pathlib.Path) -> None:
"""
Write a Message to the given file in TXT format.
Creates the following file structures:
* TagLine
* AI [Optional]
* Model [Optional]
* Question.txt_header
* Question
* Answer.txt_header
* Answer
"""
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
if self.tags:
temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
if self.ai:
temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
if self.model:
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
if self.answer:
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
shutil.move(temp_file_path, file_path)
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
"""
Write a Message to the given file in YAML format.
Creates the following file structures:
* Question.yaml_key: single or multiline string
* Answer.yaml_key: single or multiline string
* Message.tags_yaml_key: list of strings
* Message.ai_yaml_key: str [Optional]
* Message.model_yaml_key: str [Optional]
"""
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
temp_file_path = pathlib.Path(temp_fd.name)
data: YamlDict = {Question.yaml_key: str(self.question)}
if self.answer:
data[Answer.yaml_key] = str(self.answer)
if self.ai:
data[self.ai_yaml_key] = self.ai
if self.model:
data[self.model_yaml_key] = self.model
if self.tags:
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
yaml.dump(data, temp_fd, sort_keys=False)
shutil.move(temp_file_path, file_path)
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Filter tags based on their prefix (i. e. the tag starts with a given string)
or some contained string.
"""
if not self.tags:
return set()
res_tags = self.tags.copy()
if prefix and len(prefix) > 0:
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
if contain and len(contain) > 0:
res_tags -= {tag for tag in res_tags if contain not in tag}
return res_tags
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
"""
Returns all tags as a string with the TagLine prefix. Optionally filtered
using 'Message.filter_tags()'.
"""
if self.tags:
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
else:
return str(TagLine.from_set(set()))
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
"""
Matches the current Message to the given filter atttributes.
Return True if all attributes match, else False.
"""
mytags = self.tags or set()
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503
or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503
or (mfilter.model_state == 'missing' and self.model)): # noqa: W503
return False
return True
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None:
"""
Renames the given tags. The first tuple element is the old name,
the second one is the new name.
"""
if self.tags:
self.tags = rename_tags(self.tags, tags_rename)
def msg_id(self) -> str:
"""
Returns an ID that is unique throughout all messages in the same (DB) directory.
Currently this is the file name. The ID is also used for sorting messages.
"""
if self.file_path:
return self.file_path.name
else:
raise MessageError("Can't create file ID without a file path")
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def tokens(self) -> int:
"""
Returns the nr. of AI language tokens used by this message.
If unknown, 0 is returned.
"""
if self.answer:
return self.question.tokens + self.answer.tokens
else:
return self.question.tokens

121
chatmastermind/storage.py Normal file
View File

@ -0,0 +1,121 @@
import yaml
import io
import pathlib
from .utils import terminal_width, append_message, message_to_chat, ChatType
from .configuration import Config
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format)
separator = ',' if ',' in tagline else ' '
tags = [t.strip() for t in tagline.split(separator)]
if tags_only:
return {"tags": tags}
text = fd.read().strip().split('\n')
question_idx = text.index("=== QUESTION ===") + 1
answer_idx = text.index("==== ANSWER ====")
question = "\n".join(text[question_idx:answer_idx]).strip()
answer = "\n".join(text[answer_idx + 1:]).strip()
return {"question": question, "answer": answer, "tags": tags,
"file": fname.name}
def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
return fd.getvalue()
def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
def save_answers(question: str,
answers: list[str],
tags: list[str],
otags: Optional[list[str]],
config: Config
) -> None:
wtags = otags or tags
num, inum = 0, 0
next_fname = pathlib.Path(str(config.db)) / '.next'
try:
with open(next_fname, 'r') as f:
num = int(f.read())
except Exception:
pass
for answer in answers:
num += 1
inum += 1
title = f'-- ANSWER {inum} '
title_end = '-' * (terminal_width() - len(title))
print(f'{title}{title_end}')
print(answer)
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
with open(next_fname, 'w') as f:
f.write(f'{num}')
def create_chat_hist(question: Optional[str],
tags: Optional[list[str]],
extags: Optional[list[str]],
config: Config,
match_all_tags: bool = False,
with_tags: bool = False,
with_file: bool = False
) -> ChatType:
chat: ChatType = []
append_message(chat, 'system', str(config.system).strip())
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
data['file'] = file.name
elif file.suffix == '.txt':
data = read_file(file)
else:
continue
data_tags = set(data.get('tags', []))
tags_match: bool
if match_all_tags:
tags_match = not tags or set(tags).issubset(data_tags)
else:
tags_match = not tags or bool(data_tags.intersection(tags))
extags_do_not_match = \
not extags or not data_tags.intersection(extags)
if tags_match and extags_do_not_match:
message_to_chat(data, chat, with_tags, with_file)
if question:
append_message(chat, 'user', question)
return chat
def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = []
for file in sorted(pathlib.Path(str(config.db)).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
elif file.suffix == '.txt':
data = read_file(file, tags_only=True)
else:
continue
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return result
def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix)))

View File

@ -1,184 +0,0 @@
"""
Module implementing tag related functions and classes.
"""
from typing import Type, TypeVar, Optional, Final
TagInst = TypeVar('TagInst', bound='Tag')
TagLineInst = TypeVar('TagLineInst', bound='TagLine')
class TagError(Exception):
pass
class Tag(str):
"""
A single tag. A string that can contain anything but the default separator (' ').
"""
# default separator
default_separator: Final[str] = ' '
# alternative separators (e. g. for backwards compatibility)
alternative_separators: Final[list[str]] = [',']
def __new__(cls: Type[TagInst], string: str) -> TagInst:
"""
Make sure the tag string does not contain the default separator.
"""
if cls.default_separator in string:
raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'")
instance = super().__new__(cls, string)
return instance
def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]:
"""
Deletes the given tags and returns a new set.
"""
return tags.difference(tags_delete)
def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]:
"""
Adds the given tags and returns a new set.
"""
return set(sorted(tags | tags_add))
def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]:
"""
Merges the tags in 'tags_merge' into the current one and returns a new set.
"""
for ts in tags_merge:
tags |= ts
return tags
def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]:
"""
Renames the given tags and returns a new set. The first tuple element
is the old name, the second one is the new name.
"""
for t in tags_rename:
if t[0] in tags:
tags.remove(t[0])
tags.add(t[1])
return set(sorted(tags))
def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]],
tags_not: Optional[set[Tag]]) -> bool:
"""
Checks if the given set 'tags' matches the given tag requirements:
- 'tags_or' : matches if this TagLine contains ANY of those tags
- 'tags_and': matches if this TagLine contains ALL of those tags
- 'tags_not': matches if this TagLine contains NONE of those tags
Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and',
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
or all of the tags in 'tags_and' but it must never contain any of the tags in
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()),
they match no tags.
"""
required_tags_present = False
excluded_tags_missing = False
if ((tags_or is None and tags_and is None)
or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503
or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503
required_tags_present = True
if ((tags_not is None)
or (not any(tag in tags for tag in tags_not))): # noqa: W503
excluded_tags_missing = True
return required_tags_present and excluded_tags_missing
class TagLine(str):
"""
A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by
a list of tags, separated by the defaut separator (' '). Any operations on a
TagLine will sort the tags.
"""
# the prefix
prefix: Final[str] = 'TAGS:'
def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst:
"""
Make sure the tagline string starts with the prefix. Also replace newlines
and multiple spaces with ' ', in order to support multiline TagLines.
"""
if not string.startswith(cls.prefix):
raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'")
string = ' '.join(string.split())
instance = super().__new__(cls, string)
return instance
@classmethod
def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst:
"""
Create a new TagLine from a set of tags.
"""
return cls(' '.join([cls.prefix] + sorted([t for t in tags])))
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
"""
Returns all tags contained in this line as a set, optionally
filtered based on prefix or contained string.
"""
tagstr = self[len(self.prefix):].strip()
if tagstr == '':
return set() # no tags, only prefix
separator = Tag.default_separator
# look for alternative separators and use the first one found
# -> we don't support different separators in the same TagLine
for s in Tag.alternative_separators:
if s in tagstr:
separator = s
break
res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)]))
if prefix and len(prefix) > 0:
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
if contain and len(contain) > 0:
res_tags -= {tag for tag in res_tags if contain not in tag}
return res_tags or set()
def merge(self, taglines: set['TagLine']) -> 'TagLine':
"""
Merges the tags of all given taglines into the current one and returns a new TagLine.
"""
tags_merge = [tl.tags() for tl in taglines]
return self.from_set(merge_tags(self.tags(), tags_merge))
def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine':
"""
Deletes the given tags and returns a new TagLine.
"""
return self.from_set(delete_tags(self.tags(), tags_delete))
def add_tags(self, tags_add: set[Tag]) -> 'TagLine':
"""
Adds the given tags and returns a new TagLine.
"""
return self.from_set(add_tags(self.tags(), tags_add))
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine':
"""
Renames the given tags and returns a new TagLine. The first
tuple element is the old name, the second one is the new name.
"""
return self.from_set(rename_tags(self.tags(), tags_rename))
def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]],
tags_not: Optional[set[Tag]]) -> bool:
"""
Checks if the current TagLine matches the given tag requirements:
- 'tags_or' : matches if this TagLine contains ANY of those tags
- 'tags_and': matches if this TagLine contains ALL of those tags
- 'tags_not': matches if this TagLine contains NONE of those tags
Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and',
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
or all of the tags in 'tags_and' but it must never contain any of the tags in
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
exclusion is still done if 'tags_not' is not 'None').
"""
return match_tags(self.tags(), tags_or, tags_and, tags_not)

86
chatmastermind/utils.py Normal file
View File

@ -0,0 +1,86 @@
import shutil
from pprint import PrettyPrinter
from typing import Any
ChatType = list[dict[str, str]]
def terminal_width() -> int:
return shutil.get_terminal_size().columns
def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
"""
Prints the tags specified in the given args.
"""
printed_messages = []
if tags:
printed_messages.append(f"Tags: {' '.join(tags)}")
if extags:
printed_messages.append(f"Excluding tags: {' '.join(extags)}")
if otags:
printed_messages.append(f"Output tags: {' '.join(otags)}")
if printed_messages:
print("\n".join(printed_messages))
print()
def append_message(chat: ChatType,
role: str,
content: str
) -> None:
chat.append({'role': role, 'content': content.replace("''", "'")})
def message_to_chat(message: dict[str, str],
chat: ChatType,
with_tags: bool = False,
with_file: bool = False
) -> None:
append_message(chat, 'user', message['question'])
append_message(chat, 'assistant', message['answer'])
if with_tags:
tags = " ".join(message['tags'])
append_message(chat, 'tags', tags)
if with_file:
append_message(chat, 'file', message['file'])
def display_source_code(content: str) -> None:
try:
content_start = content.index('```')
content_start = content.index('\n', content_start) + 1
content_end = content.rindex('```')
if content_start < content_end:
print(content[content_start:content_end].strip())
except ValueError:
pass
def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump:
pp(chat)
return
for message in chat:
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
if source_code:
display_source_code(message['content'])
continue
if message['role'] == 'user':
print('-' * terminal_width())
if text_too_long:
print(f"{message['role'].upper()}:")
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: list[str]) -> None:
for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}")

View File

@ -12,7 +12,7 @@ setup(
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/ok2/ChatMastermind",
packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"],
packages=find_packages(),
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Console",
@ -32,7 +32,7 @@ setup(
"openai",
"PyYAML",
"argcomplete",
"pytest",
"pytest"
],
python_requires=">=3.9",
test_suite="tests",

View File

@ -1,48 +0,0 @@
import argparse
import unittest
from unittest.mock import MagicMock
from chatmastermind.ai_factory import create_ai
from chatmastermind.configuration import Config
from chatmastermind.ai import AIError
from chatmastermind.ais.openai import OpenAI
class TestCreateAI(unittest.TestCase):
def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace)
self.args.AI = 'myopenai'
self.args.model = None
self.args.max_tokens = None
self.args.temperature = None
def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration
config = Config()
self.args.AI = 'myopenai'
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_ai_from_default(self) -> None:
self.args.AI = None
# Create an AI with the default configuration
config = Config()
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_empty_ai_error(self) -> None:
self.args.AI = None
# Create Config with empty AIs
config = Config()
config.ais = {}
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
def test_create_unsupported_ai_error(self) -> None:
# Mock argparse.Namespace with ai='invalid_ai'
self.args.AI = 'invalid_ai'
# Create default Config
config = Config()
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)

View File

@ -1,488 +0,0 @@
import unittest
import pathlib
import tempfile
import time
from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
class TestChat(unittest.TestCase):
def setUp(self) -> None:
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.add_messages([self.message2, self.message1])
self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
self.chat.sort(reverse=True)
self.assertEqual(self.chat.messages[0].question, 'Question 2')
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.add_messages([self.message1])
self.chat.clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None:
self.chat.add_messages([self.message1, self.message2])
tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.add_messages([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
def test_find_remove_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
msgs = self.chat.find_messages(['0001.txt'])
self.assertListEqual(msgs, [self.message1])
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
self.assertListEqual(msgs, [self.message1, self.message2])
# add new Message with full path
message3 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('btag2')},
file_path=pathlib.Path('/foo/bla/0003.txt'))
self.chat.add_messages([message3])
# find new Message by full path
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
self.assertListEqual(msgs, [message3])
# find Message with full path only by filename
msgs = self.chat.find_messages(['0003.txt'])
self.assertListEqual(msgs, [message3])
# remove last message
self.chat.remove_messages(['0003.txt'])
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"""{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2
FILE: 0002.txt
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(unittest.TestCase):
def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.message1 = Message(Question('Question 1'),
Answer('Answer 1'),
{Tag('tag1')},
file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'),
Answer('Answer 2'),
{Tag('tag2')},
file_path=pathlib.Path('0002.yaml'))
self.message3 = Message(Question('Question 3'),
Answer('Answer 3'),
{Tag('tag3')},
file_path=pathlib.Path('0003.txt'))
self.message4 = Message(Question('Question 4'),
Answer('Answer 4'),
{Tag('tag4')},
file_path=pathlib.Path('0004.yaml'))
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
"""
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
def tearDown(self) -> None:
self.db_path.cleanup()
self.cache_path.cleanup()
pass
def test_chat_db_from_dir(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
# check that the files are sorted
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path,
pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_from_dir_glob(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
glob='*.txt')
self.assertEqual(len(chat_db.messages), 2)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path,
pathlib.Path(self.db_path.name, '0003.txt'))
def test_chat_db_from_dir_filter_tags(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or={Tag('tag1')}))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0001.txt'))
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(tags_or=set(),
tags_and=set(),
tags_not=set()))
self.assertEqual(len(chat_db.messages), 0)
def test_chat_db_from_dir_filter_answer(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
mfilter=MessageFilter(answer_contains='Answer 2'))
self.assertEqual(len(chat_db.messages), 1)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.messages[0].file_path,
pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
def test_chat_db_from_messages(self) -> None:
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name),
messages=[self.message1, self.message2,
self.message3, self.message4])
self.assertEqual(len(chat_db.messages), 4)
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
# check that Message.file_path has been correctly updated
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.write_db()
# check if the written files are in the DB directory
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# check if all files in the DB dir have actually been overwritten
for file in db_dir_files:
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
# check that Message.file_path has been correctly updated (again)
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
def test_chat_db_read(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(len(chat_db.messages), 4)
# create 2 new files in the DB directory
new_message1 = Message(Question('Question 5'),
Answer('Answer 5'),
{Tag('tag5')})
new_message2 = Message(Question('Question 6'),
Answer('Answer 6'),
{Tag('tag6')})
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read and check them
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 6)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# create 2 new files in the cache directory
new_message3 = Message(Question('Question 7'),
Answer('Answer 5'),
{Tag('tag7')})
new_message4 = Message(Question('Question 8'),
Answer('Answer 6'),
{Tag('tag8')})
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
# read and check them
chat_db.read_cache()
self.assertEqual(len(chat_db.messages), 8)
# check that the new message have the cache dir path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
# an the old ones keep their path (since they have not been replaced)
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now overwrite two messages in the DB directory
new_message1.question = Question('New Question 1')
new_message2.question = Question('New Question 2')
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
# read from the DB dir and check if the modified messages have been updated
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8)
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
# now write the messages from the cache to the DB directory
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
# read and check them
chat_db.read_db()
self.assertEqual(len(chat_db.messages), 8)
# check that they now have the DB path
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
def test_chat_db_clear(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
# check that Message.file_path is correct
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths
chat_db.write_db()
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
# add a new message with empty file_path
message_empty = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You don't belong here!"))
# and one for the cache dir
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.add_messages([message_empty, message_cache])
# clear the cache and check the cache dir
chat_db.clear_cache()
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there
self.assertEqual(len(chat_db.messages), 5)
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_chat_db_add(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# add new messages to the cache dir
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.add_to_cache([message1])
# check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
# add new messages to the DB dir
message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2"))
chat_db.add_to_db([message2])
# check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# try to write a message without a valid file_path
message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.write_messages([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.write_messages([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
def test_chat_db_update_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
message = chat_db.messages[0]
message.answer = Answer("New answer")
# update message without writing
chat_db.update_messages([message], write=False)
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# re-read the message and check for old content
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
# now check with writing (message should be overwritten)
chat_db.update_messages([message], write=True)
chat_db.read_db()
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
# test without file_path -> expect error
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.update_messages([message1])

View File

@ -1,160 +0,0 @@
import os
import unittest
import yaml
from tempfile import NamedTemporaryFile
from pathlib import Path
from typing import cast
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
class TestAIConfigInstance(unittest.TestCase):
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
ai_reference = OpenAIConfig()
self.assertEqual(ai_config.ID, ai_reference.ID)
self.assertEqual(ai_config.name, ai_reference.name)
self.assertEqual(ai_config.api_key, ai_reference.api_key)
self.assertEqual(ai_config.system, ai_reference.system)
self.assertEqual(ai_config.model, ai_reference.model)
self.assertEqual(ai_config.temperature, ai_reference.temperature)
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
self.assertEqual(ai_config.top_p, ai_reference.top_p)
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
conf_dict = {
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
self.assertEqual(ai_config.system, 'Custom system')
self.assertEqual(ai_config.api_key, '9876543210')
self.assertEqual(ai_config.model, 'custom_model')
self.assertEqual(ai_config.max_tokens, 5000)
self.assertAlmostEqual(ai_config.temperature, 0.5)
self.assertAlmostEqual(ai_config.top_p, 0.8)
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
with self.assertRaises(ConfigError):
ai_config_instance('invalid_name')
class TestConfig(unittest.TestCase):
def setUp(self) -> None:
self.test_file = NamedTemporaryFile(delete=False)
def tearDown(self) -> None:
os.remove(self.test_file.name)
def test_from_dict_should_create_config_from_dict(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'myopenai': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
config = Config.from_dict(source_dict)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertEqual(config.ais['myopenai'].name, 'openai')
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
# check that 'ID' has been added
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
def test_create_default_should_create_default_config(self) -> None:
Config.create_default(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
default_config = yaml.load(f, Loader=yaml.FullLoader)
config_reference = Config()
self.assertEqual(default_config['db'], config_reference.db)
def test_from_file_should_load_config_from_file(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'openai',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
config = Config.from_file(self.test_file.name)
self.assertIsInstance(config, Config)
self.assertEqual(config.db, './test_db/')
self.assertEqual(len(config.ais), 1)
self.assertIsInstance(config.ais['default'], AIConfig)
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
def test_to_file_should_save_config_to_file(self) -> None:
config = Config(
db='./test_db/',
ais={
'myopenai': OpenAIConfig(
ID='myopenai',
system='Custom system',
api_key='9876543210',
model='custom_model',
max_tokens=5000,
temperature=0.5,
top_p=0.8,
frequency_penalty=0.7,
presence_penalty=0.2
)
}
)
config.to_file(Path(self.test_file.name))
with open(self.test_file.name, 'r') as f:
saved_config = yaml.load(f, Loader=yaml.FullLoader)
self.assertEqual(saved_config['db'], './test_db/')
self.assertEqual(len(saved_config['ais']), 1)
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
def test_from_file_error_unknown_ai(self) -> None:
source_dict = {
'db': './test_db/',
'ais': {
'default': {
'name': 'foobla',
'system': 'Custom system',
'api_key': '9876543210',
'model': 'custom_model',
'max_tokens': 5000,
'temperature': 0.5,
'top_p': 0.8,
'frequency_penalty': 0.7,
'presence_penalty': 0.2
}
}
}
with open(self.test_file.name, 'w') as f:
yaml.dump(source_dict, f)
with self.assertRaises(ConfigError):
Config.from_file(self.test_file.name)

233
tests/test_main.py Normal file
View File

@ -0,0 +1,233 @@
import unittest
import io
import pathlib
import argparse
from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai
from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase):
"""
Base class for all cmm testcases.
"""
def dummy_config(self, db: str) -> Config:
"""
Creates a dummy configuration.
"""
return Config.from_dict(
{'system': 'dummy_system',
'db': db,
'openai': {'api_key': 'dummy_key',
'model': 'dummy_model',
'max_tokens': 4000,
'temperature': 1.0,
'top_p': 1,
'frequency_penalty': 0,
'presence_penalty': 0}}
)
class TestCreateChat(CmmTestCase):
def setUp(self) -> None:
self.config = self.dummy_config(db='test_files')
self.question = "test question"
self.tags = ['test_tag']
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['test_tag']}))
test_chat = create_chat_hist(self.question, self.tags, None, self.config)
self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3],
{'role': 'user', 'content': self.question})
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer',
'tags': ['other_tag']}))
test_chat = create_chat_hist(self.question, self.tags, None, self.config)
self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question})
@patch('os.listdir')
@patch('pathlib.Path.iterdir')
@patch('builtins.open')
def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.side_effect = (
io.StringIO(dump_data({'question': 'test_content',
'answer': 'some answer',
'tags': ['test_tag']})),
io.StringIO(dump_data({'question': 'test_content2',
'answer': 'some answer2',
'tags': ['test_tag2']})),
)
test_chat = create_chat_hist(self.question, [], None, self.config)
self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3],
{'role': 'user', 'content': 'test_content2'})
self.assertEqual(test_chat[4],
{'role': 'assistant', 'content': 'some answer2'})
class TestHandleQuestion(CmmTestCase):
def setUp(self) -> None:
self.question = "test question"
self.args = argparse.Namespace(
tags=['tag1'],
extags=['extag1'],
output_tags=None,
question=[self.question],
source=None,
only_source_code=False,
number=3,
max_tokens=None,
temperature=None,
model=None,
match_all_tags=False,
with_tags=False,
with_file=False,
)
self.config = self.dummy_config(db='test_files')
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args")
@patch("chatmastermind.main.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp")
@patch("builtins.print")
def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
mock_create_chat_hist: MagicMock) -> None:
open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.extags,
[])
mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags,
self.args.extags,
self.config,
False, False, False)
mock_print_chat_hist.assert_called_once_with('test_chat',
False,
self.args.only_source_code)
mock_ai.assert_called_with("test_chat",
self.config,
self.args.number)
expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),))
expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
self.assertEqual(mock_print.call_args_list, expected_calls)
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
open_mock.assert_has_calls(open_expected_calls, any_order=True)
class TestSaveAnswers(CmmTestCase):
@mock.patch('builtins.open')
@mock.patch('chatmastermind.storage.print')
def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
question = "Test question?"
answers = ["Answer 1", "Answer 2"]
tags = ["tag1", "tag2"]
otags = ["otag1", "otag2"]
config = self.dummy_config(db='test_db')
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
mock.patch('chatmastermind.storage.yaml.dump'), \
mock.patch('io.StringIO') as stringio_mock:
stringio_instance = stringio_mock.return_value
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
save_answers(question, answers, tags, otags, config)
open_calls = [
mock.call(pathlib.Path('test_db/.next'), 'r'),
mock.call(pathlib.Path('test_db/.next'), 'w'),
]
open_mock.assert_has_calls(open_calls, any_order=True)
class TestAI(CmmTestCase):
@patch("openai.ChatCompletion.create")
def test_ai(self, mock_create: MagicMock) -> None:
mock_create.return_value = {
'choices': [
{'message': {'content': 'response_text_1'}},
{'message': {'content': 'response_text_2'}}
],
'usage': {'tokens': 10}
}
chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy')
config.openai.model = "text-davinci-002"
config.openai.max_tokens = 150
config.openai.temperature = 0.5
result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10})
self.assertEqual(result, expected_result)
class TestCreateParser(CmmTestCase):
def test_create_parser(self) -> None:
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_cmdparser = Mock()
mock_add_subparsers.return_value = mock_cmdparser
parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config'))

View File

@ -1,836 +0,0 @@
import unittest
import pathlib
import tempfile
from typing import cast
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" ```python\n print(\"Hello, World!\")\n ```\n",
" ```python\n x = 10\n y = 20\n print(x + y)\n ```\n"
]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_without_include_delims(self) -> None:
text = """
Some text before the code block
```python
print("Hello, World!")
```
Some text after the code block
```python
x = 10
y = 20
print(x + y)
```
"""
expected_result = [
" print(\"Hello, World!\")\n",
" x = 10\n y = 20\n print(x + y)\n"
]
result = source_code(text, include_delims=False)
self.assertEqual(result, expected_result)
def test_source_code_with_single_code_block(self) -> None:
text = "```python\nprint(\"Hello, World!\")\n```"
expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"]
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
def test_source_code_with_no_code_blocks(self) -> None:
text = "Some text without any code blocks"
expected_result: list[str] = []
result = source_code(text, include_delims=True)
self.assertEqual(result, expected_result)
class QuestionTestCase(unittest.TestCase):
def test_question_with_header(self) -> None:
with self.assertRaises(MessageError):
Question(f"{Question.txt_header}\nWhat is your name?")
def test_question_with_answer_header(self) -> None:
with self.assertRaises(MessageError):
Question(f"{Answer.txt_header}\nBob")
def test_question_with_legal_header(self) -> None:
"""
If the header is just a part of a line, it's fine.
"""
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
self.assertIsInstance(question, Question)
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
def test_question_without_header(self) -> None:
question = Question("What is your favorite color?")
self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno")
def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
def test_answer_without_header(self) -> None:
answer = Answer("No")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
{Tag('tag1'), Tag('tag2')},
ai='ChatGPT',
model='gpt-3.5-turbo',
file_path=self.file_path)
self.message_min = Message(Question('This is a question.'),
file_path=self.file_path)
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_to_file_txt_complete(self) -> None:
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{TagLine.prefix} tag1 tag2
{AILine.prefix} ChatGPT
{ModelLine.prefix} gpt-3.5-turbo
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
"""
self.assertEqual(content, expected_content)
def test_to_file_txt_min(self) -> None:
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{Question.txt_header}
This is a question.
"""
self.assertEqual(content, expected_content)
def test_to_file_unsupported_file_type(self) -> None:
unsupported_file_path = pathlib.Path("example.doc")
with self.assertRaises(MessageError) as cm:
self.message_complete.to_file(unsupported_file_path)
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
def test_to_file_no_file_path(self) -> None:
"""
Provoke an exception using an empty path.
"""
with self.assertRaises(MessageError) as cm:
# clear the internal file_path
self.message_complete.file_path = None
self.message_complete.to_file(None)
self.assertEqual(str(cm.exception), "Got no valid path to write message")
# reset the internal file_path
self.message_complete.file_path = self.file_path
class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
self.message_complete = Message(Question('This is a question.'),
Answer('This is an answer.'),
{Tag('tag1'), Tag('tag2')},
ai='ChatGPT',
model='gpt-3.5-turbo',
file_path=self.file_path)
self.message_multiline = Message(Question('This is a\nmultiline question.'),
Answer('This is a\nmultiline answer.'),
{Tag('tag1'), Tag('tag2')},
ai='ChatGPT',
model='gpt-3.5-turbo',
file_path=self.file_path)
self.message_min = Message(Question('This is a question.'),
file_path=self.file_path)
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_to_file_yaml_complete(self) -> None:
self.message_complete.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{Question.yaml_key}: This is a question.
{Answer.yaml_key}: This is an answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}:
- tag1
- tag2
"""
self.assertEqual(content, expected_content)
def test_to_file_yaml_multiline(self) -> None:
self.message_multiline.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"""{Question.yaml_key}: |-
This is a
multiline question.
{Answer.yaml_key}: |-
This is a
multiline answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}:
- tag1
- tag2
"""
self.assertEqual(content, expected_content)
def test_to_file_yaml_min(self) -> None:
self.message_min.to_file(self.file_path)
with open(self.file_path, "r") as fd:
content = fd.read()
expected_content = f"{Question.yaml_key}: This is a question.\n"
self.assertEqual(content, expected_content)
class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2
{AILine.prefix} ChatGPT
{ModelLine.prefix} gpt-3.5-turbo
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""{Question.txt_header}
This is a question.
""")
def tearDown(self) -> None:
self.file.close()
self.file_min.close()
self.file_path.unlink()
self.file_path_min.unlink()
def test_from_file_txt_complete(self) -> None:
"""
Read a complete message (with all optional values).
"""
message = Message.from_file(self.file_path)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_min(self) -> None:
"""
Read a message with only required values.
"""
message = Message.from_file(self.file_path_min)
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_txt_tags_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_txt_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag3')}))
self.assertIsNone(message)
def test_from_file_txt_no_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message)
def test_from_file_txt_empty_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or=set(),
tags_and=set()))
self.assertIsNone(message)
def test_from_file_txt_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.txt")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
def test_from_file_txt_question_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='question'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='answer'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_available(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='available'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_answer_missing(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='missing'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_question_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='answer'))
self.assertIsNone(message)
def test_from_file_txt_answer_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='question'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_exists(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_contains='answer'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_available(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='available'))
self.assertIsNone(message)
def test_from_file_txt_answer_not_missing(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='missing'))
self.assertIsNone(message)
def test_from_file_txt_ai_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='ChatGPT'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_ai_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='Foo'))
self.assertIsNone(message)
def test_from_file_txt_model_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='gpt-3.5-turbo'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_txt_model_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='Bar'))
self.assertIsNone(message)
class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
with open(self.file_path, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
{Message.ai_yaml_key}: ChatGPT
{Message.model_yaml_key}: gpt-3.5-turbo
{Message.tags_yaml_key}:
- tag1
- tag2
""")
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_min = pathlib.Path(self.file_min.name)
with open(self.file_path_min, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
""")
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
self.file_min.close()
self.file_path_min.unlink()
def test_from_file_yaml_complete(self) -> None:
"""
Read a complete message (with all optional values).
"""
message = Message.from_file(self.file_path)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.ai, 'ChatGPT')
self.assertEqual(message.model, 'gpt-3.5-turbo')
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_min(self) -> None:
"""
Read a message with only the required values.
"""
message = Message.from_file(self.file_path_min)
self.assertIsInstance(message, Message)
self.assertIsNotNone(message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
self.assertIsNone(message.answer)
def test_from_file_not_exists(self) -> None:
file_not_exists = pathlib.Path("example.yaml")
with self.assertRaises(MessageError) as cm:
Message.from_file(file_not_exists)
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
def test_from_file_yaml_tags_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertEqual(message.answer, 'This is an answer.')
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
self.assertEqual(message.file_path, self.file_path)
def test_from_file_yaml_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(tags_or={Tag('tag3')}))
self.assertIsNone(message)
def test_from_file_yaml_no_tags_dont_match(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_or={Tag('tag1')}))
self.assertIsNone(message)
def test_from_file_yaml_no_tags_match_tags_not(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(tags_not={Tag('tag1')}))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
if message: # mypy bug
self.assertEqual(message.question, 'This is a question.')
self.assertSetEqual(cast(set[Tag], message.tags), set())
self.assertEqual(message.file_path, self.file_path_min)
def test_from_file_yaml_question_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='question'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='answer'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_available(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='available'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_answer_missing(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='missing'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_question_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(question_contains='answer'))
self.assertIsNone(message)
def test_from_file_yaml_answer_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_contains='question'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_exists(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_contains='answer'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_available(self) -> None:
message = Message.from_file(self.file_path_min,
MessageFilter(answer_state='available'))
self.assertIsNone(message)
def test_from_file_yaml_answer_not_missing(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(answer_state='missing'))
self.assertIsNone(message)
def test_from_file_yaml_ai_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='ChatGPT'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_ai_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(ai='Foo'))
self.assertIsNone(message)
def test_from_file_yaml_model_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='gpt-3.5-turbo'))
self.assertIsNotNone(message)
self.assertIsInstance(message, Message)
def test_from_file_yaml_model_doesnt_match(self) -> None:
message = Message.from_file(self.file_path,
MessageFilter(model='Bar'))
self.assertIsNone(message)
class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name)
with open(self.file_path_txt, "w") as fd:
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
with open(self.file_path_txt_no_tags, "w") as fd:
fd.write(f"""{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
with open(self.file_path_txt_tags_empty, "w") as fd:
fd.write(f"""TAGS:
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer.
""")
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
with open(self.file_path_yaml, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
{Message.tags_yaml_key}:
- tag1
- tag2
- ptag3
""")
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
with open(self.file_path_yaml_no_tags, "w") as fd:
fd.write(f"""
{Question.yaml_key}: |-
This is a question.
{Answer.yaml_key}: |-
This is an answer.
""")
def tearDown(self) -> None:
self.file_txt.close()
self.file_path_txt.unlink()
self.file_yaml.close()
self.file_path_yaml.unlink()
self.file_txt_no_tags.close
self.file_path_txt_no_tags.unlink()
self.file_txt_tags_empty.close
self.file_path_txt_tags_empty.unlink()
self.file_yaml_no_tags.close()
self.file_path_yaml_no_tags.unlink()
def test_tags_from_file_txt(self) -> None:
tags = Message.tags_from_file(self.file_path_txt)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
def test_tags_from_file_txt_no_tags(self) -> None:
tags = Message.tags_from_file(self.file_path_txt_no_tags)
self.assertSetEqual(tags, set())
def test_tags_from_file_txt_tags_empty(self) -> None:
tags = Message.tags_from_file(self.file_path_txt_tags_empty)
self.assertSetEqual(tags, set())
def test_tags_from_file_yaml(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml)
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
def test_tags_from_file_yaml_no_tags(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml_no_tags)
self.assertSetEqual(tags, set())
def test_tags_from_file_txt_prefix(self) -> None:
tags = Message.tags_from_file(self.file_path_txt, prefix='p')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_txt, prefix='R')
self.assertSetEqual(tags, set())
def test_tags_from_file_yaml_prefix(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml, prefix='p')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_yaml, prefix='R')
self.assertSetEqual(tags, set())
def test_tags_from_file_txt_contain(self) -> None:
tags = Message.tags_from_file(self.file_path_txt, contain='3')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_txt, contain='R')
self.assertSetEqual(tags, set())
def test_tags_from_file_yaml_contain(self) -> None:
tags = Message.tags_from_file(self.file_path_yaml, contain='3')
self.assertSetEqual(tags, {Tag('ptag3')})
tags = Message.tags_from_file(self.file_path_yaml, contain='R')
self.assertSetEqual(tags, set())
class TagsFromDirTestCase(unittest.TestCase):
def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
self.tag_sets = [
{Tag('atag1'), Tag('atag2')},
{Tag('btag3'), Tag('btag4')},
{Tag('ctag5'), Tag('ctag6')}
]
self.files = [
pathlib.Path(self.temp_dir.name, 'file1.txt'),
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
pathlib.Path(self.temp_dir.name, 'file3.txt')
]
self.files_no_tags = [
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
]
for file, tags in zip(self.files, self.tag_sets):
message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags)
message.to_file(file)
for file in self.files_no_tags:
message = Message(Question('This is a question.'),
Answer('This is an answer.'))
message.to_file(file)
def tearDown(self) -> None:
self.temp_dir.cleanup()
self.temp_dir_no_tags.cleanup()
def test_tags_from_dir(self) -> None:
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name))
expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2]
self.assertEqual(all_tags, expected_tags)
def test_tags_from_dir_prefix(self) -> None:
atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a')
expected_tags = self.tag_sets[0]
self.assertEqual(atags, expected_tags)
def test_tags_from_dir_no_tags(self) -> None:
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name))
self.assertSetEqual(all_tags, set())
class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
self.message = Message(Question('This is a question.'),
file_path=self.file_path)
self.message_no_file_path = Message(Question('This is a question.'))
def tearDown(self) -> None:
self.file.close()
self.file_path.unlink()
def test_msg_id_txt(self) -> None:
self.assertEqual(self.message.msg_id(), self.file_path.name)
def test_msg_id_txt_exception(self) -> None:
with self.assertRaises(MessageError):
self.message_no_file_path.msg_id()
class MessageHashTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')},
file_path=pathlib.Path('/tmp/foo/bla'))
self.message2 = Message(Question('This is a new question.'),
file_path=pathlib.Path('/tmp/foo/bla'))
self.message3 = Message(Question('This is a question.'),
Answer('This is an answer.'),
file_path=pathlib.Path('/tmp/foo/bla'))
# message4 is a copy of message1, because only question and
# answer are used for hashing and comparison
self.message4 = Message(Question('This is a question.'),
tags={Tag('tag1'), Tag('tag2')},
ai='Blabla',
file_path=pathlib.Path('foobla'))
def test_set_hashing(self) -> None:
msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4}
self.assertEqual(len(msgs), 3)
for msg in [self.message1, self.message2, self.message3]:
self.assertIn(msg, msgs)
class MessageTagsStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('tag1')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_tags_str(self) -> None:
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_filter_tags(self) -> None:
tags_all = self.message.filter_tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.message.filter_tags(prefix='a')
self.assertSetEqual(tags_pref, {Tag('atag1')})
tags_cont = self.message.filter_tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
self.message2 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/bla/foo'))
def test_message_in(self) -> None:
self.assertTrue(message_in(self.message1, [self.message1]))
self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_rename_tags(self) -> None:
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
self.assertIsNotNone(self.message.tags)
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_to_str(self) -> None:
expected_output = f"""{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(), expected_output)
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)

View File

@ -1,162 +0,0 @@
import os
import unittest
import argparse
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
from chatmastermind.commands.question import create_message
from chatmastermind.message import Message, Question
from chatmastermind.chat import ChatDB
class TestMessageCreate(unittest.TestCase):
"""
Test if messages created by the 'question' command have
the correct format.
"""
def setUp(self) -> None:
# create ChatDB structure
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
db_path=Path(self.db_path.name))
# create arguments mock
self.args = MagicMock(spec=argparse.Namespace)
self.args.source_text = None
self.args.source_code = None
self.args.AI = None
self.args.model = None
self.args.output_tags = None
# File 1 : no source code block, only text
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
self.source_file1_content = """This is just text.
No source code.
Nope. Go look elsewhere!"""
with open(self.source_file1.name, 'w') as f:
f.write(self.source_file1_content)
# File 2 : one embedded source code block
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
self.source_file2_content = """This is just text.
```
This is embedded source code.
```
And some text again."""
with open(self.source_file2.name, 'w') as f:
f.write(self.source_file2_content)
# File 3 : all source code
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
self.source_file3_content = """This is all source code.
Yes, really.
Language is called 'brainfart'."""
with open(self.source_file3.name, 'w') as f:
f.write(self.source_file3_content)
# File 4 : two source code blocks
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
self.source_file4_content = """This is just text.
```
This is embedded source code.
```
And some text again.
```
This is embedded source code.
```
Aaaand again some text."""
with open(self.source_file4.name, 'w') as f:
f.write(self.source_file4_content)
def tearDown(self) -> None:
os.remove(self.source_file1.name)
os.remove(self.source_file2.name)
os.remove(self.source_file3.name)
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
# exclude '.next'
return list(Path(tmp_dir.name).glob('*.[ty]*'))
def test_message_file_created(self) -> None:
self.args.ask = ["What is this?"]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
create_message(self.chat, self.args)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
message = Message.from_file(cache_dir_files[0])
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
def test_single_question(self) -> None:
self.args.ask = ["What is this?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("What is this?"))
self.assertEqual(len(message.question.source_code()), 0)
def test_multipart_question(self) -> None:
self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
self.assertEqual(message.question, Question("""What is this
'bard' thing?
Is it good?"""))
def test_single_question_with_text_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_text = [f"{self.source_file1.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains no source code (only text)
# -> don't expect any in the question
self.assertEqual(len(message.question.source_code()), 0)
self.assertEqual(message.question, Question(f"""What is this?
{self.source_file1_content}"""))
def test_single_question_with_text_file_and_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file2.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 1 source code block
# -> expect it in the question
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
"""))
def test_single_question_with_code_only_file(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file3.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file is complete source code
self.assertEqual(len(message.question.source_code()), 1)
self.assertEqual(message.question, Question(f"""What is this?
```
{self.source_file3_content}
```"""))
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
self.args.ask = ["What is this?"]
self.args.source_code = [f"{self.source_file4.name}"]
message = create_message(self.chat, self.args)
self.assertIsInstance(message, Message)
# file contains 2 source code blocks
# -> expect them in the question
self.assertEqual(len(message.question.source_code()), 2)
self.assertEqual(message.question, Question("""What is this?
```
This is embedded source code.
```
```
This is embedded source code.
```
"""))

View File

@ -1,163 +0,0 @@
import unittest
from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(unittest.TestCase):
def test_valid_tag(self) -> None:
tag = Tag('mytag')
self.assertEqual(tag, 'mytag')
def test_invalid_tag(self) -> None:
with self.assertRaises(TagError):
Tag('tag with space')
def test_default_separator(self) -> None:
self.assertEqual(Tag.default_separator, ' ')
def test_alternative_separators(self) -> None:
self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(unittest.TestCase):
def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_valid_tagline_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_invalid_tagline(self) -> None:
with self.assertRaises(TagError):
TagLine('tag1 tag2')
def test_prefix(self) -> None:
self.assertEqual(TagLine.prefix, 'TAGS:')
def test_from_set(self) -> None:
tags = {Tag('tag1'), Tag('tag2')}
tagline = TagLine.from_set(tags)
self.assertEqual(tagline, 'TAGS: tag1 tag2')
def test_tags(self) -> None:
tagline = TagLine('TAGS: atag1 btag2')
tags = tagline.tags()
self.assertEqual(tags, {Tag('atag1'), Tag('btag2')})
def test_tags_empty(self) -> None:
tagline = TagLine('TAGS:')
self.assertSetEqual(tagline.tags(), set())
def test_tags_with_newline(self) -> None:
tagline = TagLine('TAGS: tag1\n tag2')
tags = tagline.tags()
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
def test_tags_prefix(self) -> None:
tagline = TagLine('TAGS: atag1 stag2 stag3')
tags = tagline.tags(prefix='a')
self.assertSetEqual(tags, {Tag('atag1')})
tags = tagline.tags(prefix='s')
self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')})
tags = tagline.tags(prefix='R')
self.assertSetEqual(tags, set())
def test_tags_contain(self) -> None:
tagline = TagLine('TAGS: atag1 stag2 stag3')
tags = tagline.tags(contain='t')
self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')})
tags = tagline.tags(contain='1')
self.assertSetEqual(tags, {Tag('atag1')})
tags = tagline.tags(contain='R')
self.assertSetEqual(tags, set())
def test_merge(self) -> None:
tagline1 = TagLine('TAGS: tag1 tag2')
tagline2 = TagLine('TAGS: tag2 tag3')
merged_tagline = tagline1.merge({tagline2})
self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3')
def test_delete_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2 tag3')
new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')})
self.assertEqual(new_tagline, 'TAGS: tag2')
def test_add_tags(self) -> None:
tagline = TagLine('TAGS: tag1')
new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')})
self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3')
def test_rename_tags(self) -> None:
tagline = TagLine('TAGS: old1 old2')
new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))})
self.assertEqual(new_tagline, 'TAGS: new1 new2')
def test_match_tags(self) -> None:
tagline = TagLine('TAGS: tag1 tag2 tag3')
# Test case 1: Match any tag in 'tags_or'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and: set[Tag] = set()
tags_not: set[Tag] = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 2: Match all tags in 'tags_and'
tags_or = set()
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')}
tags_not = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and = {Tag('tag1'), Tag('tag2')}
tags_not = set()
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not'
tags_or = {Tag('tag1'), Tag('tag4')}
tags_and = {Tag('tag1'), Tag('tag2')}
tags_not = {Tag('tag5')}
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 5: No matching tags in 'tags_or'
tags_or = {Tag('tag4'), Tag('tag5')}
tags_and = set()
tags_not = set()
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 6: Not all tags in 'tags_and' are present
tags_or = set()
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')}
tags_not = set()
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 7: Some tags in 'tags_not' are present
tags_or = {Tag('tag1')}
tags_and = set()
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
# Test case 8: 'tags_or' and 'tags_and' are None, match all tags
tags_not = set()
self.assertTrue(tagline.match_tags(None, None, tags_not))
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
tags_not = {Tag('tag2')}
self.assertFalse(tagline.match_tags(None, None, tags_not))
# Test case 10: 'tags_or' and 'tags_and' are empty, match no tags
self.assertFalse(tagline.match_tags(set(), set(), None))
# Test case 11: 'tags_or' is empty, match no tags
self.assertFalse(tagline.match_tags(set(), None, None))
# Test case 12: 'tags_and' is empty, match no tags
self.assertFalse(tagline.match_tags(None, set(), None))
# Test case 13: 'tags_or' is empty, match 'tags_and'
tags_and = {Tag('tag1'), Tag('tag2')}
self.assertTrue(tagline.match_tags(None, tags_and, None))
# Test case 14: 'tags_and' is empty, match 'tags_or'
tags_or = {Tag('tag1'), Tag('tag2')}
self.assertTrue(tagline.match_tags(tags_or, None, None))