Compare commits

..

2 Commits

5 changed files with 78 additions and 50 deletions

View File

@ -1,12 +1,20 @@
import yaml import yaml
from pathlib import Path from pathlib import Path
from typing import Type, TypeVar, Any from typing import Type, TypeVar, Any, Optional
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config') ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
class ConfigError(Exception):
pass
@dataclass @dataclass
class AIConfig: class AIConfig:
""" """
@ -56,9 +64,24 @@ class OpenAIConfig(AIConfig):
return asdict(self) return asdict(self)
def ai_type_instance(ai_type: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given type.
"""
if ai_type.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"AI type '{ai_type}' is not supported")
def create_default_ai_configs() -> dict[str, AIConfig]: def create_default_ai_configs() -> dict[str, AIConfig]:
openai_conf = OpenAIConfig() """
return {openai_conf.name: openai_conf} Create a dict containing default configurations for all supported AIs.
"""
return {ai_type_instance(ai_type).name: ai_type_instance(ai_type) for ai_type in supported_ais}
@dataclass @dataclass
@ -76,9 +99,14 @@ class Config:
""" """
Create Config from a dict. Create Config from a dict.
""" """
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for name, conf in source['ais'].items():
ai_conf = ai_type_instance(conf['type'], conf)
ais[name] = ai_conf
return cls( return cls(
db=str(source['db']), db=str(source['db']),
ais=source['ais'] # FIXME: call correct constructors ais=ais
) )
@classmethod @classmethod
@ -94,8 +122,8 @@ class Config:
with open(path, 'r') as f: with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader) source = yaml.load(f, Loader=yaml.FullLoader)
# add the AI name to the config (for easy internal access) # add the AI name to the config (for easy internal access)
for ai_name, ai_conf in source['ais'].items(): for name, conf in source['ais'].items():
ai_conf['name'] = ai_name conf['name'] = name
return cls.from_dict(source) return cls.from_dict(source)
def to_file(self, file_path: Path) -> None: def to_file(self, file_path: Path) -> None:

View File

@ -1,3 +1,4 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
import time import time
@ -6,10 +7,9 @@ from unittest.mock import patch
from chatmastermind.tags import TagLine from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
from .test_main import CmmTestCase
class TestChat(CmmTestCase): class TestChat(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
@ -131,7 +131,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output) self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(CmmTestCase): class TestChatDB(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory() self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory()

View File

@ -1,35 +1,35 @@
import unittest # import unittest
# import io # import io
# import pathlib # import pathlib
# import argparse # import argparse
# from chatmastermind.utils import terminal_width # from chatmastermind.utils import terminal_width
# from chatmastermind.main import create_parser, ask_cmd # from chatmastermind.main import create_parser, ask_cmd
# from chatmastermind.api_client import ai # from chatmastermind.api_client import ai
from chatmastermind.configuration import Config # from chatmastermind.configuration import Config
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data # from chatmastermind.storage import create_chat_hist, save_answers, dump_data
# from unittest import mock # from unittest import mock
# from unittest.mock import patch, MagicMock, Mock, ANY # from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase): # class CmmTestCase(unittest.TestCase):
""" # """
Base class for all cmm testcases. # Base class for all cmm testcases.
""" # """
def dummy_config(self, db: str) -> Config: # def dummy_config(self, db: str) -> Config:
""" # """
Creates a dummy configuration. # Creates a dummy configuration.
""" # """
return Config.from_dict( # return Config.from_dict(
{'system': 'dummy_system', # {'system': 'dummy_system',
'db': db, # 'db': db,
'openai': {'api_key': 'dummy_key', # 'openai': {'api_key': 'dummy_key',
'model': 'dummy_model', # 'model': 'dummy_model',
'max_tokens': 4000, # 'max_tokens': 4000,
'temperature': 1.0, # 'temperature': 1.0,
'top_p': 1, # 'top_p': 1,
'frequency_penalty': 0, # 'frequency_penalty': 0,
'presence_penalty': 0}} # 'presence_penalty': 0}}
) # )
# #
# #
# class TestCreateChat(CmmTestCase): # class TestCreateChat(CmmTestCase):

View File

@ -1,12 +1,12 @@
import unittest
import pathlib import pathlib
import tempfile import tempfile
from typing import cast from typing import cast
from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(CmmTestCase): class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None: def test_source_code_with_include_delims(self) -> None:
text = """ text = """
Some text before the code block Some text before the code block
@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase):
self.assertEqual(result, expected_result) self.assertEqual(result, expected_result)
class QuestionTestCase(CmmTestCase): class QuestionTestCase(unittest.TestCase):
def test_question_with_header(self) -> None: def test_question_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Question(f"{Question.txt_header}\nWhat is your name?") Question(f"{Question.txt_header}\nWhat is your name?")
@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase):
self.assertEqual(question, "What is your favorite color?") self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase): class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno") Answer(f"{Answer.txt_header}\nno")
@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase):
self.assertEqual(answer, "No") self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(CmmTestCase): class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -160,7 +160,7 @@ This is a question.
self.message_complete.file_path = self.file_path self.message_complete.file_path = self.file_path
class MessageToFileYamlTestCase(CmmTestCase): class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase):
self.assertEqual(content, expected_content) self.assertEqual(content, expected_content)
class MessageFromFileTxtTestCase(CmmTestCase): class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -388,7 +388,7 @@ This is a question.
self.assertIsNone(message) self.assertIsNone(message)
class MessageFromFileYamlTestCase(CmmTestCase): class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertIsNone(message) self.assertIsNone(message)
class TagsFromFileTestCase(CmmTestCase): class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name) self.file_path_txt = pathlib.Path(self.file_txt.name)
@ -663,7 +663,7 @@ This is an answer.
self.assertSetEqual(tags, set()) self.assertSetEqual(tags, set())
class TagsFromDirTestCase(CmmTestCase): class TagsFromDirTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory()
@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase):
self.assertSetEqual(all_tags, set()) self.assertSetEqual(all_tags, set())
class MessageIDTestCase(CmmTestCase): class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name) self.file_path = pathlib.Path(self.file.name)
@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase):
self.message_no_file_path.msg_id() self.message_no_file_path.msg_id()
class MessageHashTestCase(CmmTestCase): class MessageHashTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase):
self.assertIn(msg, msgs) self.assertIn(msg, msgs)
class MessageTagsStrTestCase(CmmTestCase): class MessageTagsStrTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('tag1')}, tags={Tag('tag1')},
@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase):
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(CmmTestCase): class MessageFilterTagsTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase):
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(CmmTestCase): class MessageInTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'), self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase):
self.assertFalse(message_in(self.message1, [self.message2])) self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(CmmTestCase): class MessageRenameTagsTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')}, tags={Tag('atag1'), Tag('btag2')},
@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase):
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(CmmTestCase): class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.message = Message(Question('This is a question.'), self.message = Message(Question('This is a question.'),
Answer('This is an answer.'), Answer('This is an answer.'),

View File

@ -1,8 +1,8 @@
from .test_main import CmmTestCase import unittest
from chatmastermind.tags import Tag, TagLine, TagError from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(CmmTestCase): class TestTag(unittest.TestCase):
def test_valid_tag(self) -> None: def test_valid_tag(self) -> None:
tag = Tag('mytag') tag = Tag('mytag')
self.assertEqual(tag, 'mytag') self.assertEqual(tag, 'mytag')
@ -18,7 +18,7 @@ class TestTag(CmmTestCase):
self.assertEqual(Tag.alternative_separators, [',']) self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(CmmTestCase): class TestTagLine(unittest.TestCase):
def test_valid_tagline(self) -> None: def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2') tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2')