Compare commits

..

2 Commits

5 changed files with 78 additions and 50 deletions

View File

@ -1,12 +1,20 @@
import yaml
from pathlib import Path
from typing import Type, TypeVar, Any
from typing import Type, TypeVar, Any, Optional
from dataclasses import dataclass, asdict, field
ConfigInst = TypeVar('ConfigInst', bound='Config')
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
supported_ais: list[str] = ['openai']
class ConfigError(Exception):
pass
@dataclass
class AIConfig:
"""
@ -56,9 +64,24 @@ class OpenAIConfig(AIConfig):
return asdict(self)
def ai_type_instance(ai_type: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
"""
Creates an AIConfig instance of the given type.
"""
if ai_type.lower() == 'openai':
if conf_dict is None:
return OpenAIConfig()
else:
return OpenAIConfig.from_dict(conf_dict)
else:
raise ConfigError(f"AI type '{ai_type}' is not supported")
def create_default_ai_configs() -> dict[str, AIConfig]:
openai_conf = OpenAIConfig()
return {openai_conf.name: openai_conf}
"""
Create a dict containing default configurations for all supported AIs.
"""
return {ai_type_instance(ai_type).name: ai_type_instance(ai_type) for ai_type in supported_ais}
@dataclass
@ -76,9 +99,14 @@ class Config:
"""
Create Config from a dict.
"""
# create the correct AI type instances
ais: dict[str, AIConfig] = {}
for name, conf in source['ais'].items():
ai_conf = ai_type_instance(conf['type'], conf)
ais[name] = ai_conf
return cls(
db=str(source['db']),
ais=source['ais'] # FIXME: call correct constructors
ais=ais
)
@classmethod
@ -94,8 +122,8 @@ class Config:
with open(path, 'r') as f:
source = yaml.load(f, Loader=yaml.FullLoader)
# add the AI name to the config (for easy internal access)
for ai_name, ai_conf in source['ais'].items():
ai_conf['name'] = ai_name
for name, conf in source['ais'].items():
conf['name'] = name
return cls.from_dict(source)
def to_file(self, file_path: Path) -> None:

View File

@ -1,3 +1,4 @@
import unittest
import pathlib
import tempfile
import time
@ -6,10 +7,9 @@ from unittest.mock import patch
from chatmastermind.tags import TagLine
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
from .test_main import CmmTestCase
class TestChat(CmmTestCase):
class TestChat(unittest.TestCase):
def setUp(self) -> None:
self.chat = Chat([])
self.message1 = Message(Question('Question 1'),
@ -131,7 +131,7 @@ Answer 2
self.assertEqual(mock_stdout.getvalue(), expected_output)
class TestChatDB(CmmTestCase):
class TestChatDB(unittest.TestCase):
def setUp(self) -> None:
self.db_path = tempfile.TemporaryDirectory()
self.cache_path = tempfile.TemporaryDirectory()

View File

@ -1,35 +1,35 @@
import unittest
# import unittest
# import io
# import pathlib
# import argparse
# from chatmastermind.utils import terminal_width
# from chatmastermind.main import create_parser, ask_cmd
# from chatmastermind.api_client import ai
from chatmastermind.configuration import Config
# from chatmastermind.configuration import Config
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data
# from unittest import mock
# from unittest.mock import patch, MagicMock, Mock, ANY
class CmmTestCase(unittest.TestCase):
"""
Base class for all cmm testcases.
"""
def dummy_config(self, db: str) -> Config:
"""
Creates a dummy configuration.
"""
return Config.from_dict(
{'system': 'dummy_system',
'db': db,
'openai': {'api_key': 'dummy_key',
'model': 'dummy_model',
'max_tokens': 4000,
'temperature': 1.0,
'top_p': 1,
'frequency_penalty': 0,
'presence_penalty': 0}}
)
# class CmmTestCase(unittest.TestCase):
# """
# Base class for all cmm testcases.
# """
# def dummy_config(self, db: str) -> Config:
# """
# Creates a dummy configuration.
# """
# return Config.from_dict(
# {'system': 'dummy_system',
# 'db': db,
# 'openai': {'api_key': 'dummy_key',
# 'model': 'dummy_model',
# 'max_tokens': 4000,
# 'temperature': 1.0,
# 'top_p': 1,
# 'frequency_penalty': 0,
# 'presence_penalty': 0}}
# )
#
#
# class TestCreateChat(CmmTestCase):

View File

@ -1,12 +1,12 @@
import unittest
import pathlib
import tempfile
from typing import cast
from .test_main import CmmTestCase
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
from chatmastermind.tags import Tag, TagLine
class SourceCodeTestCase(CmmTestCase):
class SourceCodeTestCase(unittest.TestCase):
def test_source_code_with_include_delims(self) -> None:
text = """
Some text before the code block
@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase):
self.assertEqual(result, expected_result)
class QuestionTestCase(CmmTestCase):
class QuestionTestCase(unittest.TestCase):
def test_question_with_header(self) -> None:
with self.assertRaises(MessageError):
Question(f"{Question.txt_header}\nWhat is your name?")
@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase):
self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase):
class AnswerTestCase(unittest.TestCase):
def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError):
Answer(f"{Answer.txt_header}\nno")
@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase):
self.assertEqual(answer, "No")
class MessageToFileTxtTestCase(CmmTestCase):
class MessageToFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
@ -160,7 +160,7 @@ This is a question.
self.message_complete.file_path = self.file_path
class MessageToFileYamlTestCase(CmmTestCase):
class MessageToFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase):
self.assertEqual(content, expected_content)
class MessageFromFileTxtTestCase(CmmTestCase):
class MessageFromFileTxtTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
@ -388,7 +388,7 @@ This is a question.
self.assertIsNone(message)
class MessageFromFileYamlTestCase(CmmTestCase):
class MessageFromFileYamlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
self.file_path = pathlib.Path(self.file.name)
@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase):
self.assertIsNone(message)
class TagsFromFileTestCase(CmmTestCase):
class TagsFromFileTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path_txt = pathlib.Path(self.file_txt.name)
@ -663,7 +663,7 @@ This is an answer.
self.assertSetEqual(tags, set())
class TagsFromDirTestCase(CmmTestCase):
class TagsFromDirTestCase(unittest.TestCase):
def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase):
self.assertSetEqual(all_tags, set())
class MessageIDTestCase(CmmTestCase):
class MessageIDTestCase(unittest.TestCase):
def setUp(self) -> None:
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
self.file_path = pathlib.Path(self.file.name)
@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase):
self.message_no_file_path.msg_id()
class MessageHashTestCase(CmmTestCase):
class MessageHashTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
tags={Tag('tag1')},
@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase):
self.assertIn(msg, msgs)
class MessageTagsStrTestCase(CmmTestCase):
class MessageTagsStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('tag1')},
@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase):
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
class MessageFilterTagsTestCase(CmmTestCase):
class MessageFilterTagsTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase):
self.assertSetEqual(tags_cont, {Tag('btag2')})
class MessageInTestCase(CmmTestCase):
class MessageInTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message1 = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase):
self.assertFalse(message_in(self.message1, [self.message2]))
class MessageRenameTagsTestCase(CmmTestCase):
class MessageRenameTagsTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
tags={Tag('atag1'), Tag('btag2')},
@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase):
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
class MessageToStrTestCase(CmmTestCase):
class MessageToStrTestCase(unittest.TestCase):
def setUp(self) -> None:
self.message = Message(Question('This is a question.'),
Answer('This is an answer.'),

View File

@ -1,8 +1,8 @@
from .test_main import CmmTestCase
import unittest
from chatmastermind.tags import Tag, TagLine, TagError
class TestTag(CmmTestCase):
class TestTag(unittest.TestCase):
def test_valid_tag(self) -> None:
tag = Tag('mytag')
self.assertEqual(tag, 'mytag')
@ -18,7 +18,7 @@ class TestTag(CmmTestCase):
self.assertEqual(Tag.alternative_separators, [','])
class TestTagLine(CmmTestCase):
class TestTagLine(unittest.TestCase):
def test_valid_tagline(self) -> None:
tagline = TagLine('TAGS: tag1 tag2')
self.assertEqual(tagline, 'TAGS: tag1 tag2')