Compare commits
3 Commits
1a4e56391c
...
1b5e84793a
| Author | SHA1 | Date | |
|---|---|---|---|
| 1b5e84793a | |||
| d1124dd0b3 | |||
| c630a956c6 |
@ -15,10 +15,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI:
|
|||||||
and configuration file.
|
and configuration file.
|
||||||
"""
|
"""
|
||||||
if args.ai:
|
if args.ai:
|
||||||
try:
|
ai_conf = config.ais[args.ai]
|
||||||
ai_conf = config.ais[args.ai]
|
|
||||||
except KeyError:
|
|
||||||
raise AIError(f"AI ID '{args.ai}' does not exist in this configuration")
|
|
||||||
elif default_ai_ID in config.ais:
|
elif default_ai_ID in config.ais:
|
||||||
ai_conf = config.ais[default_ai_ID]
|
ai_conf = config.ais[default_ai_ID]
|
||||||
else:
|
else:
|
||||||
@ -26,12 +23,12 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI:
|
|||||||
|
|
||||||
if ai_conf.name == 'openai':
|
if ai_conf.name == 'openai':
|
||||||
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
||||||
if args.model:
|
|
||||||
ai.config.model = args.model
|
|
||||||
if args.max_tokens:
|
if args.max_tokens:
|
||||||
ai.config.max_tokens = args.max_tokens
|
ai.config.max_tokens = args.max_tokens
|
||||||
if args.temperature:
|
if args.config.temperature:
|
||||||
ai.config.temperature = args.temperature
|
ai.config.temperature = args.temperature
|
||||||
|
if args.model:
|
||||||
|
ai.config.model = args.model
|
||||||
return ai
|
return ai
|
||||||
else:
|
else:
|
||||||
raise AIError(f"AI '{args.ai}' is not supported")
|
raise AIError(f"AI '{args.ai}' is not supported")
|
||||||
|
|||||||
@ -49,7 +49,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
# a parent parser for all commands that support AI configuration
|
# a parent parser for all commands that support AI configuration
|
||||||
ai_parser = argparse.ArgumentParser(add_help=False)
|
ai_parser = argparse.ArgumentParser(add_help=False)
|
||||||
ai_parser.add_argument('-A', '--AI', help='AI ID to use')
|
ai_parser.add_argument('-A', '--AI', help='AI to use')
|
||||||
ai_parser.add_argument('-M', '--model', help='Model to use')
|
ai_parser.add_argument('-M', '--model', help='Model to use')
|
||||||
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
|
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
|
||||||
ai_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
ai_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
||||||
|
|||||||
@ -1,48 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from chatmastermind.ai_factory import create_ai
|
|
||||||
from chatmastermind.configuration import Config
|
|
||||||
from chatmastermind.ai import AIError
|
|
||||||
from chatmastermind.ais.openai_cmm import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class TestCreateAI(unittest.TestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.args = MagicMock(spec=argparse.Namespace)
|
|
||||||
self.args.ai = 'default'
|
|
||||||
self.args.model = None
|
|
||||||
self.args.max_tokens = None
|
|
||||||
self.args.temperature = None
|
|
||||||
|
|
||||||
def test_create_ai_from_args(self) -> None:
|
|
||||||
# Create an AI with the default configuration
|
|
||||||
config = Config()
|
|
||||||
self.args.ai = 'default'
|
|
||||||
ai = create_ai(self.args, config)
|
|
||||||
self.assertIsInstance(ai, OpenAI)
|
|
||||||
|
|
||||||
def test_create_ai_from_default(self) -> None:
|
|
||||||
self.args.ai = None
|
|
||||||
# Create an AI with the default configuration
|
|
||||||
config = Config()
|
|
||||||
ai = create_ai(self.args, config)
|
|
||||||
self.assertIsInstance(ai, OpenAI)
|
|
||||||
|
|
||||||
def test_create_empty_ai_error(self) -> None:
|
|
||||||
self.args.ai = None
|
|
||||||
# Create Config with empty AIs
|
|
||||||
config = Config()
|
|
||||||
config.ais = {}
|
|
||||||
# Call create_ai function and assert that it raises AIError
|
|
||||||
with self.assertRaises(AIError):
|
|
||||||
create_ai(self.args, config)
|
|
||||||
|
|
||||||
def test_create_unsupported_ai_error(self) -> None:
|
|
||||||
# Mock argparse.Namespace with ai='invalid_ai'
|
|
||||||
self.args.ai = 'invalid_ai'
|
|
||||||
# Create default Config
|
|
||||||
config = Config()
|
|
||||||
# Call create_ai function and assert that it raises AIError
|
|
||||||
with self.assertRaises(AIError):
|
|
||||||
create_ai(self.args, config)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user