Compare commits

..

No commits in common. "13c412782772ecad63af45719fe1a82a475bba45" and "ba5aa1fbc73013cee81c7bb27b0a970866b6bf25" have entirely different histories.

3 changed files with 234 additions and 266 deletions

View File

@ -17,9 +17,8 @@ class OpenAI(AI):
The OpenAI AI client. The OpenAI AI client.
""" """
def __init__(self, config: OpenAIConfig) -> None: def __init__(self, name: str, config: OpenAIConfig) -> None:
self.ai_type = config.ai_type self.name = name
self.name = config.name
self.config = config self.config = config
def request(self, def request(self,
@ -32,7 +31,8 @@ class OpenAI(AI):
chat history. The nr. of requested answers corresponds to the chat history. The nr. of requested answers corresponds to the
nr. of messages in the 'AIResponse'. nr. of messages in the 'AIResponse'.
""" """
oai_chat = self.openai_chat(chat, self.config.system, question) # FIXME: use real 'system' message (store in OpenAIConfig)
oai_chat = self.openai_chat(chat, "system", question)
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=self.config.model, model=self.config.model,
messages=oai_chat, messages=oai_chat,

View File

@ -1,7 +1,6 @@
import yaml import yaml
from pathlib import Path
from typing import Type, TypeVar, Any from typing import Type, TypeVar, Any
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
@ -12,7 +11,6 @@ class AIConfig:
""" """
The base class of all AI configurations. The base class of all AI configurations.
""" """
ai_type: str
name: str name: str
@ -21,18 +19,13 @@ class OpenAIConfig(AIConfig):
""" """
The OpenAI section of the configuration file. The OpenAI section of the configuration file.
""" """
# all members have default values, so we can easily create api_key: str
# a default configuration model: str
ai_type: str = 'openai' temperature: float
name: str = 'openai_1' max_tokens: int
system: str = 'You are an assistant' top_p: float
api_key: str = '0123456789' frequency_penalty: float
model: str = 'gpt-3.5' presence_penalty: float
temperature: float = 1.0
max_tokens: int = 4000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
@classmethod @classmethod
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
@ -40,9 +33,7 @@ class OpenAIConfig(AIConfig):
Create OpenAIConfig from a dict. Create OpenAIConfig from a dict.
""" """
return cls( return cls(
ai_type='openai', name='OpenAI',
name=str(source['name']),
system=str(source['system']),
api_key=str(source['api_key']), api_key=str(source['api_key']),
model=str(source['model']), model=str(source['model']),
max_tokens=int(source['max_tokens']), max_tokens=int(source['max_tokens']),
@ -52,24 +43,15 @@ class OpenAIConfig(AIConfig):
presence_penalty=float(source['presence_penalty']) presence_penalty=float(source['presence_penalty'])
) )
def as_dict(self) -> dict[str, Any]:
return asdict(self)
def create_default_ai_configs() -> dict[str, AIConfig]:
openai_conf = OpenAIConfig()
return {openai_conf.name: openai_conf}
@dataclass @dataclass
class Config: class Config:
""" """
The configuration file structure. The configuration file structure.
""" """
# all members have default values, so we can easily create system: str
# a default configuration db: str
db: str = './db/' openai: OpenAIConfig
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
@classmethod @classmethod
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
@ -77,34 +59,20 @@ class Config:
Create Config from a dict. Create Config from a dict.
""" """
return cls( return cls(
system=str(source['system']),
db=str(source['db']), db=str(source['db']),
ais=source['ais'] # FIXME: call correct constructors openai=OpenAIConfig.from_dict(source['openai'])
) )
@classmethod
def create_default(self, file_path: Path) -> None:
"""
Creates a default Config in the given file.
"""
conf = Config()
conf.to_file(file_path)
@classmethod @classmethod
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
# add the AI name to the config (for easy internal access)
for ai_name, ai_conf in source['ais'].items():
ai_conf['name'] = ai_name
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, file_path: Path) -> None: def to_file(self, path: str) -> None:
# remove the AI name from the config (for a cleaner format) with open(path, 'w') as f:
data = self.as_dict() yaml.dump(asdict(self), f, sort_keys=False)
for ai_name, ai_conf in data['ais'].items():
del (ai_conf['name'])
with open(file_path, 'w') as f:
yaml.dump(data, f, sort_keys=False)
def as_dict(self) -> dict[str, Any]: def as_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -1,14 +1,14 @@
import unittest import unittest
# import io import io
# import pathlib import pathlib
# import argparse import argparse
# from chatmastermind.utils import terminal_width from chatmastermind.utils import terminal_width
# from chatmastermind.main import create_parser, ask_cmd from chatmastermind.main import create_parser, ask_cmd
# from chatmastermind.api_client import ai from chatmastermind.api_client import ai
from chatmastermind.configuration import Config from chatmastermind.configuration import Config
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
# from unittest import mock from unittest import mock
# from unittest.mock import patch, MagicMock, Mock, ANY from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase): class CmmTestCase(unittest.TestCase):
@ -30,207 +30,207 @@ class CmmTestCase(unittest.TestCase):
'frequency_penalty': 0, 'frequency_penalty': 0,
'presence_penalty': 0}} 'presence_penalty': 0}}
) )
#
#
# class TestCreateChat(CmmTestCase): class TestCreateChat(CmmTestCase):
#
# def setUp(self) -> None: def setUp(self) -> None:
# self.config = self.dummy_config(db='test_files') self.config = self.dummy_config(db='test_files')
# self.question = "test question" self.question = "test question"
# self.tags = ['test_tag'] self.tags = ['test_tag']
#
# @patch('os.listdir') @patch('os.listdir')
# @patch('pathlib.Path.iterdir') @patch('pathlib.Path.iterdir')
# @patch('builtins.open') @patch('builtins.open')
# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt'] listdir_mock.return_value = ['testfile.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
# {'question': 'test_content', 'answer': 'some answer', {'question': 'test_content', 'answer': 'some answer',
# 'tags': ['test_tag']})) 'tags': ['test_tag']}))
#
# test_chat = create_chat_hist(self.question, self.tags, None, self.config) test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
# self.assertEqual(len(test_chat), 4) self.assertEqual(len(test_chat), 4)
# self.assertEqual(test_chat[0], self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system}) {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1], self.assertEqual(test_chat[1],
# {'role': 'user', 'content': 'test_content'}) {'role': 'user', 'content': 'test_content'})
# self.assertEqual(test_chat[2], self.assertEqual(test_chat[2],
# {'role': 'assistant', 'content': 'some answer'}) {'role': 'assistant', 'content': 'some answer'})
# self.assertEqual(test_chat[3], self.assertEqual(test_chat[3],
# {'role': 'user', 'content': self.question}) {'role': 'user', 'content': self.question})
#
# @patch('os.listdir') @patch('os.listdir')
# @patch('pathlib.Path.iterdir') @patch('pathlib.Path.iterdir')
# @patch('builtins.open') @patch('builtins.open')
# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt'] listdir_mock.return_value = ['testfile.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
# {'question': 'test_content', 'answer': 'some answer', {'question': 'test_content', 'answer': 'some answer',
# 'tags': ['other_tag']})) 'tags': ['other_tag']}))
#
# test_chat = create_chat_hist(self.question, self.tags, None, self.config) test_chat = create_chat_hist(self.question, self.tags, None, self.config)
#
# self.assertEqual(len(test_chat), 2) self.assertEqual(len(test_chat), 2)
# self.assertEqual(test_chat[0], self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system}) {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1], self.assertEqual(test_chat[1],
# {'role': 'user', 'content': self.question}) {'role': 'user', 'content': self.question})
#
# @patch('os.listdir') @patch('os.listdir')
# @patch('pathlib.Path.iterdir') @patch('pathlib.Path.iterdir')
# @patch('builtins.open') @patch('builtins.open')
# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
# open_mock.side_effect = ( open_mock.side_effect = (
# io.StringIO(dump_data({'question': 'test_content', io.StringIO(dump_data({'question': 'test_content',
# 'answer': 'some answer', 'answer': 'some answer',
# 'tags': ['test_tag']})), 'tags': ['test_tag']})),
# io.StringIO(dump_data({'question': 'test_content2', io.StringIO(dump_data({'question': 'test_content2',
# 'answer': 'some answer2', 'answer': 'some answer2',
# 'tags': ['test_tag2']})), 'tags': ['test_tag2']})),
# ) )
#
# test_chat = create_chat_hist(self.question, [], None, self.config) test_chat = create_chat_hist(self.question, [], None, self.config)
#
# self.assertEqual(len(test_chat), 6) self.assertEqual(len(test_chat), 6)
# self.assertEqual(test_chat[0], self.assertEqual(test_chat[0],
# {'role': 'system', 'content': self.config.system}) {'role': 'system', 'content': self.config.system})
# self.assertEqual(test_chat[1], self.assertEqual(test_chat[1],
# {'role': 'user', 'content': 'test_content'}) {'role': 'user', 'content': 'test_content'})
# self.assertEqual(test_chat[2], self.assertEqual(test_chat[2],
# {'role': 'assistant', 'content': 'some answer'}) {'role': 'assistant', 'content': 'some answer'})
# self.assertEqual(test_chat[3], self.assertEqual(test_chat[3],
# {'role': 'user', 'content': 'test_content2'}) {'role': 'user', 'content': 'test_content2'})
# self.assertEqual(test_chat[4], self.assertEqual(test_chat[4],
# {'role': 'assistant', 'content': 'some answer2'}) {'role': 'assistant', 'content': 'some answer2'})
#
#
# class TestHandleQuestion(CmmTestCase): class TestHandleQuestion(CmmTestCase):
#
# def setUp(self) -> None: def setUp(self) -> None:
# self.question = "test question" self.question = "test question"
# self.args = argparse.Namespace( self.args = argparse.Namespace(
# or_tags=['tag1'], or_tags=['tag1'],
# and_tags=None, and_tags=None,
# exclude_tags=['xtag1'], exclude_tags=['xtag1'],
# output_tags=None, output_tags=None,
# question=[self.question], question=[self.question],
# source=None, source=None,
# source_code_only=False, source_code_only=False,
# num_answers=3, num_answers=3,
# max_tokens=None, max_tokens=None,
# temperature=None, temperature=None,
# model=None, model=None,
# match_all_tags=False, match_all_tags=False,
# with_tags=False, with_tags=False,
# with_file=False, with_file=False,
# ) )
# self.config = self.dummy_config(db='test_files') self.config = self.dummy_config(db='test_files')
#
# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
# @patch("chatmastermind.main.print_tag_args") @patch("chatmastermind.main.print_tag_args")
# @patch("chatmastermind.main.print_chat_hist") @patch("chatmastermind.main.print_chat_hist")
# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
# @patch("chatmastermind.utils.pp") @patch("chatmastermind.utils.pp")
# @patch("builtins.print") @patch("builtins.print")
# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
# mock_create_chat_hist: MagicMock) -> None: mock_create_chat_hist: MagicMock) -> None:
# open_mock = MagicMock() open_mock = MagicMock()
# with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
# ask_cmd(self.args, self.config) ask_cmd(self.args, self.config)
# mock_print_tag_args.assert_called_once_with(self.args.or_tags, mock_print_tag_args.assert_called_once_with(self.args.or_tags,
# self.args.exclude_tags, self.args.exclude_tags,
# []) [])
# mock_create_chat_hist.assert_called_once_with(self.question, mock_create_chat_hist.assert_called_once_with(self.question,
# self.args.or_tags, self.args.or_tags,
# self.args.exclude_tags, self.args.exclude_tags,
# self.config, self.config,
# match_all_tags=False, match_all_tags=False,
# with_tags=False, with_tags=False,
# with_file=False) with_file=False)
# mock_print_chat_hist.assert_called_once_with('test_chat', mock_print_chat_hist.assert_called_once_with('test_chat',
# False, False,
# self.args.source_code_only) self.args.source_code_only)
# mock_ai.assert_called_with("test_chat", mock_ai.assert_called_with("test_chat",
# self.config, self.config,
# self.args.num_answers) self.args.num_answers)
# expected_calls = [] expected_calls = []
# for num, answer in enumerate(mock_ai.return_value[0], start=1): for num, answer in enumerate(mock_ai.return_value[0], start=1):
# title = f'-- ANSWER {num} ' title = f'-- ANSWER {num} '
# title_end = '-' * (terminal_width() - len(title)) title_end = '-' * (terminal_width() - len(title))
# expected_calls.append(((f'{title}{title_end}',),)) expected_calls.append(((f'{title}{title_end}',),))
# expected_calls.append(((answer,),)) expected_calls.append(((answer,),))
# expected_calls.append((("-" * terminal_width(),),)) expected_calls.append((("-" * terminal_width(),),))
# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
# self.assertEqual(mock_print.call_args_list, expected_calls) self.assertEqual(mock_print.call_args_list, expected_calls)
# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
# open_mock.assert_has_calls(open_expected_calls, any_order=True) open_mock.assert_has_calls(open_expected_calls, any_order=True)
#
#
# class TestSaveAnswers(CmmTestCase): class TestSaveAnswers(CmmTestCase):
# @mock.patch('builtins.open') @mock.patch('builtins.open')
# @mock.patch('chatmastermind.storage.print') @mock.patch('chatmastermind.storage.print')
# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
# question = "Test question?" question = "Test question?"
# answers = ["Answer 1", "Answer 2"] answers = ["Answer 1", "Answer 2"]
# tags = ["tag1", "tag2"] tags = ["tag1", "tag2"]
# otags = ["otag1", "otag2"] otags = ["otag1", "otag2"]
# config = self.dummy_config(db='test_db') config = self.dummy_config(db='test_db')
#
# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
# mock.patch('chatmastermind.storage.yaml.dump'), \ mock.patch('chatmastermind.storage.yaml.dump'), \
# mock.patch('io.StringIO') as stringio_mock: mock.patch('io.StringIO') as stringio_mock:
# stringio_instance = stringio_mock.return_value stringio_instance = stringio_mock.return_value
# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
# save_answers(question, answers, tags, otags, config) save_answers(question, answers, tags, otags, config)
#
# open_calls = [ open_calls = [
# mock.call(pathlib.Path('test_db/.next'), 'r'), mock.call(pathlib.Path('test_db/.next'), 'r'),
# mock.call(pathlib.Path('test_db/.next'), 'w'), mock.call(pathlib.Path('test_db/.next'), 'w'),
# ] ]
# open_mock.assert_has_calls(open_calls, any_order=True) open_mock.assert_has_calls(open_calls, any_order=True)
#
#
# class TestAI(CmmTestCase): class TestAI(CmmTestCase):
#
# @patch("openai.ChatCompletion.create") @patch("openai.ChatCompletion.create")
# def test_ai(self, mock_create: MagicMock) -> None: def test_ai(self, mock_create: MagicMock) -> None:
# mock_create.return_value = { mock_create.return_value = {
# 'choices': [ 'choices': [
# {'message': {'content': 'response_text_1'}}, {'message': {'content': 'response_text_1'}},
# {'message': {'content': 'response_text_2'}} {'message': {'content': 'response_text_2'}}
# ], ],
# 'usage': {'tokens': 10} 'usage': {'tokens': 10}
# } }
#
# chat = [{"role": "system", "content": "hello ai"}] chat = [{"role": "system", "content": "hello ai"}]
# config = self.dummy_config(db='dummy') config = self.dummy_config(db='dummy')
# config.openai.model = "text-davinci-002" config.openai.model = "text-davinci-002"
# config.openai.max_tokens = 150 config.openai.max_tokens = 150
# config.openai.temperature = 0.5 config.openai.temperature = 0.5
#
# result = ai(chat, config, 2) result = ai(chat, config, 2)
# expected_result = (['response_text_1', 'response_text_2'], expected_result = (['response_text_1', 'response_text_2'],
# {'tokens': 10}) {'tokens': 10})
# self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
#
#
# class TestCreateParser(CmmTestCase): class TestCreateParser(CmmTestCase):
# def test_create_parser(self) -> None: def test_create_parser(self) -> None:
# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
# mock_cmdparser = Mock() mock_cmdparser = Mock()
# mock_add_subparsers.return_value = mock_cmdparser mock_add_subparsers.return_value = mock_cmdparser
# parser = create_parser() parser = create_parser()
# self.assertIsInstance(parser, argparse.ArgumentParser) self.assertIsInstance(parser, argparse.ArgumentParser)
# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True)
# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY)
# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY)
# self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertTrue('.config.yaml' in parser.get_default('config'))