Compare commits
2 Commits
13c4127827
...
b2401d57ae
| Author | SHA1 | Date | |
|---|---|---|---|
| b2401d57ae | |||
| 893917e455 |
@ -1,12 +1,20 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar, Any
|
||||
from typing import Type, TypeVar, Any, Optional
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
|
||||
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||
|
||||
|
||||
supported_ais: list[str] = ['openai']
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""
|
||||
@ -56,9 +64,24 @@ class OpenAIConfig(AIConfig):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def ai_type_instance(ai_type: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
||||
"""
|
||||
Creates an AIConfig instance of the given type.
|
||||
"""
|
||||
if ai_type.lower() == 'openai':
|
||||
if conf_dict is None:
|
||||
return OpenAIConfig()
|
||||
else:
|
||||
return OpenAIConfig.from_dict(conf_dict)
|
||||
else:
|
||||
raise ConfigError(f"AI type '{ai_type}' is not supported")
|
||||
|
||||
|
||||
def create_default_ai_configs() -> dict[str, AIConfig]:
|
||||
openai_conf = OpenAIConfig()
|
||||
return {openai_conf.name: openai_conf}
|
||||
"""
|
||||
Create a dict containing default configurations for all supported AIs.
|
||||
"""
|
||||
return {ai_type_instance(ai_type).name: ai_type_instance(ai_type) for ai_type in supported_ais}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -76,9 +99,14 @@ class Config:
|
||||
"""
|
||||
Create Config from a dict.
|
||||
"""
|
||||
# create the correct AI type instances
|
||||
ais: dict[str, AIConfig] = {}
|
||||
for name, conf in source['ais'].items():
|
||||
ai_conf = ai_type_instance(conf['type'], conf)
|
||||
ais[name] = ai_conf
|
||||
return cls(
|
||||
db=str(source['db']),
|
||||
ais=source['ais'] # FIXME: call correct constructors
|
||||
ais=ais
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -94,8 +122,8 @@ class Config:
|
||||
with open(path, 'r') as f:
|
||||
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# add the AI name to the config (for easy internal access)
|
||||
for ai_name, ai_conf in source['ais'].items():
|
||||
ai_conf['name'] = ai_name
|
||||
for name, conf in source['ais'].items():
|
||||
conf['name'] = name
|
||||
return cls.from_dict(source)
|
||||
|
||||
def to_file(self, file_path: Path) -> None:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
@ -6,10 +7,9 @@ from unittest.mock import patch
|
||||
from chatmastermind.tags import TagLine
|
||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
|
||||
from .test_main import CmmTestCase
|
||||
|
||||
|
||||
class TestChat(CmmTestCase):
|
||||
class TestChat(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.chat = Chat([])
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
@ -131,7 +131,7 @@ Answer 2
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
|
||||
class TestChatDB(CmmTestCase):
|
||||
class TestChatDB(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.db_path = tempfile.TemporaryDirectory()
|
||||
self.cache_path = tempfile.TemporaryDirectory()
|
||||
|
||||
@ -1,35 +1,35 @@
|
||||
import unittest
|
||||
# import unittest
|
||||
# import io
|
||||
# import pathlib
|
||||
# import argparse
|
||||
# from chatmastermind.utils import terminal_width
|
||||
# from chatmastermind.main import create_parser, ask_cmd
|
||||
# from chatmastermind.api_client import ai
|
||||
from chatmastermind.configuration import Config
|
||||
# from chatmastermind.configuration import Config
|
||||
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
||||
# from unittest import mock
|
||||
# from unittest.mock import patch, MagicMock, Mock, ANY
|
||||
|
||||
|
||||
class CmmTestCase(unittest.TestCase):
|
||||
"""
|
||||
Base class for all cmm testcases.
|
||||
"""
|
||||
def dummy_config(self, db: str) -> Config:
|
||||
"""
|
||||
Creates a dummy configuration.
|
||||
"""
|
||||
return Config.from_dict(
|
||||
{'system': 'dummy_system',
|
||||
'db': db,
|
||||
'openai': {'api_key': 'dummy_key',
|
||||
'model': 'dummy_model',
|
||||
'max_tokens': 4000,
|
||||
'temperature': 1.0,
|
||||
'top_p': 1,
|
||||
'frequency_penalty': 0,
|
||||
'presence_penalty': 0}}
|
||||
)
|
||||
# class CmmTestCase(unittest.TestCase):
|
||||
# """
|
||||
# Base class for all cmm testcases.
|
||||
# """
|
||||
# def dummy_config(self, db: str) -> Config:
|
||||
# """
|
||||
# Creates a dummy configuration.
|
||||
# """
|
||||
# return Config.from_dict(
|
||||
# {'system': 'dummy_system',
|
||||
# 'db': db,
|
||||
# 'openai': {'api_key': 'dummy_key',
|
||||
# 'model': 'dummy_model',
|
||||
# 'max_tokens': 4000,
|
||||
# 'temperature': 1.0,
|
||||
# 'top_p': 1,
|
||||
# 'frequency_penalty': 0,
|
||||
# 'presence_penalty': 0}}
|
||||
# )
|
||||
#
|
||||
#
|
||||
# class TestCreateChat(CmmTestCase):
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
import tempfile
|
||||
from typing import cast
|
||||
from .test_main import CmmTestCase
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
||||
from chatmastermind.tags import Tag, TagLine
|
||||
|
||||
|
||||
class SourceCodeTestCase(CmmTestCase):
|
||||
class SourceCodeTestCase(unittest.TestCase):
|
||||
def test_source_code_with_include_delims(self) -> None:
|
||||
text = """
|
||||
Some text before the code block
|
||||
@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase):
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
class QuestionTestCase(CmmTestCase):
|
||||
class QuestionTestCase(unittest.TestCase):
|
||||
def test_question_with_header(self) -> None:
|
||||
with self.assertRaises(MessageError):
|
||||
Question(f"{Question.txt_header}\nWhat is your name?")
|
||||
@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase):
|
||||
self.assertEqual(question, "What is your favorite color?")
|
||||
|
||||
|
||||
class AnswerTestCase(CmmTestCase):
|
||||
class AnswerTestCase(unittest.TestCase):
|
||||
def test_answer_with_header(self) -> None:
|
||||
with self.assertRaises(MessageError):
|
||||
Answer(f"{Answer.txt_header}\nno")
|
||||
@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase):
|
||||
self.assertEqual(answer, "No")
|
||||
|
||||
|
||||
class MessageToFileTxtTestCase(CmmTestCase):
|
||||
class MessageToFileTxtTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
@ -160,7 +160,7 @@ This is a question.
|
||||
self.message_complete.file_path = self.file_path
|
||||
|
||||
|
||||
class MessageToFileYamlTestCase(CmmTestCase):
|
||||
class MessageToFileYamlTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase):
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
|
||||
class MessageFromFileTxtTestCase(CmmTestCase):
|
||||
class MessageFromFileTxtTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
@ -388,7 +388,7 @@ This is a question.
|
||||
self.assertIsNone(message)
|
||||
|
||||
|
||||
class MessageFromFileYamlTestCase(CmmTestCase):
|
||||
class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase):
|
||||
self.assertIsNone(message)
|
||||
|
||||
|
||||
class TagsFromFileTestCase(CmmTestCase):
|
||||
class TagsFromFileTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
||||
@ -663,7 +663,7 @@ This is an answer.
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
|
||||
class TagsFromDirTestCase(CmmTestCase):
|
||||
class TagsFromDirTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
|
||||
@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase):
|
||||
self.assertSetEqual(all_tags, set())
|
||||
|
||||
|
||||
class MessageIDTestCase(CmmTestCase):
|
||||
class MessageIDTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase):
|
||||
self.message_no_file_path.msg_id()
|
||||
|
||||
|
||||
class MessageHashTestCase(CmmTestCase):
|
||||
class MessageHashTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message1 = Message(Question('This is a question.'),
|
||||
tags={Tag('tag1')},
|
||||
@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase):
|
||||
self.assertIn(msg, msgs)
|
||||
|
||||
|
||||
class MessageTagsStrTestCase(CmmTestCase):
|
||||
class MessageTagsStrTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('tag1')},
|
||||
@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase):
|
||||
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
||||
|
||||
|
||||
class MessageFilterTagsTestCase(CmmTestCase):
|
||||
class MessageFilterTagsTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase):
|
||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||
|
||||
|
||||
class MessageInTestCase(CmmTestCase):
|
||||
class MessageInTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message1 = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase):
|
||||
self.assertFalse(message_in(self.message1, [self.message2]))
|
||||
|
||||
|
||||
class MessageRenameTagsTestCase(CmmTestCase):
|
||||
class MessageRenameTagsTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase):
|
||||
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
|
||||
|
||||
|
||||
class MessageToStrTestCase(CmmTestCase):
|
||||
class MessageToStrTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from .test_main import CmmTestCase
|
||||
import unittest
|
||||
from chatmastermind.tags import Tag, TagLine, TagError
|
||||
|
||||
|
||||
class TestTag(CmmTestCase):
|
||||
class TestTag(unittest.TestCase):
|
||||
def test_valid_tag(self) -> None:
|
||||
tag = Tag('mytag')
|
||||
self.assertEqual(tag, 'mytag')
|
||||
@ -18,7 +18,7 @@ class TestTag(CmmTestCase):
|
||||
self.assertEqual(Tag.alternative_separators, [','])
|
||||
|
||||
|
||||
class TestTagLine(CmmTestCase):
|
||||
class TestTagLine(unittest.TestCase):
|
||||
def test_valid_tagline(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1 tag2')
|
||||
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user