Compare commits
5 Commits
0ec606cdaa
...
0dbb0d3c4d
| Author | SHA1 | Date | |
|---|---|---|---|
| 0dbb0d3c4d | |||
| f07c3a58a1 | |||
| 96980bc4a8 | |||
| 26f72ed002 | |||
| a5fa79a4e5 |
48
chatmastermind/ais/openai.py
Normal file
48
chatmastermind/ais/openai.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
Implements the OpenAI client classes and functions.
|
||||||
|
"""
|
||||||
|
import openai
|
||||||
|
from ..message import Message
|
||||||
|
from ..chat import Chat
|
||||||
|
from ..ai import AI, AIResponse
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAI(AI):
|
||||||
|
"""
|
||||||
|
The OpenAI AI client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def request(self,
|
||||||
|
question: Message,
|
||||||
|
context: Chat,
|
||||||
|
num_answers: int = 1) -> AIResponse:
|
||||||
|
"""
|
||||||
|
Make an AI request, asking the given question with the given
|
||||||
|
context (i. e. chat history). The nr. of requested answers
|
||||||
|
corresponds to the nr. of messages in the 'AIResponse'.
|
||||||
|
"""
|
||||||
|
# TODO:
|
||||||
|
# * transform given message and chat context into OpenAI format
|
||||||
|
# * make request
|
||||||
|
# * create a new Message for each answer and return them
|
||||||
|
# (writing Messages is done by the calles)
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return all models supported by this AI.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Print all models supported by the current AI.
|
||||||
|
"""
|
||||||
|
not_ready = []
|
||||||
|
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
||||||
|
if engine['ready']:
|
||||||
|
print(engine['id'])
|
||||||
|
else:
|
||||||
|
not_ready.append(engine['id'])
|
||||||
|
if len(not_ready) > 0:
|
||||||
|
print('\nNot ready: ' + ', '.join(not_ready))
|
||||||
@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path,
|
|||||||
messages: list[Message] = []
|
messages: list[Message] = []
|
||||||
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
||||||
for file_path in sorted(file_iter):
|
for file_path in sorted(file_iter):
|
||||||
if file_path.is_file():
|
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
|
||||||
try:
|
try:
|
||||||
message = Message.from_file(file_path, mfilter)
|
message = Message.from_file(file_path, mfilter)
|
||||||
if message:
|
if message:
|
||||||
@ -127,7 +127,16 @@ class Chat:
|
|||||||
tags: set[Tag] = set()
|
tags: set[Tag] = set()
|
||||||
for m in self.messages:
|
for m in self.messages:
|
||||||
tags |= m.filter_tags(prefix, contain)
|
tags |= m.filter_tags(prefix, contain)
|
||||||
return tags
|
return set(sorted(tags))
|
||||||
|
|
||||||
|
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
|
||||||
|
"""
|
||||||
|
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
|
||||||
|
"""
|
||||||
|
tags: list[Tag] = []
|
||||||
|
for m in self.messages:
|
||||||
|
tags += [tag for tag in m.filter_tags(prefix, contain)]
|
||||||
|
return {tag: tags.count(tag) for tag in sorted(tags)}
|
||||||
|
|
||||||
def tokens(self) -> int:
|
def tokens(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,10 +7,11 @@ import sys
|
|||||||
import argcomplete
|
import argcomplete
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
import pathlib
|
||||||
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType
|
from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType
|
||||||
from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data
|
from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data
|
||||||
from .api_client import ai, openai_api_key, print_models
|
from .api_client import ai, openai_api_key, print_models
|
||||||
from .configuration import Config
|
from .configuration import Config
|
||||||
|
from .chat import ChatDB
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -61,8 +62,12 @@ def tag_cmd(args: argparse.Namespace, config: Config) -> None:
|
|||||||
"""
|
"""
|
||||||
Handler for the 'tag' command.
|
Handler for the 'tag' command.
|
||||||
"""
|
"""
|
||||||
|
chat = ChatDB.from_dir(cache_path=pathlib.Path('.'),
|
||||||
|
db_path=pathlib.Path(config.db))
|
||||||
if args.list:
|
if args.list:
|
||||||
print_tags_frequency(get_tags(config, None))
|
tags_freq = chat.tags_frequency(args.prefix, args.contain)
|
||||||
|
for tag, freq in tags_freq.items():
|
||||||
|
print(f"- {tag}: {freq}")
|
||||||
|
|
||||||
|
|
||||||
def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
def config_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||||
@ -195,6 +200,8 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
|
tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||||
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
tag_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
tag_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
|
||||||
|
tag_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
|
||||||
|
|
||||||
# 'config' command parser
|
# 'config' command parser
|
||||||
config_cmd_parser = cmdparser.add_parser('config',
|
config_cmd_parser = cmdparser.add_parser('config',
|
||||||
|
|||||||
@ -128,29 +128,29 @@ class ModelLine(str):
|
|||||||
return cls(' '.join([cls.prefix, model]))
|
return cls(' '.join([cls.prefix, model]))
|
||||||
|
|
||||||
|
|
||||||
class Question(str):
|
class Answer(str):
|
||||||
"""
|
"""
|
||||||
A single question with a defined header.
|
A single answer with a defined header.
|
||||||
"""
|
"""
|
||||||
tokens: int = 0 # tokens used by this question
|
tokens: int = 0 # tokens used by this answer
|
||||||
txt_header: ClassVar[str] = '=== QUESTION ==='
|
txt_header: ClassVar[str] = '=== ANSWER ==='
|
||||||
yaml_key: ClassVar[str] = 'question'
|
yaml_key: ClassVar[str] = 'answer'
|
||||||
|
|
||||||
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
|
||||||
"""
|
"""
|
||||||
Make sure the question string does not contain the header.
|
Make sure the answer string does not contain the header as a whole line.
|
||||||
"""
|
"""
|
||||||
if cls.txt_header in string:
|
if cls.txt_header in string.split('\n'):
|
||||||
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
|
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
|
||||||
instance = super().__new__(cls, string)
|
instance = super().__new__(cls, string)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
|
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
|
||||||
"""
|
"""
|
||||||
Build Question from a list of strings. Make sure strings do not contain the header.
|
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||||
"""
|
"""
|
||||||
if any(cls.txt_header in string for string in strings):
|
if cls.txt_header in strings:
|
||||||
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||||
instance = super().__new__(cls, '\n'.join(strings).strip())
|
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||||
return instance
|
return instance
|
||||||
@ -162,29 +162,33 @@ class Question(str):
|
|||||||
return source_code(self, include_delims)
|
return source_code(self, include_delims)
|
||||||
|
|
||||||
|
|
||||||
class Answer(str):
|
class Question(str):
|
||||||
"""
|
"""
|
||||||
A single answer with a defined header.
|
A single question with a defined header.
|
||||||
"""
|
"""
|
||||||
tokens: int = 0 # tokens used by this answer
|
tokens: int = 0 # tokens used by this question
|
||||||
txt_header: ClassVar[str] = '=== ANSWER ==='
|
txt_header: ClassVar[str] = '=== QUESTION ==='
|
||||||
yaml_key: ClassVar[str] = 'answer'
|
yaml_key: ClassVar[str] = 'question'
|
||||||
|
|
||||||
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
|
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
||||||
"""
|
"""
|
||||||
Make sure the answer string does not contain the header.
|
Make sure the question string does not contain the header as a whole line
|
||||||
|
(also not that from 'Answer', so it's always clear where the answer starts).
|
||||||
"""
|
"""
|
||||||
if cls.txt_header in string:
|
string_lines = string.split('\n')
|
||||||
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
|
if cls.txt_header in string_lines:
|
||||||
|
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
|
||||||
|
if Answer.txt_header in string_lines:
|
||||||
|
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
|
||||||
instance = super().__new__(cls, string)
|
instance = super().__new__(cls, string)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
|
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
|
||||||
"""
|
"""
|
||||||
Build Question from a list of strings. Make sure strings do not contain the header.
|
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||||
"""
|
"""
|
||||||
if any(cls.txt_header in string for string in strings):
|
if cls.txt_header in strings:
|
||||||
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||||
instance = super().__new__(cls, '\n'.join(strings).strip())
|
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||||
return instance
|
return instance
|
||||||
|
|||||||
@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals
|
|||||||
print(message['content'])
|
print(message['content'])
|
||||||
else:
|
else:
|
||||||
print(f"{message['role'].upper()}: {message['content']}")
|
print(f"{message['role'].upper()}: {message['content']}")
|
||||||
|
|
||||||
|
|
||||||
def print_tags_frequency(tags: list[str]) -> None:
|
|
||||||
for tag in sorted(set(tags)):
|
|
||||||
print(f"- {tag}: {tags.count(tag)}")
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ class TestChat(CmmTestCase):
|
|||||||
self.chat = Chat([])
|
self.chat = Chat([])
|
||||||
self.message1 = Message(Question('Question 1'),
|
self.message1 = Message(Question('Question 1'),
|
||||||
Answer('Answer 1'),
|
Answer('Answer 1'),
|
||||||
{Tag('atag1')},
|
{Tag('atag1'), Tag('btag2')},
|
||||||
file_path=pathlib.Path('0001.txt'))
|
file_path=pathlib.Path('0001.txt'))
|
||||||
self.message2 = Message(Question('Question 2'),
|
self.message2 = Message(Question('Question 2'),
|
||||||
Answer('Answer 2'),
|
Answer('Answer 2'),
|
||||||
@ -57,6 +57,11 @@ class TestChat(CmmTestCase):
|
|||||||
tags_cont = self.chat.tags(contain='2')
|
tags_cont = self.chat.tags(contain='2')
|
||||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||||
|
|
||||||
|
def test_tags_frequency(self) -> None:
|
||||||
|
self.chat.add_msgs([self.message1, self.message2])
|
||||||
|
tags_freq = self.chat.tags_frequency()
|
||||||
|
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
||||||
|
|
||||||
@patch('sys.stdout', new_callable=StringIO)
|
@patch('sys.stdout', new_callable=StringIO)
|
||||||
def test_print(self, mock_stdout: StringIO) -> None:
|
def test_print(self, mock_stdout: StringIO) -> None:
|
||||||
self.chat.add_msgs([self.message1, self.message2])
|
self.chat.add_msgs([self.message1, self.message2])
|
||||||
@ -83,7 +88,7 @@ Answer 2
|
|||||||
Question 1
|
Question 1
|
||||||
{Answer.txt_header}
|
{Answer.txt_header}
|
||||||
Answer 1
|
Answer 1
|
||||||
{TagLine.prefix} atag1
|
{TagLine.prefix} atag1 btag2
|
||||||
FILE: 0001.txt
|
FILE: 0001.txt
|
||||||
{'-'*terminal_width()}
|
{'-'*terminal_width()}
|
||||||
{Question.txt_header}
|
{Question.txt_header}
|
||||||
|
|||||||
@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase):
|
|||||||
|
|
||||||
|
|
||||||
class QuestionTestCase(CmmTestCase):
|
class QuestionTestCase(CmmTestCase):
|
||||||
def test_question_with_prefix(self) -> None:
|
def test_question_with_header(self) -> None:
|
||||||
with self.assertRaises(MessageError):
|
with self.assertRaises(MessageError):
|
||||||
Question("=== QUESTION === What is your name?")
|
Question(f"{Question.txt_header}\nWhat is your name?")
|
||||||
|
|
||||||
def test_question_without_prefix(self) -> None:
|
def test_question_with_answer_header(self) -> None:
|
||||||
|
with self.assertRaises(MessageError):
|
||||||
|
Question(f"{Answer.txt_header}\nBob")
|
||||||
|
|
||||||
|
def test_question_with_legal_header(self) -> None:
|
||||||
|
"""
|
||||||
|
If the header is just a part of a line, it's fine.
|
||||||
|
"""
|
||||||
|
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
|
||||||
|
self.assertIsInstance(question, Question)
|
||||||
|
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
|
||||||
|
|
||||||
|
def test_question_without_header(self) -> None:
|
||||||
question = Question("What is your favorite color?")
|
question = Question("What is your favorite color?")
|
||||||
self.assertIsInstance(question, Question)
|
self.assertIsInstance(question, Question)
|
||||||
self.assertEqual(question, "What is your favorite color?")
|
self.assertEqual(question, "What is your favorite color?")
|
||||||
|
|
||||||
|
|
||||||
class AnswerTestCase(CmmTestCase):
|
class AnswerTestCase(CmmTestCase):
|
||||||
def test_answer_with_prefix(self) -> None:
|
def test_answer_with_header(self) -> None:
|
||||||
with self.assertRaises(MessageError):
|
with self.assertRaises(MessageError):
|
||||||
Answer("=== ANSWER === Yes")
|
Answer(f"{Answer.txt_header}\nno")
|
||||||
|
|
||||||
def test_answer_without_prefix(self) -> None:
|
def test_answer_with_legal_header(self) -> None:
|
||||||
|
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
|
||||||
|
self.assertIsInstance(answer, Answer)
|
||||||
|
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
|
||||||
|
|
||||||
|
def test_answer_without_header(self) -> None:
|
||||||
answer = Answer("No")
|
answer = Answer("No")
|
||||||
self.assertIsInstance(answer, Answer)
|
self.assertIsInstance(answer, Answer)
|
||||||
self.assertEqual(answer, "No")
|
self.assertEqual(answer, "No")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user