Compare commits

...

5 Commits

7 changed files with 125 additions and 40 deletions

View File

@ -0,0 +1,48 @@
"""
Implements the OpenAI client classes and functions.
"""
import openai
from ..message import Message
from ..chat import Chat
from ..ai import AI, AIResponse
class OpenAI(AI):
"""
The OpenAI AI client.
"""
def request(self,
question: Message,
context: Chat,
num_answers: int = 1) -> AIResponse:
"""
Make an AI request, asking the given question with the given
context (i. e. chat history). The nr. of requested answers
corresponds to the nr. of messages in the 'AIResponse'.
"""
# TODO:
# * transform given message and chat context into OpenAI format
# * make request
# * create a new Message for each answer and return them
# (writing Messages is done by the calles)
raise NotImplementedError
def models(self) -> list[str]:
"""
Return all models supported by this AI.
"""
raise NotImplementedError
def print_models(self) -> None:
"""
Print all models supported by the current AI.
"""
not_ready = []
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
if engine['ready']:
print(engine['id'])
else:
not_ready.append(engine['id'])
if len(not_ready) > 0:
print('\nNot ready: ' + ', '.join(not_ready))

View File

@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path,
messages: list[Message] = [] messages: list[Message] = []
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
for file_path in sorted(file_iter): for file_path in sorted(file_iter):
if file_path.is_file(): if file_path.is_file() and file_path.suffix in Message.file_suffixes:
try: try:
message = Message.from_file(file_path, mfilter) message = Message.from_file(file_path, mfilter)
if message: if message:
@ -127,7 +127,16 @@ class Chat:
tags: set[Tag] = set() tags: set[Tag] = set()
for m in self.messages: for m in self.messages:
tags |= m.filter_tags(prefix, contain) tags |= m.filter_tags(prefix, contain)
return tags return set(sorted(tags))
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
"""
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
"""
tags: list[Tag] = []
for m in self.messages:
tags += [tag for tag in m.filter_tags(prefix, contain)]
return {tag: tags.count(tag) for tag in sorted(tags)}
def tokens(self) -> int: def tokens(self) -> int:
""" """

View File

@ -7,10 +7,11 @@ import sys
import argcomplete import argcomplete
import argparse import argparse
import pathlib import pathlib
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data
from .api_client import ai, openai_api_key, print_models from .api_client import ai, openai_api_key, print_models
from .configuration import Config from .configuration import Config
from .chat import ChatDB
from itertools import zip_longest from itertools import zip_longest
from typing import Any from typing import Any
@ -61,8 +62,12 @@ def tag_cmd(args: argparse.Namespace, config: Config) -> None:
""" """
Handler for the 'tag' command. Handler for the 'tag' command.
""" """
chat = ChatDB.from_dir(cache_path=pathlib.Path('.'),
db_path=pathlib.Path(config.db))
if args.list: if args.list:
print_tags_frequency(get_tags(config, None)) tags_freq = chat.tags_frequency(args.prefix, args.contain)
for tag, freq in tags_freq.items():
print(f"- {tag}: {freq}")
def config_cmd(args: argparse.Namespace, config: Config) -> None: def config_cmd(args: argparse.Namespace, config: Config) -> None:
@ -195,6 +200,8 @@ def create_parser() -> argparse.ArgumentParser:
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
tag_group.add_argument('-l', '--list', help="List all tags and their frequency", tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
action='store_true') action='store_true')
tag_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
tag_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
# 'config' command parser # 'config' command parser
config_cmd_parser = cmdparser.add_parser('config', config_cmd_parser = cmdparser.add_parser('config',

View File

@ -128,29 +128,29 @@ class ModelLine(str):
return cls(' '.join([cls.prefix, model])) return cls(' '.join([cls.prefix, model]))
class Question(str): class Answer(str):
""" """
A single question with a defined header. A single answer with a defined header.
""" """
tokens: int = 0 # tokens used by this question tokens: int = 0 # tokens used by this answer
txt_header: ClassVar[str] = '=== QUESTION ===' txt_header: ClassVar[str] = '=== ANSWER ==='
yaml_key: ClassVar[str] = 'question' yaml_key: ClassVar[str] = 'answer'
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
""" """
Make sure the question string does not contain the header. Make sure the answer string does not contain the header as a whole line.
""" """
if cls.txt_header in string: if cls.txt_header in string.split('\n'):
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@classmethod @classmethod
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Question from a list of strings. Make sure strings do not contain the header.
""" """
if any(cls.txt_header in string for string in strings): if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'") raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip()) instance = super().__new__(cls, '\n'.join(strings).strip())
return instance return instance
@ -162,29 +162,33 @@ class Question(str):
return source_code(self, include_delims) return source_code(self, include_delims)
class Answer(str): class Question(str):
""" """
A single answer with a defined header. A single question with a defined header.
""" """
tokens: int = 0 # tokens used by this answer tokens: int = 0 # tokens used by this question
txt_header: ClassVar[str] = '=== ANSWER ===' txt_header: ClassVar[str] = '=== QUESTION ==='
yaml_key: ClassVar[str] = 'answer' yaml_key: ClassVar[str] = 'question'
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
""" """
Make sure the answer string does not contain the header. Make sure the question string does not contain the header as a whole line
(also not that from 'Answer', so it's always clear where the answer starts).
""" """
if cls.txt_header in string: string_lines = string.split('\n')
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") if cls.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
if Answer.txt_header in string_lines:
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
instance = super().__new__(cls, string) instance = super().__new__(cls, string)
return instance return instance
@classmethod @classmethod
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
""" """
Build Question from a list of strings. Make sure strings do not contain the header. Build Question from a list of strings. Make sure strings do not contain the header.
""" """
if any(cls.txt_header in string for string in strings): if cls.txt_header in strings:
raise MessageError(f"Question contains the header '{cls.txt_header}'") raise MessageError(f"Question contains the header '{cls.txt_header}'")
instance = super().__new__(cls, '\n'.join(strings).strip()) instance = super().__new__(cls, '\n'.join(strings).strip())
return instance return instance

View File

@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
print(message['content']) print(message['content'])
else: else:
print(f"{message['role'].upper()}: {message['content']}") print(f"{message['role'].upper()}: {message['content']}")
def print_tags_frequency(tags: list[str]) -> None:
for tag in sorted(set(tags)):
print(f"- {tag}: {tags.count(tag)}")

View File

@ -14,7 +14,7 @@ class TestChat(CmmTestCase):
self.chat = Chat([]) self.chat = Chat([])
self.message1 = Message(Question('Question 1'), self.message1 = Message(Question('Question 1'),
Answer('Answer 1'), Answer('Answer 1'),
{Tag('atag1')}, {Tag('atag1'), Tag('btag2')},
file_path=pathlib.Path('0001.txt')) file_path=pathlib.Path('0001.txt'))
self.message2 = Message(Question('Question 2'), self.message2 = Message(Question('Question 2'),
Answer('Answer 2'), Answer('Answer 2'),
@ -57,6 +57,11 @@ class TestChat(CmmTestCase):
tags_cont = self.chat.tags(contain='2') tags_cont = self.chat.tags(contain='2')
self.assertSetEqual(tags_cont, {Tag('btag2')}) self.assertSetEqual(tags_cont, {Tag('btag2')})
def test_tags_frequency(self) -> None:
self.chat.add_msgs([self.message1, self.message2])
tags_freq = self.chat.tags_frequency()
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_print(self, mock_stdout: StringIO) -> None: def test_print(self, mock_stdout: StringIO) -> None:
self.chat.add_msgs([self.message1, self.message2]) self.chat.add_msgs([self.message1, self.message2])
@ -83,7 +88,7 @@ Answer 2
Question 1 Question 1
{Answer.txt_header} {Answer.txt_header}
Answer 1 Answer 1
{TagLine.prefix} atag1 {TagLine.prefix} atag1 btag2
FILE: 0001.txt FILE: 0001.txt
{'-'*terminal_width()} {'-'*terminal_width()}
{Question.txt_header} {Question.txt_header}

View File

@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase):
class QuestionTestCase(CmmTestCase): class QuestionTestCase(CmmTestCase):
def test_question_with_prefix(self) -> None: def test_question_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Question("=== QUESTION === What is your name?") Question(f"{Question.txt_header}\nWhat is your name?")
def test_question_without_prefix(self) -> None: def test_question_with_answer_header(self) -> None:
with self.assertRaises(MessageError):
Question(f"{Answer.txt_header}\nBob")
def test_question_with_legal_header(self) -> None:
"""
If the header is just a part of a line, it's fine.
"""
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
self.assertIsInstance(question, Question)
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
def test_question_without_header(self) -> None:
question = Question("What is your favorite color?") question = Question("What is your favorite color?")
self.assertIsInstance(question, Question) self.assertIsInstance(question, Question)
self.assertEqual(question, "What is your favorite color?") self.assertEqual(question, "What is your favorite color?")
class AnswerTestCase(CmmTestCase): class AnswerTestCase(CmmTestCase):
def test_answer_with_prefix(self) -> None: def test_answer_with_header(self) -> None:
with self.assertRaises(MessageError): with self.assertRaises(MessageError):
Answer("=== ANSWER === Yes") Answer(f"{Answer.txt_header}\nno")
def test_answer_without_prefix(self) -> None: def test_answer_with_legal_header(self) -> None:
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
self.assertIsInstance(answer, Answer)
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
def test_answer_without_header(self) -> None:
answer = Answer("No") answer = Answer("No")
self.assertIsInstance(answer, Answer) self.assertIsInstance(answer, Answer)
self.assertEqual(answer, "No") self.assertEqual(answer, "No")