import sys import argparse from pathlib import Path from itertools import zip_longest from copy import deepcopy from ..configuration import Config from ..chat import ChatDB from ..message import Message, MessageFilter, MessageError, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse def add_file_as_text(question_parts: list[str], file: str) -> None: """ Add the given file as plain text to the question part list. If the file is a Message, add the answer. """ file_path = Path(file) content: str try: message = Message.from_file(file_path) if message and message.answer: content = message.answer except MessageError: with open(file) as r: content = r.read().strip() if len(content) > 0: question_parts.append(content) def add_file_as_code(question_parts: list[str], file: str) -> None: """ Add all source code from the given file. If no code segments can be extracted, the whole content is added as source code segment. If the file is a Message, extract the source code from the answer. """ file_path = Path(file) content: str try: message = Message.from_file(file_path) if message and message.answer: content = message.answer except MessageError: with open(file) as r: content = r.read().strip() # extract and add source code code_parts = source_code(content, include_delims=True) if len(code_parts) > 0: question_parts += code_parts else: question_parts.append(f"```\n{content}\n```") def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Create a new message from the given arguments and write it to the cache directory. """ question_parts = [] question_list = args.ask if args.ask is not None else [] text_files = args.source_text if args.source_text is not None else [] code_files = args.source_code if args.source_code is not None else [] for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) if text_file is not None and len(text_file) > 0: add_file_as_text(question_parts, text_file) if code_file is not None and len(code_file) > 0: add_file_as_code(question_parts, code_file) full_question = '\n\n'.join(question_parts) message = Message(question=Question(full_question), tags=args.output_tags, ai=args.AI, model=args.model) # only write the new message to the cache, # don't add it to the internal list chat.cache_write([message]) return message def make_request(ai: AI, chat: ChatDB, message: Message, args: argparse.Namespace) -> None: """ Make an AI request with the given AI, chat history, message and arguments. Write the response(s) to the cache directory, without appending it to the given chat history. Then print the response(s). """ # print history and message question before making the request ai.print() chat.print(paged=False) print(message.to_str()) response: AIResponse = ai.request(message, chat, args.num_answers, args.output_tags) # only write the response messages to the cache, # don't add them to the internal list chat.cache_write(response.messages) for idx, msg in enumerate(response.messages): print(f"=== ANSWER {idx+1} ===") print(msg.answer) if response.tokens: print("===============") print(response.tokens) def make_msg_args(msg: Message, args: argparse.Namespace) -> argparse.Namespace: """ Takes an existing message and CLI arguments, and returns modified args based on the members of the given message. Used e.g. when repeating messages, where it's necessary to determine the correct AI, module and output tags to use (either from the existing message or the given args). """ msg_args = args # if AI, model or output tags have not been specified, # use those from the original message if (args.AI is None or args.model is None # noqa: W503 or args.output_tags is None # noqa: W503 or len(args.output_tags) == 0): # noqa: W503 msg_args = deepcopy(args) if args.AI is None and msg.ai is not None: msg_args.AI = msg.ai if args.model is None and msg.model is not None: msg_args.model = msg.model if (args.output_tags is None or len(args.output_tags) == 0) and msg.tags is not None: msg_args.output_tags = msg.tags return msg_args def repeat_messages(messages: list[Message], chat: ChatDB, args: argparse.Namespace, config: Config) -> None: """ Repeat the given messages using the given arguments. """ ai: AI for msg in messages: ai = create_ai(make_msg_args(msg, args), config) print(f"--------- Repeating message '{msg.msg_id()}': ---------") # overwrite the latest message if requested or empty # -> but not if it's in the DB! if ((msg.answer is None or args.overwrite is True) and (not chat.msg_in_db(msg))): # noqa: W503 msg.clear_answer() make_request(ai, chat, msg, args) # otherwise create a new one else: args.ask = [msg.question] message = create_message(chat, args) make_request(ai, chat, message, args) def modify_tag_args(args: argparse.Namespace) -> None: """ Changes the semantics of the tags for this command: * not tags specified -> no tags are selected * empty tags specified -> all tags are selected """ if args.or_tags is None: args.or_tags = set() elif len(args.or_tags) == 0: args.or_tags = None if args.and_tags is None: args.and_tags = set() elif len(args.and_tags) == 0: args.and_tags = None if args.exclude_tags is None: args.exclude_tags = set() elif len(args.exclude_tags) == 0: args.exclude_tags = None def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ modify_tag_args(args) mfilter = MessageFilter(tags_or=args.or_tags, tags_and=args.and_tags, tags_not=args.exclude_tags) chat = ChatDB.from_dir(cache_path=Path(config.cache), db_path=Path(config.db), mfilter=mfilter) # if it's a new question, create and store it immediately if args.ask or args.create: message = create_message(chat, args) if args.create: return # === ASK === if args.ask: ai: AI = create_ai(args, config) make_request(ai, chat, message, args) # === REPEAT === elif args.repeat is not None: repeat_msgs: list[Message] = [] # repeat latest message if len(args.repeat) == 0: lmessage = chat.msg_latest(loc='cache') if lmessage is None: print("No message found to repeat!") sys.exit(1) repeat_msgs.append(lmessage) # repeat given message(s) else: repeat_msgs = chat.msg_find(args.repeat, loc='disk') repeat_messages(repeat_msgs, chat, args, config) # === PROCESS === elif args.process is not None: # TODO: process either all questions without an # answer or the one(s) given in 'args.process' pass