Better type annotations and stricter mypy checks (#3 ) #4

Merged
ok merged 6 commits from mypy_fixes into main 2023-08-16 13:01:27 +02:00
7 changed files with 166 additions and 84 deletions

View File

@ -1,11 +1,17 @@
import openai import openai
from .utils import ChatType
from .configuration import Config
def openai_api_key(api_key: str) -> None: def openai_api_key(api_key: str) -> None:
openai.api_key = api_key openai.api_key = api_key
def print_models() -> None: def print_models() -> None:
"""
Print all models supported by the current AI.
"""
not_ready = [] not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']: if engine['ready']:
@ -16,10 +22,16 @@ def print_models() -> None:
print('\nNot ready: ' + ', '.join(not_ready)) print('\nNot ready: ' + ', '.join(not_ready))
def ai(chat: list[dict[str, str]], def ai(chat: ChatType,
config: dict, config: Config,
number: int number: int
) -> tuple[list[str], dict[str, int]]: ) -> tuple[list[str], dict[str, int]]:
"""
Make AI request with the given chat history and configuration.
Return AI response and tokens used.
"""
if not isinstance(config['openai'], dict):
raise RuntimeError('Configuration openai is not a dict.')
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=config['openai']['model'], model=config['openai']['model'],
messages=chat, messages=chat,

View File

@ -0,0 +1,63 @@
import pathlib
from typing import TypedDict, Any, Union
class OpenAIConfig(TypedDict):
"""
The OpenAI section of the configuration file.
"""
api_key: str
model: str
temperature: float
max_tokens: int
top_p: float
frequency_penalty: float
presence_penalty: float
def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool:
"""
Checks if the given Open AI configuration dict is complete
and contains valid types and values.
"""
try:
str(conf['api_key'])
str(conf['model'])
int(conf['max_tokens'])
float(conf['temperature'])
float(conf['top_p'])
float(conf['frequency_penalty'])
float(conf['presence_penalty'])
return True
except Exception as e:
print(f"OpenAI configuration is invalid: {e}")
return False
class Config(TypedDict):
"""
The configuration file structure.
"""
system: str
db: str
openai: OpenAIConfig
def config_valid(conf: dict[str, Any]) -> bool:
"""
Checks if the given configuration dict is complete
and contains valid types and values.
"""
try:
str(conf['system'])
pathlib.Path(str(conf['db']))
return True
except Exception as e:
print(f"Configuration is invalid: {e}")
return False
if 'openai' in conf:
return openai_config_valid(conf['openai'])
else:
# required as long as we only support OpenAI
print("Section 'openai' is missing in the configuration!")
return False

View File

@ -7,23 +7,25 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config
from itertools import zip_longest from itertools import zip_longest
from typing import Any
default_config = '.config.yaml' default_config = '.config.yaml'
def tags_completer(prefix, parsed_args, **kwargs): def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
with open(parsed_args.config, 'r') as f: with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader)
return get_tags_unique(config, prefix) return get_tags_unique(config, prefix)
def create_question_with_hist(args: argparse.Namespace, def create_question_with_hist(args: argparse.Namespace,
config: ConfigType, config: Config,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[ChatType, str, list[str]]:
""" """
Creates the "AI request", including the question and chat history as determined Creates the "AI request", including the question and chat history as determined
by the specified tags. by the specified tags.
@ -55,7 +57,7 @@ def create_question_with_hist(args: argparse.Namespace,
return chat, full_question, tags return chat, full_question, tags
def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: def tag_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tag' command.
""" """
@ -63,13 +65,10 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print_tags_frequency(get_tags(config, None)) print_tags_frequency(get_tags(config, None))
def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: def config_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'config' command. Handler for the 'config' command.
""" """
if type(config['openai']) is not dict:
raise RuntimeError('Configuration openai is not a dict.')
if args.list_models: if args.list_models:
print_models() print_models()
elif args.print_model: elif args.print_model:
@ -79,19 +78,16 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None:
write_config(args.config, config) write_config(args.config, config)
def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: def ask_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'ask' command. Handler for the 'ask' command.
""" """
if type(config['openai']) is not dict:
raise RuntimeError('Configuration openai is not a dict.')
config_openai = config['openai']
if args.max_tokens: if args.max_tokens:
config_openai['max_tokens'] = args.max_tokens config['openai']['max_tokens'] = args.max_tokens
if args.temperature: if args.temperature:
config_openai['temperature'] = args.temperature config['openai']['temperature'] = args.temperature
if args.model: if args.model:
config_openai['model'] = args.model config['openai']['model'] = args.model
chat, question, tags = create_question_with_hist(args, config) chat, question, tags = create_question_with_hist(args, config)
print_chat_hist(chat, False, args.only_source_code) print_chat_hist(chat, False, args.only_source_code)
otags = args.output_tags or [] otags = args.output_tags or []
@ -101,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print(f"Usage: {usage}") print(f"Usage: {usage}")
def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: def hist_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'hist' command. Handler for the 'hist' command.
""" """
@ -115,7 +111,7 @@ def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None:
print_chat_hist(chat, args.dump, args.only_source_code) print_chat_hist(chat, args.dump, args.only_source_code)
def print_cmd(args: argparse.Namespace, config: ConfigType) -> None: def print_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'print' command. Handler for the 'print' command.
""" """
@ -231,10 +227,7 @@ def main() -> int:
command = parser.parse_args() command = parser.parse_args()
config = read_config(args.config) config = read_config(args.config)
if type(config['openai']) is dict and type(config['openai']['api_key']) is str:
openai_api_key(config['openai']['api_key']) openai_api_key(config['openai']['api_key'])
else:
raise RuntimeError("Configuration openai.api_key is wrong.")
command.func(command, config) command.func(command, config)

View File

@ -1,11 +1,13 @@
import yaml import yaml
import sys
import io import io
import pathlib import pathlib
from .utils import terminal_width, append_message, message_to_chat, ConfigType from .utils import terminal_width, append_message, message_to_chat, ChatType
from typing import List, Dict, Any, Optional from .configuration import Config, config_valid
from typing import Any, Optional
def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]:
with open(fname, "r") as fd: with open(fname, "r") as fd:
tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip()
# also support tags separated by ',' (old format) # also support tags separated by ',' (old format)
@ -22,18 +24,20 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]:
"file": fname.name} "file": fname.name}
def read_config(path: str) -> ConfigType: def read_config(path: str) -> Config:
with open(path, 'r') as f: with open(path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader) config = yaml.load(f, Loader=yaml.FullLoader)
if not config_valid(config):
sys.exit(1)
return config return config
def write_config(path: str, config: ConfigType) -> None: def write_config(path: str, config: Config) -> None:
with open(path, 'w') as f: with open(path, 'w') as f:
yaml.dump(config, f) yaml.dump(config, f)
def dump_data(data: Dict[str, Any]) -> str: def dump_data(data: dict[str, Any]) -> str:
with io.StringIO() as fd: with io.StringIO() as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@ -41,7 +45,7 @@ def dump_data(data: Dict[str, Any]) -> str:
return fd.getvalue() return fd.getvalue()
def write_file(fname: str, data: Dict[str, Any]) -> None: def write_file(fname: str, data: dict[str, Any]) -> None:
with open(fname, "w") as fd: with open(fname, "w") as fd:
fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'TAGS: {" ".join(data["tags"])}\n')
fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n')
@ -52,7 +56,7 @@ def save_answers(question: str,
answers: list[str], answers: list[str],
tags: list[str], tags: list[str],
otags: Optional[list[str]], otags: Optional[list[str]],
config: ConfigType config: Config
) -> None: ) -> None:
wtags = otags or tags wtags = otags or tags
num, inum = 0, 0 num, inum = 0, 0
@ -75,14 +79,14 @@ def save_answers(question: str,
def create_chat_hist(question: Optional[str], def create_chat_hist(question: Optional[str],
tags: Optional[List[str]], tags: Optional[list[str]],
extags: Optional[List[str]], extags: Optional[list[str]],
config: ConfigType, config: Config,
match_all_tags: bool = False, match_all_tags: bool = False,
with_tags: bool = False, with_tags: bool = False,
with_file: bool = False with_file: bool = False
) -> List[Dict[str, str]]: ) -> ChatType:
chat: List[Dict[str, str]] = [] chat: ChatType = []
append_message(chat, 'system', str(config['system']).strip()) append_message(chat, 'system', str(config['system']).strip())
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
@ -108,7 +112,7 @@ def create_chat_hist(question: Optional[str],
return chat return chat
def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: def get_tags(config: Config, prefix: Optional[str]) -> list[str]:
result = [] result = []
for file in sorted(pathlib.Path(str(config['db'])).iterdir()): for file in sorted(pathlib.Path(str(config['db'])).iterdir()):
if file.suffix == '.yaml': if file.suffix == '.yaml':
@ -127,5 +131,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]:
return result return result
def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]:
return list(set(get_tags(config, prefix))) return list(set(get_tags(config, prefix)))

View File

@ -1,14 +1,15 @@
import shutil import shutil
from pprint import PrettyPrinter from pprint import PrettyPrinter
from typing import Any
ConfigType = dict[str, str | dict[str, str | int | float]] ChatType = list[dict[str, str]]
def terminal_width() -> int: def terminal_width() -> int:
return shutil.get_terminal_size().columns return shutil.get_terminal_size().columns
def pp(*args, **kwargs) -> None: def pp(*args: Any, **kwargs: Any) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
@ -30,7 +31,7 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None
print() print()
def append_message(chat: list[dict[str, str]], def append_message(chat: ChatType,
role: str, role: str,
content: str content: str
) -> None: ) -> None:
@ -38,7 +39,7 @@ def append_message(chat: list[dict[str, str]],
def message_to_chat(message: dict[str, str], def message_to_chat(message: dict[str, str],
chat: list[dict[str, str]], chat: ChatType,
with_tags: bool = False, with_tags: bool = False,
with_file: bool = False with_file: bool = False
) -> None: ) -> None:
@ -61,7 +62,7 @@ def display_source_code(content: str) -> None:
pass pass
def print_chat_hist(chat, dump=False, source_code=False) -> None: def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None:
if dump: if dump:
pp(chat) pp(chat)
return return

View File

@ -5,3 +5,4 @@ strict_optional = True
warn_unused_ignores = False warn_unused_ignores = False
warn_redundant_casts = True warn_redundant_casts = True
warn_unused_configs = True warn_unused_configs = True
disallow_untyped_defs = True

View File

@ -5,25 +5,46 @@ import argparse
from chatmastermind.utils import terminal_width from chatmastermind.utils import terminal_width
from chatmastermind.main import create_parser, ask_cmd from chatmastermind.main import create_parser, ask_cmd
from chatmastermind.api_client import ai from chatmastermind.api_client import ai
from chatmastermind.configuration import Config, OpenAIConfig
from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.storage import create_chat_hist, save_answers, dump_data
from unittest import mock from unittest import mock
from unittest.mock import patch, MagicMock, Mock, ANY from unittest.mock import patch, MagicMock, Mock, ANY
class TestCreateChat(unittest.TestCase): class CmmTestCase(unittest.TestCase):
"""
Base class for all cmm testcases.
"""
def dummy_config(self, db: str) -> Config:
"""
Creates a dummy configuration.
"""
return Config(
system='dummy_system',
db=db,
openai=OpenAIConfig(
api_key='dummy_key',
model='dummy_model',
max_tokens=4000,
temperature=1.0,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
)
def setUp(self):
self.config = { class TestCreateChat(CmmTestCase):
'system': 'System text',
'db': 'test_files' def setUp(self) -> None:
} self.config = self.dummy_config(db='test_files')
self.question = "test question" self.question = "test question"
self.tags = ['test_tag'] self.tags = ['test_tag']
@patch('os.listdir') @patch('os.listdir')
@patch('pathlib.Path.iterdir') @patch('pathlib.Path.iterdir')
@patch('builtins.open') @patch('builtins.open')
def test_create_chat_with_tags(self, open_mock, iterdir_mock, listdir_mock): def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt'] listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
@ -45,7 +66,7 @@ class TestCreateChat(unittest.TestCase):
@patch('os.listdir') @patch('os.listdir')
@patch('pathlib.Path.iterdir') @patch('pathlib.Path.iterdir')
@patch('builtins.open') @patch('builtins.open')
def test_create_chat_with_other_tags(self, open_mock, iterdir_mock, listdir_mock): def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt'] listdir_mock.return_value = ['testfile.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
@ -63,7 +84,7 @@ class TestCreateChat(unittest.TestCase):
@patch('os.listdir') @patch('os.listdir')
@patch('pathlib.Path.iterdir') @patch('pathlib.Path.iterdir')
@patch('builtins.open') @patch('builtins.open')
def test_create_chat_without_tags(self, open_mock, iterdir_mock, listdir_mock): def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None:
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value]
open_mock.side_effect = ( open_mock.side_effect = (
@ -90,9 +111,9 @@ class TestCreateChat(unittest.TestCase):
{'role': 'assistant', 'content': 'some answer2'}) {'role': 'assistant', 'content': 'some answer2'})
class TestHandleQuestion(unittest.TestCase): class TestHandleQuestion(CmmTestCase):
def setUp(self): def setUp(self) -> None:
self.question = "test question" self.question = "test question"
self.args = argparse.Namespace( self.args = argparse.Namespace(
tags=['tag1'], tags=['tag1'],
@ -109,12 +130,7 @@ class TestHandleQuestion(unittest.TestCase):
with_tags=False, with_tags=False,
with_file=False, with_file=False,
) )
self.config = { self.config = self.dummy_config(db='test_files')
'db': 'test_files',
'setting1': 'value1',
'setting2': 'value2',
'openai': {},
}
@patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.create_chat_hist", return_value="test_chat")
@patch("chatmastermind.main.print_tag_args") @patch("chatmastermind.main.print_tag_args")
@ -122,9 +138,9 @@ class TestHandleQuestion(unittest.TestCase):
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
@patch("chatmastermind.utils.pp") @patch("chatmastermind.utils.pp")
@patch("builtins.print") @patch("builtins.print")
def test_ask_cmd(self, mock_print, mock_pp, mock_ai, def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock,
mock_print_chat_hist, mock_print_tag_args, mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock,
mock_create_chat_hist): mock_create_chat_hist: MagicMock) -> None:
open_mock = MagicMock() open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
ask_cmd(self.args, self.config) ask_cmd(self.args, self.config)
@ -155,15 +171,15 @@ class TestHandleQuestion(unittest.TestCase):
open_mock.assert_has_calls(open_expected_calls, any_order=True) open_mock.assert_has_calls(open_expected_calls, any_order=True)
class TestSaveAnswers(unittest.TestCase): class TestSaveAnswers(CmmTestCase):
@mock.patch('builtins.open') @mock.patch('builtins.open')
@mock.patch('chatmastermind.storage.print') @mock.patch('chatmastermind.storage.print')
def test_save_answers(self, print_mock, open_mock): def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None:
question = "Test question?" question = "Test question?"
answers = ["Answer 1", "Answer 2"] answers = ["Answer 1", "Answer 2"]
tags = ["tag1", "tag2"] tags = ["tag1", "tag2"]
otags = ["otag1", "otag2"] otags = ["otag1", "otag2"]
config = {'db': 'test_db'} config = self.dummy_config(db='test_db')
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
mock.patch('chatmastermind.storage.yaml.dump'), \ mock.patch('chatmastermind.storage.yaml.dump'), \
@ -179,10 +195,10 @@ class TestSaveAnswers(unittest.TestCase):
open_mock.assert_has_calls(open_calls, any_order=True) open_mock.assert_has_calls(open_calls, any_order=True)
class TestAI(unittest.TestCase): class TestAI(CmmTestCase):
@patch("openai.ChatCompletion.create") @patch("openai.ChatCompletion.create")
def test_ai(self, mock_create: MagicMock): def test_ai(self, mock_create: MagicMock) -> None:
mock_create.return_value = { mock_create.return_value = {
'choices': [ 'choices': [
{'message': {'content': 'response_text_1'}}, {'message': {'content': 'response_text_1'}},
@ -191,28 +207,20 @@ class TestAI(unittest.TestCase):
'usage': {'tokens': 10} 'usage': {'tokens': 10}
} }
number = 2
chat = [{"role": "system", "content": "hello ai"}] chat = [{"role": "system", "content": "hello ai"}]
config = { config = self.dummy_config(db='dummy')
"openai": { config['openai']['model'] = "text-davinci-002"
"model": "text-davinci-002", config['openai']['max_tokens'] = 150
"temperature": 0.5, config['openai']['temperature'] = 0.5
"max_tokens": 150,
"top_p": 1,
"n": number,
"frequency_penalty": 0,
"presence_penalty": 0
}
}
result = ai(chat, config, number) result = ai(chat, config, 2)
expected_result = (['response_text_1', 'response_text_2'], expected_result = (['response_text_1', 'response_text_2'],
{'tokens': 10}) {'tokens': 10})
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
class TestCreateParser(unittest.TestCase): class TestCreateParser(CmmTestCase):
def test_create_parser(self): def test_create_parser(self) -> None:
with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers:
mock_cmdparser = Mock() mock_cmdparser = Mock()
mock_add_subparsers.return_value = mock_cmdparser mock_add_subparsers.return_value = mock_cmdparser