106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import yaml
|
|
import io
|
|
import pathlib
|
|
from .utils import terminal_width, append_message, message_to_chat
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
def read_file(fname: str, tags_only: bool = False) -> Dict[str, Any]:
|
|
with open(fname, "r") as fd:
|
|
if tags_only:
|
|
return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]}
|
|
text = fd.read().strip().split('\n')
|
|
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
|
|
question_idx = text.index("=== QUESTION ===") + 1
|
|
answer_idx = text.index("==== ANSWER ====")
|
|
question = "\n".join(text[question_idx:answer_idx]).strip()
|
|
answer = "\n".join(text[answer_idx + 1:]).strip()
|
|
return {"question": question, "answer": answer, "tags": tags}
|
|
|
|
|
|
def dump_data(data: Dict[str, Any]) -> str:
|
|
with io.StringIO() as fd:
|
|
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
|
return fd.getvalue()
|
|
|
|
|
|
def write_file(fname: str, data: Dict[str, Any]) -> None:
|
|
with open(fname, "w") as fd:
|
|
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
|
|
|
|
|
def save_answers(question: str,
|
|
answers: list[str],
|
|
tags: list[str],
|
|
otags: Optional[list[str]],
|
|
config: Dict[str, Any]
|
|
) -> None:
|
|
wtags = otags or tags
|
|
num, inum = 0, 0
|
|
next_fname = pathlib.Path(config['db']) / '.next'
|
|
try:
|
|
with open(next_fname, 'r') as f:
|
|
num = int(f.read())
|
|
except Exception:
|
|
pass
|
|
for answer in answers:
|
|
num += 1
|
|
inum += 1
|
|
title = f'-- ANSWER {inum} '
|
|
title_end = '-' * (terminal_width() - len(title))
|
|
print(f'{title}{title_end}')
|
|
print(answer)
|
|
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
|
|
with open(next_fname, 'w') as f:
|
|
f.write(f'{num}')
|
|
|
|
|
|
def create_chat(question: Optional[str],
|
|
tags: Optional[List[str]],
|
|
extags: Optional[List[str]],
|
|
config: Dict[str, Any]
|
|
) -> List[Dict[str, str]]:
|
|
chat: List[Dict[str, str]] = []
|
|
append_message(chat, 'system', config['system'].strip())
|
|
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
if file.suffix == '.yaml':
|
|
with open(file, 'r') as f:
|
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
elif file.suffix == '.txt':
|
|
data = read_file(file)
|
|
else:
|
|
continue
|
|
data_tags = set(data.get('tags', []))
|
|
tags_match = \
|
|
not tags or data_tags.intersection(tags)
|
|
extags_do_not_match = \
|
|
not extags or not data_tags.intersection(extags)
|
|
if tags_match and extags_do_not_match:
|
|
message_to_chat(data, chat)
|
|
if question:
|
|
append_message(chat, 'user', question)
|
|
return chat
|
|
|
|
|
|
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|
result = []
|
|
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
if file.suffix == '.yaml':
|
|
with open(file, 'r') as f:
|
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
elif file.suffix == '.txt':
|
|
data = read_file(file, tags_only=True)
|
|
else:
|
|
continue
|
|
for tag in data.get('tags', []):
|
|
if prefix and len(prefix) > 0:
|
|
if tag.startswith(prefix):
|
|
result.append(tag)
|
|
else:
|
|
result.append(tag)
|
|
return list(set(result))
|