121 lines
4.3 KiB
Python
121 lines
4.3 KiB
Python
import yaml
|
|
import io
|
|
import pathlib
|
|
from .utils import terminal_width, append_message, message_to_chat
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]:
|
|
with open(fname, "r") as fd:
|
|
tagline = fd.readline().strip().split(':')[1].strip()
|
|
# also support tags separated by ',' (old format)
|
|
separator = ',' if ',' in tagline else ' '
|
|
tags = [t.strip() for t in tagline.split(separator)]
|
|
if tags_only:
|
|
return {"tags": tags}
|
|
text = fd.read().strip().split('\n')
|
|
question_idx = text.index("=== QUESTION ===") + 1
|
|
answer_idx = text.index("==== ANSWER ====")
|
|
question = "\n".join(text[question_idx:answer_idx]).strip()
|
|
answer = "\n".join(text[answer_idx + 1:]).strip()
|
|
return {"question": question, "answer": answer, "tags": tags,
|
|
"file": fname.name}
|
|
|
|
|
|
def dump_data(data: Dict[str, Any]) -> str:
|
|
with io.StringIO() as fd:
|
|
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
|
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
|
return fd.getvalue()
|
|
|
|
|
|
def write_file(fname: str, data: Dict[str, Any]) -> None:
|
|
with open(fname, "w") as fd:
|
|
fd.write(f'TAGS: {" ".join(data["tags"])}\n')
|
|
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
|
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
|
|
|
|
|
def save_answers(question: str,
|
|
answers: list[str],
|
|
tags: list[str],
|
|
otags: Optional[list[str]],
|
|
config: Dict[str, Any]
|
|
) -> None:
|
|
wtags = otags or tags
|
|
num, inum = 0, 0
|
|
next_fname = pathlib.Path(config['db']) / '.next'
|
|
try:
|
|
with open(next_fname, 'r') as f:
|
|
num = int(f.read())
|
|
except Exception:
|
|
pass
|
|
for answer in answers:
|
|
num += 1
|
|
inum += 1
|
|
title = f'-- ANSWER {inum} '
|
|
title_end = '-' * (terminal_width() - len(title))
|
|
print(f'{title}{title_end}')
|
|
print(answer)
|
|
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
|
|
with open(next_fname, 'w') as f:
|
|
f.write(f'{num}')
|
|
|
|
|
|
def create_chat(question: Optional[str],
|
|
tags: Optional[List[str]],
|
|
extags: Optional[List[str]],
|
|
config: Dict[str, Any],
|
|
match_all_tags: bool = False,
|
|
with_tags: bool = False,
|
|
with_file: bool = False
|
|
) -> List[Dict[str, str]]:
|
|
chat: List[Dict[str, str]] = []
|
|
append_message(chat, 'system', config['system'].strip())
|
|
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
if file.suffix == '.yaml':
|
|
with open(file, 'r') as f:
|
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
data['file'] = file.name
|
|
elif file.suffix == '.txt':
|
|
data = read_file(file)
|
|
else:
|
|
continue
|
|
data_tags = set(data.get('tags', []))
|
|
tags_match: bool
|
|
if match_all_tags:
|
|
tags_match = not tags or set(tags).issubset(data_tags)
|
|
else:
|
|
tags_match = not tags or bool(data_tags.intersection(tags))
|
|
extags_do_not_match = \
|
|
not extags or not data_tags.intersection(extags)
|
|
if tags_match and extags_do_not_match:
|
|
message_to_chat(data, chat, with_tags, with_file)
|
|
if question:
|
|
append_message(chat, 'user', question)
|
|
return chat
|
|
|
|
|
|
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|
result = []
|
|
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
|
if file.suffix == '.yaml':
|
|
with open(file, 'r') as f:
|
|
data = yaml.load(f, Loader=yaml.FullLoader)
|
|
elif file.suffix == '.txt':
|
|
data = read_file(file, tags_only=True)
|
|
else:
|
|
continue
|
|
for tag in data.get('tags', []):
|
|
if prefix and len(prefix) > 0:
|
|
if tag.startswith(prefix):
|
|
result.append(tag)
|
|
else:
|
|
result.append(tag)
|
|
return result
|
|
|
|
|
|
def get_tags_unique(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
|
return list(set(get_tags(config, prefix)))
|