Compare commits

...

7 Commits

9 changed files with 439 additions and 104 deletions

64
chatmastermind/ai.py Normal file
View File

@ -0,0 +1,64 @@
from dataclasses import dataclass
from abc import abstractmethod
from typing import Protocol, Optional, Union
from .configuration import AIConfig
from .message import Message
from .chat import Chat
class AIError(Exception):
pass
@dataclass
class Tokens:
prompt: int = 0
completion: int = 0
total: int = 0
@dataclass
class AIResponse:
"""
The response to an AI request. Consists of one or more messages
(each containing the question and a single answer) and the nr.
of used tokens.
"""
messages: list[Message]
tokens: Optional[Tokens] = None
class AI(Protocol):
"""
The base class for AI clients.
"""
name: str
config: AIConfig
@abstractmethod
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
"""
raise NotImplementedError
@abstractmethod
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def tokens(self, data: Union[Message, Chat]) -> int:
"""
Computes the nr. of AI language tokens for the given message
or chat. Note that the computation may not be 100% accurate
and is not implemented for all AIs.
"""
raise NotImplementedError

View File

@ -0,0 +1,90 @@
"""
Implements the OpenAI client classes and functions.
"""
import openai
from typing import Optional
from ..tags import Tag
from ..message import Message, Answer
from ..chat import Chat
from ..ai import AI, AIResponse, Tokens
from ..config import OpenAIConfig
ChatType = list[dict[str, str]]
class OpenAI(AI):
"""
The OpenAI AI client.
"""
config: OpenAIConfig
def request(self,
question: Message,
chat: Chat,
num_answers: int = 1,
otags: Optional[set[Tag]] = None) -> AIResponse:
"""
Make an AI request, asking the given question with the given
chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'.
"""
oai_chat = self.openai_chat(chat, self.config.system, question)
response = openai.ChatCompletion.create(
model=self.config.model,
messages=oai_chat,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
n=num_answers,
frequency_penalty=self.config.frequency_penalty,
presence_penalty=self.config.presence_penalty)
answers: list[Message] = []
for choice in response['choices']: # type: ignore
answers.append(Message(question=question.question,
answer=Answer(choice['message']['content']),
tags=otags,
ai=self.name,
model=self.config.model))
return AIResponse(answers, Tokens(response['usage']['prompt'],
response['usage']['completion'],
response['usage']['total']))
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))
def openai_chat(self, chat: Chat, system: str,
question: Optional[Message] = None) -> ChatType:
"""
Create a chat history with system message in OpenAI format.
Optionally append a new question.
"""
oai_chat: ChatType = []
def append(role: str, content: str) -> None:
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
append('system', system)
for message in chat.messages:
if message.answer:
append('user', message.question)
append('assistant', message.answer)
if question:
append('user', question.question)
return oai_chat

View File

@ -7,7 +7,7 @@ from pprint import PrettyPrinter
from pydoc import pager
from dataclasses import dataclass
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in
from .message import Message, MessageFilter, MessageError, message_in
from .tags import Tag
ChatInst = TypeVar('ChatInst', bound='Chat')
@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path,
return messages
def make_file_path(dir_path: pathlib.Path,
file_suffix: str,
next_fid: Callable[[], int]) -> pathlib.Path:
"""
Create a file_path for the given directory using the
given file_suffix and ID generator function.
"""
return dir_path / f"{next_fid():04d}{file_suffix}"
def write_dir(dir_path: pathlib.Path,
messages: list[Message],
file_suffix: str,
@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path,
file_path = message.file_path
# message has no file_path: create one
if not file_path:
fid = next_fid()
fname = f"{fid:04d}{file_suffix}"
file_path = dir_path / fname
file_path = make_file_path(dir_path, file_suffix, next_fid)
# file_path does not point to given directory: modify it
elif not file_path.parent.samefile(dir_path):
file_path = dir_path / file_path.name
@ -124,11 +132,11 @@ class Chat:
"""
self.messages = []
def add_msgs(self, msgs: list[Message]) -> None:
def add_messages(self, messages: list[Message]) -> None:
"""
Add new messages and sort them if possible.
"""
self.messages += msgs
self.messages += messages
self.sort()
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
@ -162,18 +170,10 @@ class Chat:
output: list[str] = []
for message in self.messages:
if source_code_only:
output.extend(source_code(message.question, include_delims=True))
output.append(message.to_str(source_code_only=True))
continue
output.append('-' * terminal_width())
if with_tags:
output.append(message.tags_str())
if with_files:
output.append('FILE: ' + str(message.file_path))
output.append(Question.txt_header)
output.append(message.question)
if message.answer:
output.append(Answer.txt_header)
output.append(message.answer)
output.append(message.to_str(with_tags, with_files))
output.append('\n' + ('-' * terminal_width()) + '\n')
if paged:
print_paged('\n'.join(output))
else:
@ -279,25 +279,25 @@ class ChatDB(Chat):
self.messages += new_messages
self.sort()
def write_db(self, msgs: Optional[list[Message]] = None) -> None:
def write_db(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the DB directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point
to the DB directory.
"""
write_dir(self.db_path,
msgs if msgs else self.messages,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
def write_cache(self, msgs: Optional[list[Message]] = None) -> None:
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
"""
Write messages to the cache directory. If a message has no file_path, a new one
will be created. If message.file_path exists, it will be modified to point to
the cache directory.
"""
write_dir(self.cache_path,
msgs if msgs else self.messages,
messages if messages else self.messages,
self.file_suffix,
self.get_next_fid)
@ -309,3 +309,52 @@ class ChatDB(Chat):
clear_dir(self.cache_path, self.glob)
# only keep messages from DB dir (or those that have not yet been written)
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the DB directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.db_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
"""
Add the given new messages and set the file_path to the cache directory.
Only accepts messages without a file_path.
"""
if any(m.file_path is not None for m in messages):
raise ChatError("Can't add new messages with existing file_path")
if write:
write_dir(self.cache_path,
messages,
self.file_suffix,
self.get_next_fid)
else:
for m in messages:
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
self.messages += messages
self.sort()
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
"""
Write either the given messages or the internal ones to their current file_path.
If messages are given, they all must have a valid file_path. When writing the
internal messages, the ones with a valid file_path are written, the others
are ignored.
"""
if messages and any(m.file_path is None for m in messages):
raise ChatError("Can't write files without a valid file_path")
msgs = iter(messages if messages else self.messages)
while (m := next(msgs, None)):
m.to_file()

View File

@ -6,11 +6,13 @@ import yaml
import sys
import argcomplete
import argparse
import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
from pathlib import Path
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
from .storage import save_answers, create_chat_hist, read_file, dump_data
from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from .chat import ChatDB
from .message import Message, MessageFilter
from itertools import zip_longest
from typing import Any
@ -18,9 +20,8 @@ default_config = '.config.yaml'
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return get_tags_unique(config, prefix)
config = Config.from_file(parsed_args.config)
return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
def create_question_with_hist(args: argparse.Namespace,
@ -31,11 +32,11 @@ def create_question_with_hist(args: argparse.Namespace,
by the specified tags.
"""
tags = args.tags or []
extags = args.extags or []
etags = args.etags or []
otags = args.output_tags or []
if not args.only_source_code:
print_tag_args(tags, extags, otags)
if not args.source_code_only:
print_tag_args(tags, etags, otags)
question_parts = []
question_list = args.question if args.question is not None else []
@ -52,17 +53,24 @@ def create_question_with_hist(args: argparse.Namespace,
question_parts.append(f"```\n{r.read().strip()}\n```")
full_question = '\n\n'.join(question_parts)
chat = create_chat_hist(full_question, tags, extags, config,
args.match_all_tags, False, False)
chat = create_chat_hist(full_question, tags, etags, config,
match_all_tags=True if args.atags else False, # FIXME
with_tags=False,
with_file=False)
return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: Config) -> None:
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'tag' command.
Handler for the 'tags' command.
"""
chat = ChatDB.from_dir(cache_path=Path('.'),
db_path=Path(config.db))
if args.list:
print_tags_frequency(get_tags(config, None))
tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
# TODO: add renaming
def config_cmd(args: argparse.Namespace, config: Config) -> None:
@ -89,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None:
if args.model:
config.openai.model = args.model
chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code)
print_chat_hist(chat, False, args.source_code_only)
otags = args.output_tags or []
answers, usage = ai(chat, config, args.number)
save_answers(question, answers, tags, otags, config)
@ -101,21 +109,25 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'hist' command.
"""
tags = args.tags or []
extags = args.extags or []
chat = create_chat_hist(None, tags, extags, config,
args.match_all_tags,
args.with_tags,
args.with_files)
print_chat_hist(chat, args.dump, args.only_source_code)
mfilter = MessageFilter(tags_or=args.tags,
tags_and=args.atags,
tags_not=args.etags,
question_contains=args.question,
answer_contains=args.answer)
chat = ChatDB.from_dir(Path('.'),
Path(config.db),
mfilter=mfilter)
chat.print(args.source_code_only,
args.with_tags,
args.with_files)
def print_cmd(args: argparse.Namespace, config: Config) -> None:
"""
Handler for the 'print' command.
"""
fname = pathlib.Path(args.file)
fname = Path(args.file)
if fname.suffix == '.yaml':
with open(args.file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
@ -124,7 +136,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None:
else:
print(f"Unknown file type: {args.file}")
sys.exit(1)
if args.only_source_code:
if args.source_code_only:
display_source_code(data['answer'])
else:
print(dump_data(data).strip())
@ -144,18 +156,17 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support tag selection
tag_parser = argparse.ArgumentParser(add_help=False)
tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+',
help='List of tag names', metavar='TAGS')
help='List of tag names (one must match)', metavar='TAGS')
tag_arg.completer = tags_completer # type: ignore
extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+',
help='List of tag names to exclude', metavar='EXTAGS')
extag_arg.completer = tags_completer # type: ignore
atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+',
help='List of tag names (all must match)', metavar='TAGS')
atag_arg.completer = tags_completer # type: ignore
etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+',
help='List of tag names to exclude', metavar='ETAGS')
etag_arg.completer = tags_completer # type: ignore
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
help='List of output tag names, default is input', metavar='OTAGS')
otag_arg.completer = tags_completer # type: ignore
tag_parser.add_argument('-a', '--match-all-tags',
help="All given tags must match when selecting chat history entries",
action='store_true')
# enable autocompletion for tags
# 'ask' command parser
ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser],
@ -170,7 +181,7 @@ def create_parser() -> argparse.ArgumentParser:
ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int,
default=1)
ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query')
ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history',
ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history',
action='store_true')
# 'hist' command parser
@ -178,23 +189,25 @@ def create_parser() -> argparse.ArgumentParser:
help="Print chat history.",
aliases=['h'])
hist_cmd_parser.set_defaults(func=hist_cmd)
hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure",
action='store_true')
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
action='store_true')
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
action='store_true')
hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true')
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
# 'tag' command parser
tag_cmd_parser = cmdparser.add_parser('tag',
help="Manage tags.",
aliases=['t'])
tag_cmd_parser.set_defaults(func=tag_cmd)
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true')
# 'tags' command parser
tags_cmd_parser = cmdparser.add_parser('tags',
help="Manage tags.",
aliases=['t'])
tags_cmd_parser.set_defaults(func=tags_cmd)
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true')
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser
config_cmd_parser = cmdparser.add_parser('config',
@ -214,7 +227,7 @@ def create_parser() -> argparse.ArgumentParser:
aliases=['p'])
print_cmd_parser.set_defaults(func=print_cmd)
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code',
print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
action='store_true')
argcomplete.autocomplete(parser)

View File

@ -392,6 +392,30 @@ class Message():
data[cls.file_yaml_key] = file_path
return cls.from_dict(data)
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
"""
Return the current Message as a string.
"""
output: list[str] = []
if source_code_only:
output.extend(self.question.source_code(include_delims=True))
if self.answer:
output.extend(self.answer.source_code(include_delims=True))
return '\n'.join(output)
if with_tags:
output.append(self.tags_str())
if with_file:
output.append('FILE: ' + str(self.file_path))
output.append(Question.txt_header)
output.append(self.question)
if self.answer:
output.append(Answer.txt_header)
output.append(self.answer)
return '\n'.join(output)
def __str__(self) -> str:
return self.to_str(False, False, False)
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
"""
Write a Message to the given file. Type is determined based on the suffix.

View File

@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
print(message['content'])
else:
print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: list[str]) -> None:
for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}")

View File

@ -5,7 +5,7 @@ from io import StringIO
from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
from .test_main import CmmTestCase
@ -22,14 +22,14 @@ class TestChat(CmmTestCase):
file_path=pathlib.Path('0002.txt'))
def test_filter(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
self.assertEqual(len(self.chat.messages), 1)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
def test_sort(self) -> None:
self.chat.add_msgs([self.message2, self.message1])
self.chat.add_messages([self.message2, self.message1])
self.chat.sort()
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
@ -38,18 +38,18 @@ class TestChat(CmmTestCase):
self.assertEqual(self.chat.messages[1].question, 'Question 1')
def test_clear(self) -> None:
self.chat.add_msgs([self.message1])
self.chat.add_messages([self.message1])
self.chat.clear()
self.assertEqual(len(self.chat.messages), 0)
def test_add_msgs(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
def test_add_messages(self) -> None:
self.chat.add_messages([self.message1, self.message2])
self.assertEqual(len(self.chat.messages), 2)
self.assertEqual(self.chat.messages[0].question, 'Question 1')
self.assertEqual(self.chat.messages[1].question, 'Question 2')
def test_tags(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
tags_all = self.chat.tags()
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
tags_pref = self.chat.tags(prefix='a')
@ -58,45 +58,53 @@ class TestChat(CmmTestCase):
self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
@patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False)
expected_output = f"""{'-'*terminal_width()}
{Question.txt_header}
expected_output = f"""{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@patch('sys.stdout', new_callable=StringIO)
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2])
self.chat.add_messages([self.message1, self.message2])
self.chat.print(paged=False, with_tags=True, with_files=True)
expected_output = f"""{'-'*terminal_width()}
{TagLine.prefix} atag1 btag2
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: 0001.txt
{Question.txt_header}
Question 1
{Answer.txt_header}
Answer 1
{'-'*terminal_width()}
{TagLine.prefix} btag2
FILE: 0002.txt
{Question.txt_header}
Question 2
{Answer.txt_header}
Answer 2
{'-'*terminal_width()}
"""
self.assertEqual(mock_stdout.getvalue(), expected_output)
@ -127,6 +135,17 @@ class TestChatDB(CmmTestCase):
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
# make the next FID match the current state
next_fname = pathlib.Path(self.db_path.name) / '.next'
with open(next_fname, 'w') as f:
f.write('4')
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
"""
List all Message files in the given TemporaryDirectory.
"""
# exclude '.next'
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
def tearDown(self) -> None:
self.db_path.cleanup()
@ -184,11 +203,11 @@ class TestChatDB(CmmTestCase):
def test_chat_db_fids(self) -> None:
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
self.assertEqual(chat_db.get_next_fid(), 1)
self.assertEqual(chat_db.get_next_fid(), 2)
self.assertEqual(chat_db.get_next_fid(), 3)
self.assertEqual(chat_db.get_next_fid(), 5)
self.assertEqual(chat_db.get_next_fid(), 6)
self.assertEqual(chat_db.get_next_fid(), 7)
with open(chat_db.next_fname, 'r') as f:
self.assertEqual(f.read(), '3')
self.assertEqual(f.read(), '7')
def test_chat_db_write(self) -> None:
# create a new ChatDB instance
@ -203,7 +222,7 @@ class TestChatDB(CmmTestCase):
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
@ -216,14 +235,14 @@ class TestChatDB(CmmTestCase):
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
# check the timestamp of the files in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
# overwrite the messages in the db directory
time.sleep(0.05)
chat_db.write_db()
# check if the written files are in the DB directory
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
@ -314,12 +333,12 @@ class TestChatDB(CmmTestCase):
# write the messages to the cache directory
chat_db.write_cache()
# check if the written files are in the cache directory
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 4)
# now rewrite them to the DB dir and check for modified paths
chat_db.write_db()
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
@ -333,15 +352,69 @@ class TestChatDB(CmmTestCase):
message_cache = Message(question=Question("What the hell am I doing here?"),
answer=Answer("You're a creep!"),
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
chat_db.add_msgs([message_empty, message_cache])
chat_db.add_messages([message_empty, message_cache])
# clear the cache and check the cache dir
chat_db.clear_cache()
cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*'))
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# make sure that the DB messages (and the new message) are still there
self.assertEqual(len(chat_db.messages), 5)
db_dir_files = list(pathlib.Path(self.db_path.name).glob('*'))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# but not the message with the cache dir path
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
def test_chat_db_add(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
# add new messages to the cache dir
message1 = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
chat_db.add_to_cache([message1])
# check if the file_path has been correctly set
self.assertIsNotNone(message1.file_path)
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
# add new messages to the DB dir
message2 = Message(question=Question("Question 2"),
answer=Answer("Answer 2"))
chat_db.add_to_db([message2])
# check if the file_path has been correctly set
self.assertIsNotNone(message2.file_path)
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 5)
with self.assertRaises(ChatError):
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
def test_chat_db_write_messages(self) -> None:
# create a new ChatDB instance
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
pathlib.Path(self.db_path.name))
db_dir_files = self.message_list(self.db_path)
self.assertEqual(len(db_dir_files), 4)
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 0)
# try to write a message without a valid file_path
message = Message(question=Question("Question 1"),
answer=Answer("Answer 1"))
with self.assertRaises(ChatError):
chat_db.write_messages([message])
# write a message with a valid file_path
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
chat_db.write_messages([message])
cache_dir_files = self.message_list(self.cache_path)
self.assertEqual(len(cache_dir_files), 1)
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)

View File

@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase):
self.question = "test question"
self.args = argparse.Namespace(
tags=['tag1'],
extags=['extag1'],
atags=None,
etags=['etag1'],
output_tags=None,
question=[self.question],
source=None,
only_source_code=False,
source_code_only=False,
number=3,
max_tokens=None,
temperature=None,
@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase):
with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.tags,
self.args.extags,
self.args.etags,
[])
mock_create_chat_hist.assert_called_once_with(self.question,
self.args.tags,
self.args.extags,
self.args.etags,
self.config,
False, False, False)
match_all_tags=False,
with_tags=False,
with_file=False)
mock_print_chat_hist.assert_called_once_with('test_chat',
False,
self.args.only_source_code)
self.args.source_code_only)
mock_ai.assert_called_with("test_chat",
self.config,
self.args.number)
@ -227,7 +230,7 @@ class TestCreateParser(CmmTestCase):
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config'))

View File

@ -804,3 +804,27 @@ class MessageRenameTagsTestCase(CmmTestCase):
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
self.assertIsNotNone(self.message.tags)
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(CmmTestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),
tags={Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('/tmp/foo/bla'))
def test_to_str(self) -> None:
expected_output = f"""{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(), expected_output)
def test_to_str_with_tags_and_file(self) -> None:
expected_output = f"""{TagLine.prefix} atag1 btag2
FILE: /tmp/foo/bla
{Question.txt_header}
This is a question.
{Answer.txt_header}
This is an answer."""
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)