From 05557cceb56715fffa740ba9f8e7474de765e520 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 18 Sep 2023 14:34:10 +0200 Subject: [PATCH] configuration: the cache folder can now be specified in the configuration file --- chatmastermind/commands/hist.py | 2 +- chatmastermind/commands/question.py | 2 +- chatmastermind/commands/tags.py | 2 +- chatmastermind/configuration.py | 2 ++ tests/test_configuration.py | 7 +++++++ 5 files changed, 12 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py index 88ed3be..5b14bd2 100644 --- a/chatmastermind/commands/hist.py +++ b/chatmastermind/commands/hist.py @@ -15,7 +15,7 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: tags_not=args.exclude_tags, question_contains=args.question, answer_contains=args.answer) - chat = ChatDB.from_dir(Path('.'), + chat = ChatDB.from_dir(Path(config.cache), Path(config.db), mfilter=mfilter) chat.print(args.source_code_only, diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 78a6c4e..faba681 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -84,7 +84,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), tags_and=args.and_tags if args.and_tags is not None else set(), tags_not=args.exclude_tags if args.exclude_tags is not None else set()) - chat = ChatDB.from_dir(cache_path=Path('.'), + chat = ChatDB.from_dir(cache_path=Path(config.cache), db_path=Path(config.db), mfilter=mfilter) # if it's a new question, create and store it immediately diff --git a/chatmastermind/commands/tags.py b/chatmastermind/commands/tags.py index 71574ff..61af13d 100644 --- a/chatmastermind/commands/tags.py +++ b/chatmastermind/commands/tags.py @@ -8,7 +8,7 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. """ - chat = ChatDB.from_dir(cache_path=Path('.'), + chat = ChatDB.from_dir(cache_path=Path(config.cache), db_path=Path(config.db)) if args.list: tags_freq = chat.msg_tags_frequency(args.prefix, args.contain) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index d1f9601..7dfa78a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -116,6 +116,7 @@ class Config: """ # all members have default values, so we can easily create # a default configuration + cache: str = '.' db: str = './db/' ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @@ -132,6 +133,7 @@ class Config: ai_conf = ai_config_instance(conf['name'], conf) ais[ID] = ai_conf return cls( + cache=str(source['cache']) if 'cache' in source else '.', db=str(source['db']), ais=ais ) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index ba8a5aa..3e866f2 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -57,6 +57,7 @@ class TestConfig(unittest.TestCase): def test_from_dict_should_create_config_from_dict(self) -> None: source_dict = { + 'cache': '.', 'db': './test_db/', 'ais': { 'myopenai': { @@ -73,6 +74,7 @@ class TestConfig(unittest.TestCase): } } config = Config.from_dict(source_dict) + self.assertEqual(config.cache, '.') self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) self.assertEqual(config.ais['myopenai'].name, 'openai') @@ -89,6 +91,7 @@ class TestConfig(unittest.TestCase): def test_from_file_should_load_config_from_file(self) -> None: source_dict = { + 'cache': './test_cache/', 'db': './test_db/', 'ais': { 'default': { @@ -108,6 +111,7 @@ class TestConfig(unittest.TestCase): yaml.dump(source_dict, f) config = Config.from_file(self.test_file.name) self.assertIsInstance(config, Config) + self.assertEqual(config.cache, './test_cache/') self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) self.assertIsInstance(config.ais['default'], AIConfig) @@ -115,6 +119,7 @@ class TestConfig(unittest.TestCase): def test_to_file_should_save_config_to_file(self) -> None: config = Config( + cache='./test_cache/', db='./test_db/', ais={ 'myopenai': OpenAIConfig( @@ -133,12 +138,14 @@ class TestConfig(unittest.TestCase): config.to_file(Path(self.test_file.name)) with open(self.test_file.name, 'r') as f: saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['cache'], './test_cache/') self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = { + 'cache': './test_cache/', 'db': './test_db/', 'ais': { 'default': {