Compare commits

..

3 Commits

3 changed files with 57 additions and 6 deletions

View File

@ -15,7 +15,10 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI:
and configuration file. and configuration file.
""" """
if args.ai: if args.ai:
ai_conf = config.ais[args.ai] try:
ai_conf = config.ais[args.ai]
except KeyError:
raise AIError(f"AI ID '{args.ai}' does not exist in this configuration")
elif default_ai_ID in config.ais: elif default_ai_ID in config.ais:
ai_conf = config.ais[default_ai_ID] ai_conf = config.ais[default_ai_ID]
else: else:
@ -23,12 +26,12 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI:
if ai_conf.name == 'openai': if ai_conf.name == 'openai':
ai = OpenAI(cast(OpenAIConfig, ai_conf)) ai = OpenAI(cast(OpenAIConfig, ai_conf))
if args.max_tokens:
ai.config.max_tokens = args.max_tokens
if args.config.temperature:
ai.config.temperature = args.temperature
if args.model: if args.model:
ai.config.model = args.model ai.config.model = args.model
if args.max_tokens:
ai.config.max_tokens = args.max_tokens
if args.temperature:
ai.config.temperature = args.temperature
return ai return ai
else: else:
raise AIError(f"AI '{args.ai}' is not supported") raise AIError(f"AI '{args.ai}' is not supported")

View File

@ -49,7 +49,7 @@ def create_parser() -> argparse.ArgumentParser:
# a parent parser for all commands that support AI configuration # a parent parser for all commands that support AI configuration
ai_parser = argparse.ArgumentParser(add_help=False) ai_parser = argparse.ArgumentParser(add_help=False)
ai_parser.add_argument('-A', '--AI', help='AI to use') ai_parser.add_argument('-A', '--AI', help='AI ID to use')
ai_parser.add_argument('-M', '--model', help='Model to use') ai_parser.add_argument('-M', '--model', help='Model to use')
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
ai_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ai_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)

48
tests/test_ai_factory.py Normal file
View File

@ -0,0 +1,48 @@
import argparse
import unittest
from unittest.mock import MagicMock
from chatmastermind.ai_factory import create_ai
from chatmastermind.configuration import Config
from chatmastermind.ai import AIError
from chatmastermind.ais.openai_cmm import OpenAI
class TestCreateAI(unittest.TestCase):
def setUp(self) -> None:
self.args = MagicMock(spec=argparse.Namespace)
self.args.ai = 'default'
self.args.model = None
self.args.max_tokens = None
self.args.temperature = None
def test_create_ai_from_args(self) -> None:
# Create an AI with the default configuration
config = Config()
self.args.ai = 'default'
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_ai_from_default(self) -> None:
self.args.ai = None
# Create an AI with the default configuration
config = Config()
ai = create_ai(self.args, config)
self.assertIsInstance(ai, OpenAI)
def test_create_empty_ai_error(self) -> None:
self.args.ai = None
# Create Config with empty AIs
config = Config()
config.ais = {}
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)
def test_create_unsupported_ai_error(self) -> None:
# Mock argparse.Namespace with ai='invalid_ai'
self.args.ai = 'invalid_ai'
# Create default Config
config = Config()
# Call create_ai function and assert that it raises AIError
with self.assertRaises(AIError):
create_ai(self.args, config)