Compare commits

..

2 Commits

3 changed files with 266 additions and 234 deletions

View File

@ -17,8 +17,9 @@ class OpenAI(AI):
The OpenAI AI client. The OpenAI AI client.
""" """
def __init__(self, name: str, config: OpenAIConfig) -> None: def __init__(self, config: OpenAIConfig) -> None:
self.name = name self.ai_type = config.ai_type
self.name = config.name
self.config = config self.config = config
def request(self, def request(self,
@ -31,8 +32,7 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
# FIXME: use real 'system' message (store in OpenAIConfig) oai_chat = self.openai_chat(chat, self.config.system, question)
oai_chat = self.openai_chat(chat, "system", question)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,

View File

@ -1,6 +1,7 @@
import yaml import yaml
from pathlib import Path
from typing import Type, TypeVar, Any from typing import Type, TypeVar, Any
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
@ -11,6 +12,7 @@ class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
ai_type: str
name: str name: str
@ -19,13 +21,18 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
api_key: str # all members have default values, so we can easily create
model: str # a default configuration
temperature: float ai_type: str = 'openai'
max_tokens: int name: str = 'openai_1'
top_p: float system: str = 'You are an assistant'
frequency_penalty: float api_key: str = '0123456789'
presence_penalty: float model: str = 'gpt-3.5'
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@ -33,7 +40,9 @@ class OpenAIConfig(AIConfig):
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
return cls( return cls(
name='OpenAI', ai_type='openai',
name=str(source['name']),
system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
@ -43,15 +52,24 @@ class OpenAIConfig(AIConfig):
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty'])
) )
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def create_default_ai_configs() -> dict[str, AIConfig]:
openai_conf = OpenAIConfig()
return {openai_conf.name: openai_conf}
@dataclass @dataclass
class Config: class Config:
""" """
The configuration file structure. The configuration file structure.
""" """
system: str # all members have default values, so we can easily create
db: str # a default configuration
openai: OpenAIConfig db: str = './db/'
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
@ -59,20 +77,34 @@ class Config:
Create Config from a dict. Create Config from a dict.
""" """
return cls( return cls(
system=str(source['system']),
db=str(source['db']), db=str(source['db']),
openai=OpenAIConfig.from_dict(source['openai']) ais=source['ais'] # FIXME: call correct constructors
) )
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
# add the AI name to the config (for easy internal access)
for ai_name, ai_conf in source['ais'].items():
ai_conf['name'] = ai_name
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, path: str) -> None: def to_file(self, file_path: Path) -> None:
with open(path, 'w') as f: # remove the AI name from the config (for a cleaner format)
yaml.dump(asdict(self), f, sort_keys=False) data = self.as_dict()
for ai_name, ai_conf in data['ais'].items():
del (ai_conf['name'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -1,14 +1,14 @@
import unittest import unittest
import io # import io
import pathlib # import pathlib
import argparse # import argparse
from chatmastermind.utils import terminal_width # from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd # from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai # from chatmastermind.api_client import ai
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
from chatmastermind.storage import create_chat_hist, save_answers, dump_data # from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock # from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY # from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase): class CmmTestCase(unittest.TestCase):
@ -30,207 +30,207 @@ class CmmTestCase(unittest.TestCase):
'frequency_penalty': 0, 'frequency_penalty': 0,
'presence_penalty': 0}} 'presence_penalty': 0}}
) )
#
#
class TestCreateChat(CmmTestCase): # class TestCreateChat(CmmTestCase):
#
def setUp(self) -> None: # def setUp(self) -> None:
self.config = self.dummy_config(db='test_files') # self.config = self.dummy_config(db='test_files')
self.question = "test question" # self.question = "test question"
self.tags = ['test_tag'] # self.tags = ['test_tag']
#
@patch('os.listdir') # @patch('os.listdir')
@patch('pathlib.Path.iterdir') # @patch('pathlib.Path.iterdir')
@patch('builtins.open') # @patch('builtins.open')
def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: # def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt'] # listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] # iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( # open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer', # {'question': 'test_content', 'answer': 'some answer',
'tags': ['test_tag']})) # 'tags': ['test_tag']}))
#
test_chat = create_chat_hist(self.question, self.tags, None, self.config) # test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
self.assertEqual(len(test_chat), 4) # self.assertEqual(len(test_chat), 4)
self.assertEqual(test_chat[0], # self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system}) # {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], # self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'}) # {'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2], # self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'}) # {'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3], # self.assertEqual(test_chat[3],
{'role': 'user', 'content': self.question}) # {'role': 'user', 'content': self.question})
#
@patch('os.listdir') # @patch('os.listdir')
@patch('pathlib.Path.iterdir') # @patch('pathlib.Path.iterdir')
@patch('builtins.open') # @patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: # def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt'] # listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] # iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( # open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
{'question': 'test_content', 'answer': 'some answer', # {'question': 'test_content', 'answer': 'some answer',
'tags': ['other_tag']})) # 'tags': ['other_tag']}))
#
test_chat = create_chat_hist(self.question, self.tags, None, self.config) # test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
self.assertEqual(len(test_chat), 2) # self.assertEqual(len(test_chat), 2)
self.assertEqual(test_chat[0], # self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system}) # {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], # self.assertEqual(test_chat[1],
{'role': 'user', 'content': self.question}) # {'role': 'user', 'content': self.question})
#
@patch('os.listdir') # @patch('os.listdir')
@patch('pathlib.Path.iterdir') # @patch('pathlib.Path.iterdir')
@patch('builtins.open') # @patch('builtins.open')
def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: # def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] # listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] # iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.side_effect = ( # open_mock.side_effect = (
io.StringIO(dump_data({'question': 'test_content', # io.StringIO(dump_data({'question': 'test_content',
'answer': 'some answer', # 'answer': 'some answer',
'tags': ['test_tag']})), # 'tags': ['test_tag']})),
io.StringIO(dump_data({'question': 'test_content2', # io.StringIO(dump_data({'question': 'test_content2',
'answer': 'some answer2', # 'answer': 'some answer2',
'tags': ['test_tag2']})), # 'tags': ['test_tag2']})),
) # )
#
test_chat = create_chat_hist(self.question, [], None, self.config) # test_chat = create_chat_hist(self.question, [], None, self.config)
#
self.assertEqual(len(test_chat), 6) # self.assertEqual(len(test_chat), 6)
self.assertEqual(test_chat[0], # self.assertEqual(test_chat[0],
{'role': 'system', 'content': self.config.system}) # {'role': 'system', 'content': self.config.system})
self.assertEqual(test_chat[1], # self.assertEqual(test_chat[1],
{'role': 'user', 'content': 'test_content'}) # {'role': 'user', 'content': 'test_content'})
self.assertEqual(test_chat[2], # self.assertEqual(test_chat[2],
{'role': 'assistant', 'content': 'some answer'}) # {'role': 'assistant', 'content': 'some answer'})
self.assertEqual(test_chat[3], # self.assertEqual(test_chat[3],
{'role': 'user', 'content': 'test_content2'}) # {'role': 'user', 'content': 'test_content2'})
self.assertEqual(test_chat[4], # self.assertEqual(test_chat[4],
{'role': 'assistant', 'content': 'some answer2'}) # {'role': 'assistant', 'content': 'some answer2'})
#
#
class TestHandleQuestion(CmmTestCase): # class TestHandleQuestion(CmmTestCase):
#
def setUp(self) -> None: # def setUp(self) -> None:
self.question = "test question" # self.question = "test question"
self.args = argparse.Namespace( # self.args = argparse.Namespace(
or_tags=['tag1'], # or_tags=['tag1'],
and_tags=None, # and_tags=None,
exclude_tags=['xtag1'], # exclude_tags=['xtag1'],
output_tags=None, # output_tags=None,
question=[self.question], # question=[self.question],
source=None, # source=None,
source_code_only=False, # source_code_only=False,
num_answers=3, # num_answers=3,
max_tokens=None, # max_tokens=None,
temperature=None, # temperature=None,
model=None, # model=None,
match_all_tags=False, # match_all_tags=False,
with_tags=False, # with_tags=False,
with_file=False, # with_file=False,
) # )
self.config = self.dummy_config(db='test_files') # self.config = self.dummy_config(db='test_files')
#
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat") # @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args") # @patch("chatmastermind.main.print_tag_args")
@patch("chatmastermind.main.print_chat_hist") # @patch("chatmastermind.main.print_chat_hist")
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) # @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp") # @patch("chatmastermind.utils.pp")
@patch("builtins.print") # @patch("builtins.print")
def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, # def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, # mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
mock_create_chat_hist: MagicMock) -> None: # mock_create_chat_hist: MagicMock) -> None:
open_mock = MagicMock() # open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): # with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config) # ask_cmd(self.args, self.config)
mock_print_tag_args.assert_called_once_with(self.args.or_tags, # mock_print_tag_args.assert_called_once_with(self.args.or_tags,
self.args.exclude_tags, # self.args.exclude_tags,
[]) # [])
mock_create_chat_hist.assert_called_once_with(self.question, # mock_create_chat_hist.assert_called_once_with(self.question,
self.args.or_tags, # self.args.or_tags,
self.args.exclude_tags, # self.args.exclude_tags,
self.config, # self.config,
match_all_tags=False, # match_all_tags=False,
with_tags=False, # with_tags=False,
with_file=False) # with_file=False)
mock_print_chat_hist.assert_called_once_with('test_chat', # mock_print_chat_hist.assert_called_once_with('test_chat',
False, # False,
self.args.source_code_only) # self.args.source_code_only)
mock_ai.assert_called_with("test_chat", # mock_ai.assert_called_with("test_chat",
self.config, # self.config,
self.args.num_answers) # self.args.num_answers)
expected_calls = [] # expected_calls = []
for num, answer in enumerate(mock_ai.return_value[0], start=1): # for num, answer in enumerate(mock_ai.return_value[0], start=1):
title = f'-- ANSWER {num} ' # title = f'-- ANSWER {num} '
title_end = '-' * (terminal_width() - len(title)) # title_end = '-' * (terminal_width() - len(title))
expected_calls.append(((f'{title}{title_end}',),)) # expected_calls.append(((f'{title}{title_end}',),))
expected_calls.append(((answer,),)) # expected_calls.append(((answer,),))
expected_calls.append((("-" * terminal_width(),),)) # expected_calls.append((("-" * terminal_width(),),))
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) # expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
self.assertEqual(mock_print.call_args_list, expected_calls) # self.assertEqual(mock_print.call_args_list, expected_calls)
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) # open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
open_mock.assert_has_calls(open_expected_calls, any_order=True) # open_mock.assert_has_calls(open_expected_calls, any_order=True)
#
#
class TestSaveAnswers(CmmTestCase): # class TestSaveAnswers(CmmTestCase):
@mock.patch('builtins.open') # @mock.patch('builtins.open')
@mock.patch('chatmastermind.storage.print') # @mock.patch('chatmastermind.storage.print')
def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: # def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
question = "Test question?" # question = "Test question?"
answers = ["Answer 1", "Answer 2"] # answers = ["Answer 1", "Answer 2"]
tags = ["tag1", "tag2"] # tags = ["tag1", "tag2"]
otags = ["otag1", "otag2"] # otags = ["otag1", "otag2"]
config = self.dummy_config(db='test_db') # config = self.dummy_config(db='test_db')
#
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ # with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
mock.patch('chatmastermind.storage.yaml.dump'), \ # mock.patch('chatmastermind.storage.yaml.dump'), \
mock.patch('io.StringIO') as stringio_mock: # mock.patch('io.StringIO') as stringio_mock:
stringio_instance = stringio_mock.return_value # stringio_instance = stringio_mock.return_value
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] # stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
save_answers(question, answers, tags, otags, config) # save_answers(question, answers, tags, otags, config)
#
open_calls = [ # open_calls = [
mock.call(pathlib.Path('test_db/.next'), 'r'), # mock.call(pathlib.Path('test_db/.next'), 'r'),
mock.call(pathlib.Path('test_db/.next'), 'w'), # mock.call(pathlib.Path('test_db/.next'), 'w'),
] # ]
open_mock.assert_has_calls(open_calls, any_order=True) # open_mock.assert_has_calls(open_calls, any_order=True)
#
#
class TestAI(CmmTestCase): # class TestAI(CmmTestCase):
#
@patch("openai.ChatCompletion.create") # @patch("openai.ChatCompletion.create")
def test_ai(self, mock_create: MagicMock) -> None: # def test_ai(self, mock_create: MagicMock) -> None:
mock_create.return_value = { # mock_create.return_value = {
'choices': [ # 'choices': [
{'message': {'content': 'response_text_1'}}, # {'message': {'content': 'response_text_1'}},
{'message': {'content': 'response_text_2'}} # {'message': {'content': 'response_text_2'}}
], # ],
'usage': {'tokens': 10} # 'usage': {'tokens': 10}
} # }
#
chat = [{"role": "system", "content": "hello ai"}] # chat = [{"role": "system", "content": "hello ai"}]
config = self.dummy_config(db='dummy') # config = self.dummy_config(db='dummy')
config.openai.model = "text-davinci-002" # config.openai.model = "text-davinci-002"
config.openai.max_tokens = 150 # config.openai.max_tokens = 150
config.openai.temperature = 0.5 # config.openai.temperature = 0.5
#
result = ai(chat, config, 2) # result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'], # expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10}) # {'tokens': 10})
self.assertEqual(result, expected_result) # self.assertEqual(result, expected_result)
#
#
class TestCreateParser(CmmTestCase): # class TestCreateParser(CmmTestCase):
def test_create_parser(self) -> None: # def test_create_parser(self) -> None:
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: # with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_cmdparser = Mock() # mock_cmdparser = Mock()
mock_add_subparsers.return_value = mock_cmdparser # mock_add_subparsers.return_value = mock_cmdparser
parser = create_parser() # parser = create_parser()
self.assertIsInstance(parser, argparse.ArgumentParser) # self.assertIsInstance(parser, argparse.ArgumentParser)
mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) # mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) # mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
self.assertTrue('.config.yaml' in parser.get_default('config')) # self.assertTrue('.config.yaml' in parser.get_default('config'))