85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
import shutil
|
|
from pprint import PrettyPrinter
|
|
|
|
ConfigType = dict[str, str | dict[str, str | int | float]]
|
|
|
|
|
|
def terminal_width() -> int:
|
|
return shutil.get_terminal_size().columns
|
|
|
|
|
|
def pp(*args, **kwargs) -> None:
|
|
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
|
|
|
|
|
def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None:
|
|
"""
|
|
Prints the tags specified in the given args.
|
|
"""
|
|
printed_messages = []
|
|
|
|
if tags:
|
|
printed_messages.append(f"Tags: {' '.join(tags)}")
|
|
if extags:
|
|
printed_messages.append(f"Excluding tags: {' '.join(extags)}")
|
|
if otags:
|
|
printed_messages.append(f"Output tags: {' '.join(otags)}")
|
|
|
|
if printed_messages:
|
|
print("\n".join(printed_messages))
|
|
print()
|
|
|
|
|
|
def append_message(chat: list[dict[str, str]],
|
|
role: str,
|
|
content: str
|
|
) -> None:
|
|
chat.append({'role': role, 'content': content.replace("''", "'")})
|
|
|
|
|
|
def message_to_chat(message: dict[str, str],
|
|
chat: list[dict[str, str]],
|
|
with_tags: bool = False,
|
|
with_file: bool = False
|
|
) -> None:
|
|
append_message(chat, 'user', message['question'])
|
|
append_message(chat, 'assistant', message['answer'])
|
|
if with_tags:
|
|
tags = " ".join(message['tags'])
|
|
append_message(chat, 'tags', tags)
|
|
if with_file:
|
|
append_message(chat, 'file', message['file'])
|
|
|
|
|
|
def display_source_code(content: str) -> None:
|
|
try:
|
|
content_start = content.index('```')
|
|
content_end = content.rindex('```')
|
|
if content_start + 3 < content_end:
|
|
print(content[content_start + 3:content_end].strip())
|
|
except ValueError:
|
|
pass
|
|
|
|
|
|
def print_chat_hist(chat, dump=False, source_code=False) -> None:
|
|
if dump:
|
|
pp(chat)
|
|
return
|
|
for message in chat:
|
|
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
|
|
if source_code:
|
|
display_source_code(message['content'])
|
|
continue
|
|
if message['role'] == 'user':
|
|
print('-' * terminal_width())
|
|
if text_too_long:
|
|
print(f"{message['role'].upper()}:")
|
|
print(message['content'])
|
|
else:
|
|
print(f"{message['role'].upper()}: {message['content']}")
|
|
|
|
|
|
def print_tags_frequency(tags: list[str]) -> None:
|
|
for tag in sorted(set(tags)):
|
|
print(f"- {tag}: {tags.count(tag)}")
|