Compare commits
2 Commits
13c4127827
...
b2401d57ae
| Author | SHA1 | Date | |
|---|---|---|---|
| b2401d57ae | |||
| 893917e455 |
@ -1,12 +1,20 @@
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type, TypeVar, Any
|
from typing import Type, TypeVar, Any, Optional
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
|
|
||||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||||
|
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
|
||||||
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||||
|
|
||||||
|
|
||||||
|
supported_ais: list[str] = ['openai']
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AIConfig:
|
class AIConfig:
|
||||||
"""
|
"""
|
||||||
@ -56,9 +64,24 @@ class OpenAIConfig(AIConfig):
|
|||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
def ai_type_instance(ai_type: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
||||||
|
"""
|
||||||
|
Creates an AIConfig instance of the given type.
|
||||||
|
"""
|
||||||
|
if ai_type.lower() == 'openai':
|
||||||
|
if conf_dict is None:
|
||||||
|
return OpenAIConfig()
|
||||||
|
else:
|
||||||
|
return OpenAIConfig.from_dict(conf_dict)
|
||||||
|
else:
|
||||||
|
raise ConfigError(f"AI type '{ai_type}' is not supported")
|
||||||
|
|
||||||
|
|
||||||
def create_default_ai_configs() -> dict[str, AIConfig]:
|
def create_default_ai_configs() -> dict[str, AIConfig]:
|
||||||
openai_conf = OpenAIConfig()
|
"""
|
||||||
return {openai_conf.name: openai_conf}
|
Create a dict containing default configurations for all supported AIs.
|
||||||
|
"""
|
||||||
|
return {ai_type_instance(ai_type).name: ai_type_instance(ai_type) for ai_type in supported_ais}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -76,9 +99,14 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
Create Config from a dict.
|
Create Config from a dict.
|
||||||
"""
|
"""
|
||||||
|
# create the correct AI type instances
|
||||||
|
ais: dict[str, AIConfig] = {}
|
||||||
|
for name, conf in source['ais'].items():
|
||||||
|
ai_conf = ai_type_instance(conf['type'], conf)
|
||||||
|
ais[name] = ai_conf
|
||||||
return cls(
|
return cls(
|
||||||
db=str(source['db']),
|
db=str(source['db']),
|
||||||
ais=source['ais'] # FIXME: call correct constructors
|
ais=ais
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -94,8 +122,8 @@ class Config:
|
|||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
source = yaml.load(f, Loader=yaml.FullLoader)
|
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
# add the AI name to the config (for easy internal access)
|
# add the AI name to the config (for easy internal access)
|
||||||
for ai_name, ai_conf in source['ais'].items():
|
for name, conf in source['ais'].items():
|
||||||
ai_conf['name'] = ai_name
|
conf['name'] = name
|
||||||
return cls.from_dict(source)
|
return cls.from_dict(source)
|
||||||
|
|
||||||
def to_file(self, file_path: Path) -> None:
|
def to_file(self, file_path: Path) -> None:
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import unittest
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
@ -6,10 +7,9 @@ from unittest.mock import patch
|
|||||||
from chatmastermind.tags import TagLine
|
from chatmastermind.tags import TagLine
|
||||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||||
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
|
from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError
|
||||||
from .test_main import CmmTestCase
|
|
||||||
|
|
||||||
|
|
||||||
class TestChat(CmmTestCase):
|
class TestChat(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.chat = Chat([])
|
self.chat = Chat([])
|
||||||
self.message1 = Message(Question('Question 1'),
|
self.message1 = Message(Question('Question 1'),
|
||||||
@ -131,7 +131,7 @@ Answer 2
|
|||||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||||
|
|
||||||
|
|
||||||
class TestChatDB(CmmTestCase):
|
class TestChatDB(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.db_path = tempfile.TemporaryDirectory()
|
self.db_path = tempfile.TemporaryDirectory()
|
||||||
self.cache_path = tempfile.TemporaryDirectory()
|
self.cache_path = tempfile.TemporaryDirectory()
|
||||||
|
|||||||
@ -1,35 +1,35 @@
|
|||||||
import unittest
|
# import unittest
|
||||||
# import io
|
# import io
|
||||||
# import pathlib
|
# import pathlib
|
||||||
# import argparse
|
# import argparse
|
||||||
# from chatmastermind.utils import terminal_width
|
# from chatmastermind.utils import terminal_width
|
||||||
# from chatmastermind.main import create_parser, ask_cmd
|
# from chatmastermind.main import create_parser, ask_cmd
|
||||||
# from chatmastermind.api_client import ai
|
# from chatmastermind.api_client import ai
|
||||||
from chatmastermind.configuration import Config
|
# from chatmastermind.configuration import Config
|
||||||
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
# from chatmastermind.storage import create_chat_hist, save_answers, dump_data
|
||||||
# from unittest import mock
|
# from unittest import mock
|
||||||
# from unittest.mock import patch, MagicMock, Mock, ANY
|
# from unittest.mock import patch, MagicMock, Mock, ANY
|
||||||
|
|
||||||
|
|
||||||
class CmmTestCase(unittest.TestCase):
|
# class CmmTestCase(unittest.TestCase):
|
||||||
"""
|
# """
|
||||||
Base class for all cmm testcases.
|
# Base class for all cmm testcases.
|
||||||
"""
|
# """
|
||||||
def dummy_config(self, db: str) -> Config:
|
# def dummy_config(self, db: str) -> Config:
|
||||||
"""
|
# """
|
||||||
Creates a dummy configuration.
|
# Creates a dummy configuration.
|
||||||
"""
|
# """
|
||||||
return Config.from_dict(
|
# return Config.from_dict(
|
||||||
{'system': 'dummy_system',
|
# {'system': 'dummy_system',
|
||||||
'db': db,
|
# 'db': db,
|
||||||
'openai': {'api_key': 'dummy_key',
|
# 'openai': {'api_key': 'dummy_key',
|
||||||
'model': 'dummy_model',
|
# 'model': 'dummy_model',
|
||||||
'max_tokens': 4000,
|
# 'max_tokens': 4000,
|
||||||
'temperature': 1.0,
|
# 'temperature': 1.0,
|
||||||
'top_p': 1,
|
# 'top_p': 1,
|
||||||
'frequency_penalty': 0,
|
# 'frequency_penalty': 0,
|
||||||
'presence_penalty': 0}}
|
# 'presence_penalty': 0}}
|
||||||
)
|
# )
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# class TestCreateChat(CmmTestCase):
|
# class TestCreateChat(CmmTestCase):
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
|
import unittest
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from .test_main import CmmTestCase
|
|
||||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
||||||
from chatmastermind.tags import Tag, TagLine
|
from chatmastermind.tags import Tag, TagLine
|
||||||
|
|
||||||
|
|
||||||
class SourceCodeTestCase(CmmTestCase):
|
class SourceCodeTestCase(unittest.TestCase):
|
||||||
def test_source_code_with_include_delims(self) -> None:
|
def test_source_code_with_include_delims(self) -> None:
|
||||||
text = """
|
text = """
|
||||||
Some text before the code block
|
Some text before the code block
|
||||||
@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase):
|
|||||||
self.assertEqual(result, expected_result)
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
|
||||||
class QuestionTestCase(CmmTestCase):
|
class QuestionTestCase(unittest.TestCase):
|
||||||
def test_question_with_header(self) -> None:
|
def test_question_with_header(self) -> None:
|
||||||
with self.assertRaises(MessageError):
|
with self.assertRaises(MessageError):
|
||||||
Question(f"{Question.txt_header}\nWhat is your name?")
|
Question(f"{Question.txt_header}\nWhat is your name?")
|
||||||
@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase):
|
|||||||
self.assertEqual(question, "What is your favorite color?")
|
self.assertEqual(question, "What is your favorite color?")
|
||||||
|
|
||||||
|
|
||||||
class AnswerTestCase(CmmTestCase):
|
class AnswerTestCase(unittest.TestCase):
|
||||||
def test_answer_with_header(self) -> None:
|
def test_answer_with_header(self) -> None:
|
||||||
with self.assertRaises(MessageError):
|
with self.assertRaises(MessageError):
|
||||||
Answer(f"{Answer.txt_header}\nno")
|
Answer(f"{Answer.txt_header}\nno")
|
||||||
@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase):
|
|||||||
self.assertEqual(answer, "No")
|
self.assertEqual(answer, "No")
|
||||||
|
|
||||||
|
|
||||||
class MessageToFileTxtTestCase(CmmTestCase):
|
class MessageToFileTxtTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
@ -160,7 +160,7 @@ This is a question.
|
|||||||
self.message_complete.file_path = self.file_path
|
self.message_complete.file_path = self.file_path
|
||||||
|
|
||||||
|
|
||||||
class MessageToFileYamlTestCase(CmmTestCase):
|
class MessageToFileYamlTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase):
|
|||||||
self.assertEqual(content, expected_content)
|
self.assertEqual(content, expected_content)
|
||||||
|
|
||||||
|
|
||||||
class MessageFromFileTxtTestCase(CmmTestCase):
|
class MessageFromFileTxtTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
@ -388,7 +388,7 @@ This is a question.
|
|||||||
self.assertIsNone(message)
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
|
||||||
class MessageFromFileYamlTestCase(CmmTestCase):
|
class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase):
|
|||||||
self.assertIsNone(message)
|
self.assertIsNone(message)
|
||||||
|
|
||||||
|
|
||||||
class TagsFromFileTestCase(CmmTestCase):
|
class TagsFromFileTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||||
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
||||||
@ -663,7 +663,7 @@ This is an answer.
|
|||||||
self.assertSetEqual(tags, set())
|
self.assertSetEqual(tags, set())
|
||||||
|
|
||||||
|
|
||||||
class TagsFromDirTestCase(CmmTestCase):
|
class TagsFromDirTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
|
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
|
||||||
@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase):
|
|||||||
self.assertSetEqual(all_tags, set())
|
self.assertSetEqual(all_tags, set())
|
||||||
|
|
||||||
|
|
||||||
class MessageIDTestCase(CmmTestCase):
|
class MessageIDTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||||
self.file_path = pathlib.Path(self.file.name)
|
self.file_path = pathlib.Path(self.file.name)
|
||||||
@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase):
|
|||||||
self.message_no_file_path.msg_id()
|
self.message_no_file_path.msg_id()
|
||||||
|
|
||||||
|
|
||||||
class MessageHashTestCase(CmmTestCase):
|
class MessageHashTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.message1 = Message(Question('This is a question.'),
|
self.message1 = Message(Question('This is a question.'),
|
||||||
tags={Tag('tag1')},
|
tags={Tag('tag1')},
|
||||||
@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase):
|
|||||||
self.assertIn(msg, msgs)
|
self.assertIn(msg, msgs)
|
||||||
|
|
||||||
|
|
||||||
class MessageTagsStrTestCase(CmmTestCase):
|
class MessageTagsStrTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.message = Message(Question('This is a question.'),
|
self.message = Message(Question('This is a question.'),
|
||||||
tags={Tag('tag1')},
|
tags={Tag('tag1')},
|
||||||
@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase):
|
|||||||
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
||||||
|
|
||||||
|
|
||||||
class MessageFilterTagsTestCase(CmmTestCase):
|
class MessageFilterTagsTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.message = Message(Question('This is a question.'),
|
self.message = Message(Question('This is a question.'),
|
||||||
tags={Tag('atag1'), Tag('btag2')},
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase):
|
|||||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||||
|
|
||||||
|
|
||||||
class MessageInTestCase(CmmTestCase):
|
class MessageInTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.message1 = Message(Question('This is a question.'),
|
self.message1 = Message(Question('This is a question.'),
|
||||||
tags={Tag('atag1'), Tag('btag2')},
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase):
|
|||||||
self.assertFalse(message_in(self.message1, [self.message2]))
|
self.assertFalse(message_in(self.message1, [self.message2]))
|
||||||
|
|
||||||
|
|
||||||
class MessageRenameTagsTestCase(CmmTestCase):
|
class MessageRenameTagsTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.message = Message(Question('This is a question.'),
|
self.message = Message(Question('This is a question.'),
|
||||||
tags={Tag('atag1'), Tag('btag2')},
|
tags={Tag('atag1'), Tag('btag2')},
|
||||||
@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase):
|
|||||||
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
|
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
|
||||||
|
|
||||||
|
|
||||||
class MessageToStrTestCase(CmmTestCase):
|
class MessageToStrTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.message = Message(Question('This is a question.'),
|
self.message = Message(Question('This is a question.'),
|
||||||
Answer('This is an answer.'),
|
Answer('This is an answer.'),
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from .test_main import CmmTestCase
|
import unittest
|
||||||
from chatmastermind.tags import Tag, TagLine, TagError
|
from chatmastermind.tags import Tag, TagLine, TagError
|
||||||
|
|
||||||
|
|
||||||
class TestTag(CmmTestCase):
|
class TestTag(unittest.TestCase):
|
||||||
def test_valid_tag(self) -> None:
|
def test_valid_tag(self) -> None:
|
||||||
tag = Tag('mytag')
|
tag = Tag('mytag')
|
||||||
self.assertEqual(tag, 'mytag')
|
self.assertEqual(tag, 'mytag')
|
||||||
@ -18,7 +18,7 @@ class TestTag(CmmTestCase):
|
|||||||
self.assertEqual(Tag.alternative_separators, [','])
|
self.assertEqual(Tag.alternative_separators, [','])
|
||||||
|
|
||||||
|
|
||||||
class TestTagLine(CmmTestCase):
|
class TestTagLine(unittest.TestCase):
|
||||||
def test_valid_tagline(self) -> None:
|
def test_valid_tagline(self) -> None:
|
||||||
tagline = TagLine('TAGS: tag1 tag2')
|
tagline = TagLine('TAGS: tag1 tag2')
|
||||||
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user