import yaml import io import pathlib from .utils import terminal_width, append_message, message_to_chat, ConfigType from typing import List, Dict, Any, Optional def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: with open(fname, "r") as fd: tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() # also support tags separated by ',' (old format) separator = ',' if ',' in tagline else ' ' tags = [t.strip() for t in tagline.split(separator)] if tags_only: return {"tags": tags} text = fd.read().strip().split('\n') question_idx = text.index("=== QUESTION ===") + 1 answer_idx = text.index("==== ANSWER ====") question = "\n".join(text[question_idx:answer_idx]).strip() answer = "\n".join(text[answer_idx + 1:]).strip() return {"question": question, "answer": answer, "tags": tags, "file": fname.name} def dump_data(data: Dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'==== ANSWER ====\n{data["answer"]}\n') return fd.getvalue() def write_file(fname: str, data: Dict[str, Any]) -> None: with open(fname, "w") as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'==== ANSWER ====\n{data["answer"]}\n') def save_answers(question: str, answers: list[str], tags: list[str], otags: Optional[list[str]], config: ConfigType ) -> None: wtags = otags or tags num, inum = 0, 0 next_fname = pathlib.Path(str(config['db'])) / '.next' try: with open(next_fname, 'r') as f: num = int(f.read()) except Exception: pass for answer in answers: num += 1 inum += 1 title = f'-- ANSWER {inum} ' title_end = '-' * (terminal_width() - len(title)) print(f'{title}{title_end}') print(answer) write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) with open(next_fname, 'w') as f: f.write(f'{num}') def create_chat_hist(question: Optional[str], tags: Optional[List[str]], extags: Optional[List[str]], config: ConfigType, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False ) -> List[Dict[str, str]]: chat: List[Dict[str, str]] = [] append_message(chat, 'system', str(config['system']).strip()) for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) data['file'] = file.name elif file.suffix == '.txt': data = read_file(file) else: continue data_tags = set(data.get('tags', [])) tags_match: bool if match_all_tags: tags_match = not tags or set(tags).issubset(data_tags) else: tags_match = not tags or bool(data_tags.intersection(tags)) extags_do_not_match = \ not extags or not data_tags.intersection(extags) if tags_match and extags_do_not_match: message_to_chat(data, chat, with_tags, with_file) if question: append_message(chat, 'user', question) return chat def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: result = [] for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) elif file.suffix == '.txt': data = read_file(file, tags_only=True) else: continue for tag in data.get('tags', []): if prefix and len(prefix) > 0: if tag.startswith(prefix): result.append(tag) else: result.append(tag) return result def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: return list(set(get_tags(config, prefix)))