diff --git a/.github/workflows/pypi-publish.yml b/.github/workflows/pypi-publish.yml new file mode 100644 index 0000000..cb82597 --- /dev/null +++ b/.github/workflows/pypi-publish.yml @@ -0,0 +1,36 @@ +name: publish + +on: + release: + types: [published] + +permissions: + contents: read + id-token: write + +jobs: + pypi: + name: build and publish to PyPI + runs-on: ubuntu-latest + environment: pypi + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tooling + run: python -m pip install --upgrade build twine + + - name: Build distribution + run: python -m build + + - name: Check distribution + run: python -m twine check dist/* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 7082077..9e5bf64 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ __pycache__/ .pytest_cache/ .mypy_cache/ .ruff_cache/ +.coverage +coverage.xml +htmlcov/ .venv/ dist/ build/ @@ -11,3 +14,34 @@ workspace/ output/ .codex/ .agents/ + +# Local secrets/configuration +.env +.env.* +!.env.example + +# Runtime data and generated indexes +*.db +*.db-* +*.sqlite +*.sqlite3 +*.sqlite3-* + +# Logs and temporary run output +*.log +*.out +*.tmp +*.bak + +# Large local artifacts +*.tar +*.tar.gz +*.tgz +*.zip +*.7z + +# Local staging and evaluation artifacts +staging/ + +# Local generated previews +docs/*preview*.html diff --git a/README.md b/README.md index 94fafab..768892a 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,41 @@ # Little Heta -Little Heta is a lightweight command line tool for personal knowledge, memory, -and document intelligence workflows. It converts local documents into a -Markdown wiki, keeps wiki page identity stable, and can maintain a SQLite -vector index for faster semantic retrieval. +

+ Little Heta banner +

+ +

+ English · + 简体中文 · + 繁體中文 · + 日本語 · + 한국어 · + Español · + Português · + Français · + Deutsch +

+ +

+ PyPI v0.1.0 + Python 3.10+ + License: MIT + KnowledgeXLab +

+ +Little Heta is a local CLI knowledge infrastructure for personal documents, +agent memory, and document intelligence. It turns PDFs, Office files, images, +audio, code, HTML, Markdown, and notes into a stable Markdown wiki, adds +semantic vector retrieval, and lets agents reuse distilled knowledge through a +memory layer. -## Status - -This repository is an early `v0.1.0` implementation. The current focus is a -fast local workflow for initialization, document insertion, wiki maintenance, -and optional vector indexing. - -## Features +## Install -- Interactive first-time setup with `heta init` -- Provider configuration for Qwen, ChatGPT, or Gemini -- Optional MinerU integration for PDF parsing -- Markdown wiki generation under the Little Heta workspace -- Stable numeric wiki page ids in page filenames -- Optional SQLite + sqlite-vec wiki chunk index -- CLI status view with provider, MinerU, KB, wiki, and space usage summaries +Install from PyPI: -## Install +```bash +pip install little-heta +``` From a local checkout: @@ -29,78 +43,198 @@ From a local checkout: pip install -e . ``` -For development dependencies: +For development: ```bash pip install -e ".[dev]" ``` -## Quick Start - -Initialize Little Heta: +The package installs the `heta` command: ```bash -heta init +heta --help ``` -The wizard writes configuration to: - -```text -~/.heta/heta.yaml -``` +## Initialize -Check the current workspace and provider status: +Run the first-time setup: ```bash -heta status +heta init ``` -Insert one file or a directory: +You need to prepare: -```bash -heta insert ./docs +- An LLM API key for one provider: Qwen, ChatGPT, or Gemini. +- Optional MinerU access for PDF and Office parsing. Apply or learn more at + [MinerU](https://mineru.net/apiManage/docs). + +`heta init` writes config and workspace data under: + +```text +~/.heta/ ``` -Large PDFs are profiled and split before parsing by default. Little Heta gives a -lightweight PDF profile to a planning agent, validates the returned page ranges, -and falls back to fixed page windows when planning is unavailable. Disable this -behavior when you want to parse a PDF as one source file: +It also installs the Little Heta agent skill automatically into: -```bash -heta insert --no-pdf-planning ./large.pdf +```text +~/.codex/skills/heta +~/.claude/skills/heta ``` -Ask a read-only question against the wiki: +## Use with Codex and Claude Code + +After `heta init`, Codex and Claude Code can discover the Little Heta skill +globally. The skill tells the agent when to use: ```bash -heta query "What is HetaGen?" +heta ask "..." +heta query "..." +heta recall "..." +heta remember "..." ``` -Clean wiki pages and the vector database while keeping raw files: +You can refresh or reinstall the skill at any time: ```bash -heta clean +heta skill ``` -Manage vector indexing: +For other agent frameworks, copy these two files: -```bash -heta vector status -heta vector on -heta vector off +```text +~/.heta/skills/heta/SKILL.md +~/.heta/skills/heta/COMMANDS.md ``` +## What You Get + +Most personal knowledge bases eventually become a `/raw` folder: papers, +slides, screenshots, audio clips, code files, notes, and half-finished drafts +all pile up together. A normal agent can read those files directly, but every +question pays the same cost again: open the index, guess which page matters, +read long pages, and spend tokens rediscovering context it already found before. + +Little Heta separates the external knowledge base from the agent's internal +memory. The KB remains the source of truth: a structured, versioned wiki built +from the user's files. Memory, by contrast, is the agent's persistent working +layer, storing reusable information that helps the agent reason, route, and +avoid repeated deep retrieval. This creates a memory-first, KB-grounded +retrieval loop. + +Little Heta turns that pile into a persistent agent workspace: + +- **Wiki foundation**: raw files are compiled into stable Markdown pages with + numeric page ids, clean `[[Wiki Links]]`, and Git history. +- **Vector Wiki**: each page is chunked by Markdown structure, so `heta query` + can jump to the right section instead of relying only on sparse `index.md` + summaries. +- **Memory-first retrieval**: `heta ask` stores distilled KB insights after + expensive lookups, allowing later questions to reuse prior KB understanding + instead of repeating the same deep wiki traversal. +- **Synchronized memory + KB management**: memory stays tied to the evolving + wiki. When KB content changes, related memories can be invalidated to prevent + stale cached insights from drifting away from the source of truth. +- **Agent reuse**: larger teams and multi-agent workflows benefit because useful + KB discoveries can be reused across later questions, sessions, and agents. + +Heta's memory architecture stores four complementary types of information: + +- **Raw dialogue memory**: original user-agent interaction history, preserving + full context and wording. +- **Atomic fact memory**: compact factual statements extracted from + interactions, useful for precise attribute or preference recall. +- **Episodic memory**: event-level summaries that capture tasks, decisions, + temporal context, and multi-step work sessions. +- **KB insight memory**: distilled insights produced after KB retrieval, + storing what the agent learned from external documents so future questions + can reuse that understanding without repeating the same expensive traversal. + +Retrieval quality depends heavily on corpus structure. In corpora where +important details are buried deep inside long wiki pages and poorly represented +by summaries, index-only wiki navigation can suffer severe retrieval collapse. +In our initial stress scenarios, Vector Wiki and memory-backed retrieval +improved answer accuracy by roughly **1.25x-5x+**, with some cases recovering +from **0% to 100%** accuracy. + +Memory-backed reuse used **82.1% fewer tokens** than index-only wiki query and +answered **2.58x faster** in a multi-page comparison setting. This gap is expected to +grow in larger or messier workspaces, because index-only wiki navigation scales +with the number and length of pages an agent may need to inspect, while +memory-backed reuse resolves repeated questions from previously distilled +insights. The main extra cost is the first pass that creates the reusable +insight. + +## Core CLI + +The main commands are: + +- `heta init`: set up providers, workspace, and agent skills. +- `heta status`: show provider, MinerU, wiki, memory, and space status. +- `heta insert`: add files or folders to the knowledge base. +- `heta query`: ask a read-only question against inserted documents. +- `heta ask`: answer using memory and the document KB together. +- `heta remember`: save a fact, decision, or preference. +- `heta recall`: retrieve saved memory. +- `heta clean`: remove generated wiki pages and vector DB while keeping raw files. +- `heta vector`: turn document vector indexing on, off, or show status. +- `heta insert-planning`: turn smart insert planning on, off, or show status. +- `heta mem-show`: inspect stored KB memories. +- `heta mem-clean`: erase memory data. +- `heta skill`: install or refresh agent skills. + +Detailed command docs: + +- [init](docs/cli/init.md) +- [status](docs/cli/status.md) +- [insert](docs/cli/insert.md) +- [query](docs/cli/query.md) +- [ask](docs/cli/ask.md) +- [remember](docs/cli/remember.md) +- [recall](docs/cli/recall.md) +- [clean](docs/cli/clean.md) +- [vector](docs/cli/vector.md) +- [insert-planning](docs/cli/insert-planning.md) +- [mem-show](docs/cli/mem-show.md) +- [mem-clean](docs/cli/mem-clean.md) +- [skill](docs/cli/skill.md) + +## Supported Files + +Little Heta can insert: + +- Markdown and text: `.md`, `.markdown`, `.txt` +- PDF and Office: `.pdf`, `.doc`, `.docx`, `.ppt`, `.pptx`, `.xls`, `.xlsx` +- Images: `.png`, `.jpg`, `.jpeg`, `.webp`, `.gif`, `.bmp` +- Audio and video transcripts: `.mp3`, `.wav`, `.m4a`, `.flac`, `.ogg`, `.mp4` +- Code and config files: `.py`, `.js`, `.ts`, `.tsx`, `.jsx`, `.java`, `.go`, + `.rs`, `.cpp`, `.c`, `.h`, `.hpp`, `.sh`, `.sql`, `.yaml`, `.yml`, `.json`, + `.toml` +- HTML: `.html`, `.htm` + +PDF and Office parsing require MinerU. Images and audio/video require a +multimodal or transcription-capable LLM provider. + ## Workspace -Little Heta stores local runtime data under: +Runtime data lives under: ```text ~/.heta/ ``` -The workspace contains raw source files, generated wiki pages, worktrees, and -the local database used by the vector index. Runtime workspace data is not -intended to be committed to this repository. +Important paths: + +```text +~/.heta/heta.yaml config +~/.heta/workspace/kb/raw archived source files +~/.heta/workspace/kb/wiki/index.md wiki entry index +~/.heta/workspace/kb/wiki/pages/ generated Markdown wiki pages +~/.heta/workspace/kb/wiki/log.md wiki operation log +~/.heta/workspace/kb/db/wiki_vectors.sqlite3 local wiki vector database +~/.heta/workspace/mem/mem.sqlite3 local memory database +~/.heta/skills/heta/ portable Little Heta agent skill +``` ## Development @@ -113,11 +247,18 @@ pytest Project layout: ```text -src/heta/ CLI, config, providers, and KB implementation +src/heta/ CLI, config, assistants, memory, and KB implementation +docs/ user and technical documentation tests/ unit tests pyproject.toml package metadata and dependencies ``` +## Community + +If Little Heta is useful to you, please consider giving the project a star. If +you run into bugs, rough edges, or missing workflows, open an issue and tell us +what happened. + ## License Little Heta is released under the MIT License. See [LICENSE](LICENSE). diff --git a/docs/assets/little-heta-banner.png b/docs/assets/little-heta-banner.png new file mode 100644 index 0000000..9e36ab5 Binary files /dev/null and b/docs/assets/little-heta-banner.png differ diff --git a/docs/cli/ask.md b/docs/cli/ask.md new file mode 100644 index 0000000..88d4622 --- /dev/null +++ b/docs/cli/ask.md @@ -0,0 +1,22 @@ +# heta ask + +Ask using memory and inserted documents together. + +```bash +heta ask "How does our auth flow refresh tokens?" +``` + +This is the default command for agent workflows. It can: + +- Search saved memory first. +- Query the document wiki when memory is not enough. +- Store distilled KB insights for later reuse. + +Options: + +```bash +heta ask "..." --top-k 5 --debug +``` + +`--debug` shows agent steps, memory evidence, and KB output. + diff --git a/docs/cli/clean.md b/docs/cli/clean.md new file mode 100644 index 0000000..0b051a5 --- /dev/null +++ b/docs/cli/clean.md @@ -0,0 +1,18 @@ +# heta clean + +Clean generated wiki knowledge while keeping original raw files. + +```bash +heta clean +heta clean --yes +``` + +What it does: + +- Clears generated wiki pages. +- Resets `wiki/index.md`. +- Appends a clean operation to `wiki/log.md`. +- Deletes the local wiki vector database. +- Keeps `~/.heta/workspace/kb/raw`. +- Commits the clean operation to the wiki Git repo. + diff --git a/docs/cli/dynamic-insert.md b/docs/cli/dynamic-insert.md new file mode 100644 index 0000000..44dfd60 --- /dev/null +++ b/docs/cli/dynamic-insert.md @@ -0,0 +1,18 @@ +# heta dynamic-insert + +Control whether `heta insert` uses dynamic LLM wiki merging. + +```bash +heta dynamic-insert status +heta dynamic-insert on +heta dynamic-insert off +``` + +Default after `heta init` is off. With dynamic insert off, `heta insert` +uses static insertion: the LLM writes only the page summary, while Little Heta +writes the page structure, source section, index, log, Git commit, and vector +index updates. + +Turn dynamic insert on to use the older tool-calling merge agent that can read +existing pages, update related pages, and merge a new document into an existing +topic page. diff --git a/docs/cli/init.md b/docs/cli/init.md new file mode 100644 index 0000000..dad6e76 --- /dev/null +++ b/docs/cli/init.md @@ -0,0 +1,21 @@ +# heta init + +Set up Little Heta for the first time. + +```bash +heta init +``` + +What it does: + +- Creates `~/.heta/heta.yaml`. +- Configures one LLM provider: Qwen, ChatGPT, or Gemini. +- Optionally configures MinerU for PDF and Office parsing. +- Enables vector indexing and insert planning by default. +- Installs the Little Heta skill into Codex and Claude Code. + +Prepare before running: + +- Your LLM provider API key. +- Optional MinerU API key from https://mineru.net/apiManage/docs. + diff --git a/docs/cli/insert-planning.md b/docs/cli/insert-planning.md new file mode 100644 index 0000000..24a08c0 --- /dev/null +++ b/docs/cli/insert-planning.md @@ -0,0 +1,14 @@ +# heta insert-planning + +Manage smart insert planning. + +```bash +heta insert-planning status +heta insert-planning on +heta insert-planning off +``` + +When enabled, large PDFs are profiled before parsing. Little Heta samples PDF +metadata, outline, page count, and page text, asks a planning agent for split +ranges, validates the plan, and then parses smaller parts. + diff --git a/docs/cli/insert.md b/docs/cli/insert.md new file mode 100644 index 0000000..8a8deee --- /dev/null +++ b/docs/cli/insert.md @@ -0,0 +1,26 @@ +# heta insert + +Insert files or directories into the Little Heta knowledge base. + +```bash +heta insert ./docs +heta insert report.pdf notes.md +``` + +What it does: + +- Copies original files into `~/.heta/workspace/kb/raw`. +- Parses supported formats into Markdown. +- Runs the wiki merge agent. +- Updates wiki pages under `~/.heta/workspace/kb/wiki`. +- Updates the vector index when vector indexing is enabled. +- Commits wiki changes with Git. + +Large PDFs are planned and split by default before parsing. Control that with: + +```bash +heta insert-planning status +heta insert-planning on +heta insert-planning off +``` + diff --git a/docs/cli/mem-clean.md b/docs/cli/mem-clean.md new file mode 100644 index 0000000..9b7aeff --- /dev/null +++ b/docs/cli/mem-clean.md @@ -0,0 +1,12 @@ +# heta mem-clean + +Erase all Little Heta memory data. + +```bash +heta mem-clean +heta mem-clean --yes +``` + +This clears personal memory and KB insight memory while preserving the database +schema. It does not remove wiki pages or raw source files. + diff --git a/docs/cli/mem-show.md b/docs/cli/mem-show.md new file mode 100644 index 0000000..b537396 --- /dev/null +++ b/docs/cli/mem-show.md @@ -0,0 +1,19 @@ +# heta mem-show + +Inspect saved memory. + +```bash +heta mem-show insights +``` + +Options: + +```bash +heta mem-show insights --source pages/1-heta.md +heta mem-show insights --question "rate limits" +heta mem-show insights --limit 20 +heta mem-show insights --full +``` + +This is mainly for inspecting KB insight memories created by `heta ask`. + diff --git a/docs/cli/query.md b/docs/cli/query.md new file mode 100644 index 0000000..2cd1e7d --- /dev/null +++ b/docs/cli/query.md @@ -0,0 +1,19 @@ +# heta query + +Ask a read-only question against inserted documents. + +```bash +heta query "What does the design doc say about rate limits?" +``` + +Use this when the answer should come from the wiki knowledge base, not personal +memory. + +Options: + +```bash +heta query "..." --top-k 5 +``` + +`--top-k` controls how many vector matches are offered to the query agent. + diff --git a/docs/cli/recall.md b/docs/cli/recall.md new file mode 100644 index 0000000..61df03b --- /dev/null +++ b/docs/cli/recall.md @@ -0,0 +1,19 @@ +# heta recall + +Retrieve saved memory relevant to a query. + +```bash +heta recall "what did we decide about the database?" +``` + +Use this when you want personal memory only, without searching the document +wiki. + +Options: + +```bash +heta recall "..." --top-k 10 --debug +``` + +`--debug` shows ranking and evidence details. + diff --git a/docs/cli/remember.md b/docs/cli/remember.md new file mode 100644 index 0000000..5a64846 --- /dev/null +++ b/docs/cli/remember.md @@ -0,0 +1,11 @@ +# heta remember + +Save a fact, decision, preference, or useful context into Little Heta memory. + +```bash +heta remember "We decided to use Postgres for the main store." +``` + +Use it when the user states something that should be available in later +sessions or for other agents. + diff --git a/docs/cli/skill.md b/docs/cli/skill.md new file mode 100644 index 0000000..f415ee2 --- /dev/null +++ b/docs/cli/skill.md @@ -0,0 +1,24 @@ +# heta skill + +Install or refresh the Little Heta agent skill. + +```bash +heta skill +``` + +It writes the skill to: + +```text +~/.heta/skills/heta +~/.codex/skills/heta +~/.claude/skills/heta +``` + +Codex and Claude Code can use the global skill automatically. For other agent +frameworks, copy: + +```text +~/.heta/skills/heta/SKILL.md +~/.heta/skills/heta/COMMANDS.md +``` + diff --git a/docs/cli/status.md b/docs/cli/status.md new file mode 100644 index 0000000..d1e9d7c --- /dev/null +++ b/docs/cli/status.md @@ -0,0 +1,18 @@ +# heta status + +Show the current Little Heta setup and workspace usage. + +```bash +heta status +``` + +It displays: + +- Config path. +- LLM provider. +- MinerU status. +- Vector index status. +- Insert planning status. +- Raw file and wiki page counts. +- `~/.heta` space usage. + diff --git a/docs/cli/vector.md b/docs/cli/vector.md new file mode 100644 index 0000000..fea67f5 --- /dev/null +++ b/docs/cli/vector.md @@ -0,0 +1,14 @@ +# heta vector + +Manage document vector indexing. + +```bash +heta vector status +heta vector on +heta vector off +``` + +When enabled, Little Heta syncs wiki chunks into the local sqlite-vec database +after insert. `heta query` can then retrieve relevant page sections before the +agent reads the wiki. + diff --git a/docs/i18n/README.de.md b/docs/i18n/README.de.md new file mode 100644 index 0000000..a56edff --- /dev/null +++ b/docs/i18n/README.de.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta ist eine lokale CLI-Wissensinfrastruktur für persönliche Dokumente, Agenten-Gedächtnis und Dokumentenintelligenz. Es verwandelt PDFs, Office-Dateien, Bilder, Audio, Code, HTML, Markdown und Notizen in ein stabiles Markdown-Wiki mit Vektorsuche und wiederverwendbarem Gedächtnis. + +## Installation + +Von PyPI installieren: + +```bash +pip install little-heta +``` + +Aus einem lokalen Checkout: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` benötigt einen LLM-API-Schlüssel. Für PDF- und Office-Parsing kann optional MinerU verwendet werden: https://mineru.net/apiManage/docs. + +## Core Concepts + +- **Wiki foundation**: Das Wiki dient als grundlegende Wissensschicht; Rohdateien werden zu stabilen Markdown-Seiten mit numerischen IDs. +- **Vector Wiki**: Seiten werden anhand der Markdown-Struktur in Abschnitte geteilt. +- **Memory reuse**: `heta ask` speichert nützliche Erkenntnisse aus der Wissensbasis zur späteren Wiederverwendung. +- **Agent skills**: `heta init` installiert den Little-Heta-Skill für Codex und Claude Code. + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.es.md b/docs/i18n/README.es.md new file mode 100644 index 0000000..a1ff57b --- /dev/null +++ b/docs/i18n/README.es.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta es una infraestructura local de conocimiento por CLI para documentos personales, memoria de agentes e inteligencia documental. Convierte PDFs, archivos de Office, imágenes, audio, código, HTML, Markdown y notas en una wiki Markdown estable, con búsqueda vectorial y memoria reutilizable. + +## Installation + +Instalar desde PyPI: + +```bash +pip install little-heta +``` + +Desde una copia local del repositorio: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` requiere una clave API de un LLM. El análisis de PDF y Office puede usar MinerU de forma opcional: https://mineru.net/apiManage/docs. + +## Core Concepts + +- **Wiki foundation**: la wiki funciona como capa base de conocimiento; los archivos originales se convierten en páginas Markdown estables con identificadores numéricos. +- **Vector Wiki**: las páginas se dividen según la estructura Markdown para recuperar secciones concretas. +- **Memory reuse**: `heta ask` guarda conocimientos útiles de la base de conocimiento para reutilizarlos después. +- **Agent skills**: `heta init` instala la skill de Little Heta para Codex y Claude Code. + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.fr.md b/docs/i18n/README.fr.md new file mode 100644 index 0000000..41dd6ef --- /dev/null +++ b/docs/i18n/README.fr.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta est une infrastructure locale de connaissance en ligne de commande pour les documents personnels, la mémoire d'agents et l'intelligence documentaire. Il transforme les PDF, fichiers Office, images, fichiers audio, code, HTML, Markdown et notes en wiki Markdown stable, avec recherche vectorielle et mémoire réutilisable. + +## Installation + +Installer depuis PyPI : + +```bash +pip install little-heta +``` + +Depuis une copie locale du dépôt : + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` nécessite une clé API LLM. L'analyse des PDF et fichiers Office peut utiliser MinerU en option : https://mineru.net/apiManage/docs. + +## Core Concepts + +- **Wiki foundation** : le wiki sert de couche de base de connaissance ; les fichiers bruts deviennent des pages Markdown stables avec identifiants numériques. +- **Vector Wiki** : les pages sont découpées selon la structure Markdown pour retrouver les bonnes sections. +- **Memory reuse** : `heta ask` enregistre les connaissances utiles de la base pour les réutiliser plus tard. +- **Agent skills** : `heta init` installe la skill Little Heta pour Codex et Claude Code. + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.ja.md b/docs/i18n/README.ja.md new file mode 100644 index 0000000..29004d0 --- /dev/null +++ b/docs/i18n/README.ja.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta は、個人ドキュメント、Agent メモリ、ドキュメントインテリジェンスのためのローカル CLI 知識基盤です。PDF、Office、画像、音声、コード、HTML、Markdown、ノートを安定した Markdown Wiki に変換し、ベクトル検索と再利用可能なメモリ層を提供します。 + +## Installation + +PyPI からインストール: + +```bash +pip install little-heta +``` + +ローカルリポジトリから: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` には LLM API key が必要です。PDF と Office の解析には MinerU を任意で利用できます:https://mineru.net/apiManage/docs。 + +## Core Concepts + +- **Wiki foundation**:Wiki は知識の基盤層であり、元ファイルを安定した番号付き Markdown ページに変換します。 +- **Vector Wiki**:Markdown の階層に沿ってページを分割し、必要な章へ素早く到達します。 +- **Memory reuse**:`heta ask` は高コストな検索結果を知識として保存し、後続の質問で再利用できます。 +- **Agent skills**:`heta init` は Codex と Claude Code 用の Little Heta skill を自動でインストールします。 + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.ko.md b/docs/i18n/README.ko.md new file mode 100644 index 0000000..8c49eb9 --- /dev/null +++ b/docs/i18n/README.ko.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta는 개인 문서, Agent 메모리, 문서 지능을 위한 로컬 CLI 지식 인프라입니다. PDF, Office, 이미지, 오디오, 코드, HTML, Markdown, 노트를 안정적인 Markdown Wiki로 변환하고, 벡터 검색과 재사용 가능한 메모리 계층을 제공합니다. + +## Installation + +PyPI에서 설치: + +```bash +pip install little-heta +``` + +로컬 저장소에서 설치: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init`에는 LLM API key가 필요합니다. PDF 및 Office 파싱에는 선택적으로 MinerU를 사용할 수 있습니다: https://mineru.net/apiManage/docs. + +## Core Concepts + +- **Wiki foundation**: Wiki는 지식 기반 계층이며, 원본 파일을 안정적인 번호가 있는 Markdown 페이지로 컴파일합니다. +- **Vector Wiki**: Markdown 구조에 따라 페이지를 나누어 관련 섹션을 더 빠르게 찾습니다. +- **Memory reuse**: `heta ask`는 비용이 큰 검색 결과를 메모리로 저장하고 이후 질문에서 재사용합니다. +- **Agent skills**: `heta init`은 Codex와 Claude Code용 Little Heta skill을 자동 설치합니다. + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.pt.md b/docs/i18n/README.pt.md new file mode 100644 index 0000000..ccfbd7b --- /dev/null +++ b/docs/i18n/README.pt.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta é uma infraestrutura local de conhecimento via CLI para documentos pessoais, memória de agentes e inteligência documental. Ele converte PDFs, arquivos Office, imagens, áudio, código, HTML, Markdown e notas em uma wiki Markdown estável, com busca vetorial e memória reutilizável. + +## Installation + +Instale pelo PyPI: + +```bash +pip install little-heta +``` + +A partir de um repositório local: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` exige uma chave de API de um LLM. A análise de PDF e Office pode usar o MinerU opcionalmente: https://mineru.net/apiManage/docs. + +## Core Concepts + +- **Wiki foundation**: a wiki funciona como camada base de conhecimento; arquivos originais viram páginas Markdown estáveis com identificadores numéricos. +- **Vector Wiki**: páginas são divididas pela estrutura Markdown para recuperação por seção. +- **Memory reuse**: `heta ask` salva conhecimentos úteis da base para reutilização em perguntas futuras. +- **Agent skills**: `heta init` instala a skill do Little Heta para Codex e Claude Code. + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.zh-CN.md b/docs/i18n/README.zh-CN.md new file mode 100644 index 0000000..3b215fb --- /dev/null +++ b/docs/i18n/README.zh-CN.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta 是一个本地 CLI 知识基础设施,用于个人文档、Agent 记忆和文档智能。它把 PDF、Office、图片、音频、代码、HTML、Markdown 和笔记转成稳定的 Markdown Wiki,并提供向量检索和可复用的记忆层。 + +## Installation + +从 PyPI 安装: + +```bash +pip install little-heta +``` + +从本地仓库安装: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` 需要你准备一个 LLM API key。PDF 和 Office 解析可选接入 MinerU:https://mineru.net/apiManage/docs。 + +## Core Concepts + +- **Wiki foundation**:Wiki 是知识基础层,原始文件会被编译成带稳定编号的 Markdown 页面。 +- **Vector Wiki**:按照 Markdown 层级切分页面,让查询更容易命中具体章节。 +- **Memory reuse**:`heta ask` 可以把昂贵查询得到的知识沉淀为记忆,后续问题复用。 +- **Agent skills**:`heta init` 会自动安装 Codex 和 Claude Code 可使用的 Little Heta skill。 + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/i18n/README.zh-TW.md b/docs/i18n/README.zh-TW.md new file mode 100644 index 0000000..412c9f4 --- /dev/null +++ b/docs/i18n/README.zh-TW.md @@ -0,0 +1,54 @@ +# Little Heta + +[English](../../README.md) | [简体中文](README.zh-CN.md) | [繁體中文](README.zh-TW.md) | [日本語](README.ja.md) | [한국어](README.ko.md) | [Español](README.es.md) | [Português](README.pt.md) | [Français](README.fr.md) | [Deutsch](README.de.md) + +## What is Heta + +Little Heta 是一個本地 CLI 知識基礎設施,用於個人文件、Agent 記憶與文件智慧。它會將 PDF、Office、圖片、音訊、程式碼、HTML、Markdown 和筆記轉成穩定的 Markdown Wiki,並提供向量檢索與可重複使用的記憶層。 + +## Installation + +從 PyPI 安裝: + +```bash +pip install little-heta +``` + +從本地倉庫安裝: + +```bash +pip install -e . +``` + +## Quick Start + +```bash +heta init +heta status +heta insert ./docs +heta ask "What does my knowledge base say about this?" +``` + +`heta init` 需要你準備一個 LLM API key。PDF 和 Office 解析可選接入 MinerU:https://mineru.net/apiManage/docs。 + +## Core Concepts + +- **Wiki foundation**:Wiki 是知識基礎層,原始文件會被編譯成帶穩定編號的 Markdown 頁面。 +- **Vector Wiki**:依照 Markdown 層級切分頁面,讓查詢更容易命中特定章節。 +- **Memory reuse**:`heta ask` 可以把昂貴查詢得到的知識沉澱為記憶,供後續問題重複使用。 +- **Agent skills**:`heta init` 會自動安裝 Codex 和 Claude Code 可使用的 Little Heta skill。 + +## Minimal Examples + +```bash +heta query "What does the design doc say?" +heta remember "We decided to use Postgres." +heta recall "database decision" +heta skill +``` + +## Community Links + +- GitHub: https://github.com/KnowledgeXLab/Little_Heta +- Team: https://knowledgexlab.github.io/ +- License: MIT diff --git a/docs/technical-explanation/pdf_planning_agent_flow.md b/docs/technical-explanation/pdf_planning_agent_flow.md deleted file mode 100644 index 0008b42..0000000 --- a/docs/technical-explanation/pdf_planning_agent_flow.md +++ /dev/null @@ -1,244 +0,0 @@ -# PDF Planning Agent Flow - -This document describes the Little Heta large-PDF planning flow used by `heta insert`. - -The goal is to avoid sending an entire large PDF directly into parsing and wiki merge. Little Heta first builds a lightweight PDF profile, asks a planning agent how to split the document, validates the plan, and then performs deterministic PDF splitting in code. - -## High-Level Flow - -```mermaid -flowchart TD - A["heta insert input"] --> B["collect_insert_files
validate extension"] - B --> C{"file is PDF?"} - - C -->|No| D["save file to raw/"] - D --> E["parse_document"] - E --> Z["agent merge wiki"] - - C -->|Yes| F["estimate_pdf_pages
pypdf page count"] - F --> G{"pdf planning enabled
and page_count > 80?"} - - G -->|No| H["save original PDF to raw/"] - H --> E - - G -->|Yes| I["save original PDF to raw/originals/"] - I --> J["build_pdf_profile"] - J --> K["run_pdf_planning_agent"] - K --> L{"agent plan valid?"} - - L -->|No| M["fallback plan
fixed 40-page windows"] - L -->|Yes| N["validate and normalize plan"] - - N --> O["split oversized units
max 40 pages each"] - O --> P["fill missing page ranges"] - P --> Q["remove overlaps by sorting and cropping"] - - M --> R["split_pdf_to_raw_parts"] - Q --> R - - R --> S["write part PDFs to raw/"] - R --> T["write part .meta.json"] - S --> U["parse each PDF part"] - T --> U - U --> Z -``` - -## Agent Input - -The planning agent does not read the full PDF. It receives a lightweight `PdfProfile`. - -```json -{ - "filename": "large-report.pdf", - "page_count": 200, - "metadata": { - "Title": "Example Report", - "Author": "Example Author" - }, - "outline": [ - { - "title": "Chapter 1 Introduction", - "page": 1, - "depth": 0 - } - ], - "page_samples": [ - { - "page": 1, - "text": "sampled text from page 1..." - }, - { - "page": 26, - "text": "sampled text from page 26..." - } - ], - "heading_candidates": [ - { - "page": 1, - "text": "Chapter 1 Introduction" - } - ] -} -``` - -Plain-language meaning: - -- `filename`: The PDF file name. -- `page_count`: Total number of pages. -- `metadata`: PDF metadata such as title, author, and subject. -- `outline`: Built-in PDF bookmarks or outline entries, if present. -- `page_samples`: Extracted text from selected sample pages. -- `heading_candidates`: Heading-like lines detected from sampled text. - -Sampling policy: - -- First four pages. -- Every `page_count // 8` pages. -- Last page. -- Each sampled page is truncated to about 900 characters. -- The final profile sent to the agent is capped at about 12,000 characters. - -## Agent Output - -The planning agent must return JSON only. - -```json -{ - "document_type": "textbook", - "split_strategy": "chapter", - "units": [ - { - "title": "Chapter 1: Introduction", - "start_page": 1, - "end_page": 32 - }, - { - "title": "Chapter 2: Methods", - "start_page": 33, - "end_page": 78 - } - ] -} -``` - -Plain-language meaning: - -- `document_type`: The agent's best guess of the PDF type. - - `textbook` - - `paper_collection` - - `report` - - `slides` - - `manual` - - `scanned_book` - - `mixed` -- `split_strategy`: How the agent thinks the PDF should be split. - - `outline` - - `chapter` - - `section` - - `fixed_page_window` - - `fallback` -- `units`: The proposed page ranges. Page numbers are 1-based and inclusive. - -## Validation Logic - -The agent output is never trusted directly. Little Heta validates and normalizes the plan before splitting. - -```mermaid -flowchart TD - A["agent JSON output"] --> B{"is JSON object?"} - B -->|No| F["fallback fixed windows"] - B -->|Yes| C{"has non-empty units?"} - C -->|No| F - C -->|Yes| D{"all page ranges legal?"} - D -->|No| F - D -->|Yes| E["normalize units"] - - E --> G{"unit > 40 pages?"} - G -->|Yes| H["split oversized unit
into <=40-page parts"] - G -->|No| I["keep unit"] - H --> J["sort units by start_page"] - I --> J - - J --> K["crop overlapping ranges"] - K --> L["fill missing ranges
with fixed windows"] - L --> M["validated split units"] - F --> M -``` - -Validation rules: - -- The output must be parseable JSON. -- The JSON must be an object. -- `units` must exist and must not be empty. -- Each unit must satisfy: - - `start_page >= 1` - - `end_page <= page_count` - - `start_page <= end_page` -- Oversized units are split into smaller parts. -- Missing page ranges are filled automatically. -- Overlapping ranges are cropped after sorting. -- If the plan is invalid, Little Heta falls back to fixed 40-page windows. - -## Split Output - -For a large PDF, Little Heta stores: - -```text -raw/ - originals/ - 2026-05-13_143000_big-book.pdf - - 2026-05-13_143000_big-book_part-001_intro_pages-1-40.pdf - 2026-05-13_143000_big-book_part-001_intro_pages-1-40.meta.json - 2026-05-13_143000_big-book_part-002_methods_pages-41-80.pdf - 2026-05-13_143000_big-book_part-002_methods_pages-41-80.meta.json -``` - -Each `.meta.json` records traceability: - -```json -{ - "original": "raw/originals/2026-05-13_143000_big-book.pdf", - "part": "raw/2026-05-13_143000_big-book_part-001_intro_pages-1-40.pdf", - "title": "Introduction", - "start_page": 1, - "end_page": 40, - "document_type": "report", - "split_strategy": "section" -} -``` - -## Fallback Behavior - -Fallback is intentionally simple and deterministic: - -```text -Pages 1-40 -Pages 41-80 -Pages 81-120 -... -``` - -Fallback is used when: - -- The planning agent fails. -- The agent returns non-JSON text. -- The JSON shape is invalid. -- Page ranges are illegal. -- `units` is empty. -- No LLM config is provided to the planning function. - -## Design Boundary - -The planning agent only decides how the PDF should be split. It does not parse the full PDF, edit wiki pages, create markdown pages, or write files. - -The deterministic code owns: - -- PDF page counting. -- Profile generation. -- Plan validation. -- Missing page recovery. -- Oversized unit splitting. -- Actual PDF splitting. -- Raw file and metadata writing. - diff --git a/pyproject.toml b/pyproject.toml index 63cb7f0..d99e210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,27 +1,28 @@ [build-system] -requires = ["setuptools>=68", "wheel"] +requires = ["setuptools==80.9.0", "wheel==0.45.1"] build-backend = "setuptools.build_meta" [project] name = "little-heta" -version = "0.1.0" +version = "0.2.1" description = "Little Heta first-time initialization CLI" readme = "README.md" requires-python = ">=3.10" license = "MIT" dependencies = [ - "typer>=0.12.0", - "rich>=13.7.0", - "PyYAML>=6.0.0", - "requests>=2.31.0", - "openai>=1.0.0", - "sqlite-vec>=0.1.6", - "pypdf>=4.0.0", + "typer==0.20.0", + "rich==14.2.0", + "PyYAML==6.0.3", + "requests==2.32.5", + "openai==2.30.0", + "sqlite-vec==0.1.9", + "pypdf==6.11.0", + "beautifulsoup4==4.14.3", ] [project.optional-dependencies] dev = [ - "pytest>=7.4.0", + "pytest==9.0.2", ] [project.scripts] @@ -30,6 +31,9 @@ heta = "heta.cli:app" [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools.package-data] +"heta.assistants" = ["templates/claude_skill/*.md"] + [tool.pytest.ini_options] pythonpath = ["src"] testpaths = ["tests"] diff --git a/src/heta/__init__.py b/src/heta/__init__.py index 7184236..0bc12a9 100644 --- a/src/heta/__init__.py +++ b/src/heta/__init__.py @@ -1,4 +1,3 @@ """Little Heta.""" -__version__ = "0.1.0" - +__version__ = "0.2.1" diff --git a/src/heta/assistants/__init__.py b/src/heta/assistants/__init__.py new file mode 100644 index 0000000..f7cf97e --- /dev/null +++ b/src/heta/assistants/__init__.py @@ -0,0 +1,105 @@ +"""Install Little Heta usage guides into AI coding assistants.""" + +from __future__ import annotations + +import importlib.resources +import os +from dataclasses import dataclass +from pathlib import Path + +from heta.config.io import CONFIG_DIR + +_CODEX_DIR = Path(os.environ.get("CODEX_HOME", Path.home() / ".codex")) +_CLAUDE_DIR = Path.home() / ".claude" + +HETA_SKILL_DIR = CONFIG_DIR / "skills" / "heta" +CODEX_SKILL_DIR = _CODEX_DIR / "skills" / "heta" +CLAUDE_SKILL_DIR = _CLAUDE_DIR / "skills" / "heta" + +_SKILL_FILES = ("SKILL.md", "COMMANDS.md") + + +@dataclass(frozen=True) +class InstalledSkill: + assistant: str + path: Path + + +def claude_code_detected() -> bool: + """True when a Claude Code config directory exists for the current user.""" + return _CLAUDE_DIR.is_dir() + + +def codex_detected() -> bool: + """True when a Codex config directory exists for the current user.""" + return _CODEX_DIR.is_dir() + + +def install_codex_skill(target_dir: Path | None = None) -> Path: + """Copy the bundled `heta` skill into the Codex global skills directory.""" + return _copy_skill(target_dir or CODEX_SKILL_DIR) + + +def install_claude_skill(target_dir: Path | None = None) -> Path: + """Copy the bundled `heta` skill into the Claude Code skills directory. + + Overwrites any existing copy so re-running `heta init` refreshes the skill. + Returns the directory the skill was written to. + """ + return _copy_skill(target_dir or CLAUDE_SKILL_DIR) + + +def install_portable_skill(target_dir: Path | None = None) -> Path: + """Copy the bundled `heta` skill into Little Heta's own config directory.""" + return _copy_skill(target_dir or HETA_SKILL_DIR) + + +def install_assistant_skills() -> list[InstalledSkill]: + """Install Little Heta skills into supported global assistant folders.""" + return [ + InstalledSkill("Little Heta", install_portable_skill()), + InstalledSkill("Codex", install_codex_skill()), + InstalledSkill("Claude Code", install_claude_skill()), + ] + + +def skill_template_files() -> tuple[str, ...]: + """Files that make up the portable Little Heta skill.""" + return _SKILL_FILES + + +def skill_template_dir() -> Path: + """Stable user-facing directory containing Little Heta's portable skill.""" + return HETA_SKILL_DIR + + +def skill_template_hint() -> str: + """Human-readable location hint for manual agent-framework installation.""" + return "copy SKILL.md and COMMANDS.md from the Little Heta skill folder" + + +def _copy_skill(dest: Path) -> Path: + dest.mkdir(parents=True, exist_ok=True) + + source = importlib.resources.files("heta.assistants") / "templates" / "claude_skill" + for name in _SKILL_FILES: + text = (source / name).read_text(encoding="utf-8") + (dest / name).write_text(text, encoding="utf-8") + return dest + + +__all__ = [ + "CLAUDE_SKILL_DIR", + "CODEX_SKILL_DIR", + "HETA_SKILL_DIR", + "InstalledSkill", + "claude_code_detected", + "codex_detected", + "install_assistant_skills", + "install_claude_skill", + "install_codex_skill", + "install_portable_skill", + "skill_template_dir", + "skill_template_files", + "skill_template_hint", +] diff --git a/src/heta/assistants/templates/claude_skill/COMMANDS.md b/src/heta/assistants/templates/claude_skill/COMMANDS.md new file mode 100644 index 0000000..c9d6aca --- /dev/null +++ b/src/heta/assistants/templates/claude_skill/COMMANDS.md @@ -0,0 +1,62 @@ +# Little Heta — full command reference + +Every `heta` command. The four core commands (`ask`, `query`, `recall`, +`remember`) are covered in `SKILL.md`; this file documents the rest. Run any +command with the Bash tool. + +## Setup + +### `heta init` +Interactive first-time setup — LLM provider + API key, MinerU document +parsing, and so on. Interactive: the user must run this themselves. Never run +it for them. + +### `heta status` +Show what is configured and how much is indexed. No arguments. + +## Indexing documents + +### `heta insert [PATHS...]` +Add files or folders to the knowledge base. Defaults to the current directory. +Supports PDF, Office, images, audio, code, HTML, and Markdown. + +```bash +heta insert ./docs +heta insert report.pdf notes.md +``` + +### `heta clean [-y]` +Remove generated wiki pages and the vector index. Original raw files are kept. +`-y` / `--yes` skips the confirmation prompt. + +## Core command options + +`ask`, `query`, `recall`, and `remember` are documented in `SKILL.md`. Their +extra options: + +- `heta ask "" [-k N] [-d]` — `-k` / `--top-k` results per layer + (default 5); `-d` / `--debug` shows agent steps and evidence. +- `heta query "" [--top-k N]` — `--top-k` initial vector matches, + 1–10 (default 5). +- `heta recall "" [-k N] [-d]` — `-k` / `--top-k` results per layer + (default 10); `-d` / `--debug` shows layer ranking, reason, and scored evidence. +- `heta remember ""` — no extra options. + +## Inspecting & clearing memory + +### `heta mem-show insights [-s SOURCE] [-q QUESTION] [-n LIMIT] [-f]` +List stored KB-insight memories, newest first. `-s` / `--source` filters by +source path, `-q` / `--question` filters by question, `-n` / `--limit` caps +rows (default 50), `-f` / `--full` shows full untruncated text. + +### `heta mem-clean [-y]` +Erase all saved memory. `-y` / `--yes` skips confirmation. Irreversible. + +## Settings + +### `heta vector on | off | status` +Turn document search vector indexing on or off, or show its current state. + +### `heta insert-planning on | off | status` +Turn smart insert planning (such as large-PDF splitting) on or off, or show +its current state. diff --git a/src/heta/assistants/templates/claude_skill/SKILL.md b/src/heta/assistants/templates/claude_skill/SKILL.md new file mode 100644 index 0000000..71f611d --- /dev/null +++ b/src/heta/assistants/templates/claude_skill/SKILL.md @@ -0,0 +1,53 @@ +--- +name: heta +description: Search and recall the user's own documents, files, and saved memory — and save new things worth remembering — using Little Heta, a local CLI knowledge base. Use whenever a task needs external knowledge from the user's own materials or earlier context, or when the user states a fact, decision, or preference worth keeping, instead of guessing or grepping files. +--- + +# Little Heta + +`heta` is a local command-line knowledge base. It indexes the user's +documents — PDF, Office, images, audio, code, Markdown — answers questions +from them and from saved memory, and can store new memories. + +Reach for `heta` whenever a task needs **external knowledge** from the user's +own documents, or needs to **recall or save memory** — instead of guessing or +grepping. Run commands with the Bash tool. + +## 1. Check Heta is set up + +Run `heta status` once at the start. + +- Shows a model provider and KB files → Heta is ready, continue. +- Shows config missing / "not configured" → tell the user to run `heta init` + themselves. It is an interactive API-key setup, so do **not** run it yourself. + +## 2. Four core commands + +These four cover retrieval and memory. **Default to `heta ask`.** + +| Command | When to use it | +|---------|----------------| +| `heta ask ""` | **Default.** Answers from saved memory and indexed documents together. | +| `heta query ""` | When the answer must come strictly from indexed documents. | +| `heta recall ""` | When you want the user's personal memory (past chats, facts), not documents. | +| `heta remember ""` | When the user states a fact, decision, or preference worth keeping for later. | + +Examples: + +```bash +heta ask "How does our auth flow refresh tokens?" +heta query "What does the design doc say about rate limits?" +heta recall "what did I decide about the database" +heta remember "We decided to use Postgres for the main store." +``` + +Show the user Heta's output, then add a short summary. + +## 3. Other commands + +Heta can also index files (`heta insert`), clean up, and toggle settings. + +- For a quick list: run `heta --help`. +- For full usage of any command: read `COMMANDS.md` in this skill's directory + (next to this file). Only read it when the user actually needs one of those + commands — do not load it ahead of time. diff --git a/src/heta/cli/__init__.py b/src/heta/cli/__init__.py index 6425447..032710f 100644 --- a/src/heta/cli/__init__.py +++ b/src/heta/cli/__init__.py @@ -4,14 +4,25 @@ import typer +from heta.cli.ask import ask_command +from heta.cli.branding import apply_typer_theme from heta.cli.clean import clean_command +from heta.cli.dynamic_insert import app as dynamic_insert_app +from heta.cli.mem_clean import mem_clean_command +from heta.cli.mem_show import app as mem_show_app from heta.cli import init as init_module from heta.cli.init import interactive_init from heta.cli.insert import insert_command +from heta.cli.insert_planning import app as insert_planning_app from heta.cli.query import query_command +from heta.cli.recall import recall_command +from heta.cli.remember import remember_command +from heta.cli.skill import skill_command from heta.cli.status import status_command from heta.cli.vector import app as vector_app +apply_typer_theme() + app = typer.Typer( name="heta", help="Little Heta command line interface.", @@ -27,7 +38,7 @@ def main() -> None: @app.command("init") def init_command() -> None: - """Run the first-time Little Heta initialization wizard.""" + """Set up Little Heta for the first time.""" try: interactive_init() except (KeyboardInterrupt, EOFError): @@ -35,8 +46,16 @@ def init_command() -> None: raise typer.Exit(130) from None +app.command("ask")(ask_command) +app.command("mem-clean")(mem_clean_command) app.command("insert")(insert_command) app.command("query")(query_command) app.command("clean")(clean_command) +app.command("remember")(remember_command) +app.command("recall")(recall_command) +app.command("skill")(skill_command) app.command("status")(status_command) +app.add_typer(dynamic_insert_app) +app.add_typer(insert_planning_app) app.add_typer(vector_app) +app.add_typer(mem_show_app) diff --git a/src/heta/cli/ask.py b/src/heta/cli/ask.py new file mode 100644 index 0000000..fea43b9 --- /dev/null +++ b/src/heta/cli/ask.py @@ -0,0 +1,115 @@ +"""CLI command: heta ask.""" + +from __future__ import annotations + +import typer +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.text import Text + +from heta.cli.branding import CYAN, ERR, HETA, MUTED, OK, WARN +from heta.config.io import load_config +from heta.query.smart_query import smart_query + +console = Console() + +_SOURCE_STYLES = { + "memory": ("● memory", f"bold {OK}"), + "kb": ("● KB", f"bold {HETA}"), + "both": ("● memory + KB", f"bold {CYAN}"), +} + + +def ask_command( + question: str = typer.Argument(..., help="Question to ask."), + top_k: int = typer.Option(5, "--top-k", "-k", help="Results per layer / KB vector match."), + debug: bool = typer.Option(False, "--debug", "-d", help="Print agent steps, memory evidence, and KB output."), +) -> None: + """Ask anything — answered from your memory and documents.""" + config = load_config() + if config is None: + console.print(f"[{ERR}]Heta is not initialised. Run `heta init` first.[/]") + raise typer.Exit(1) + + with console.status(f"[bold {HETA}]Thinking...[/]"): + result = smart_query(question, config, top_k=top_k) + + if debug: + console.print(f"\n[bold {WARN}]── DEBUG ──[/]") + console.print(f"agent steps: {' → '.join(result.agent_steps) or '(none)'}") + + if result.memory_evidence: + console.print("\n[bold]memory evidence:[/bold]") + for layer_ev in result.memory_evidence: + console.print(f" [bold]{layer_ev.layer}[/bold] ({len(layer_ev.items)} hits)") + for i, item in enumerate(layer_ev.items, 1): + score = item.get("score", 0) + console.print(f" [dim][{i}; score={score:.4f}][/dim]") + if layer_ev.layer == "raw": + console.print(f" {item.get('text_content', '')}") + elif layer_ev.layer == "episode": + console.print(f" {item.get('summary', '')}") + elif layer_ev.layer == "kb_insight": + console.print(f" [dim]source:[/dim] {item.get('source_path', '')}") + console.print(f" {item.get('insight', '')}") + else: + console.print(f" {item.get('fact_text', '')}") + + if result.kb_result: + console.print("\n[bold]kb result:[/bold]") + paths = [s.path for s in result.kb_result.sources] + console.print(f" used sources: {paths or '(empty)'}") + if result.kb_result.insights: + console.print(f" agent insights ({len(result.kb_result.insights)}):") + for i, qi in enumerate(result.kb_result.insights, 1): + console.print(f" [dim][{i}] sources:[/dim] {qi.source_paths}") + console.print(f" {qi.text}") + console.print(f" written_back: {result.written_back}") + console.print(f"[bold {WARN}]──────────[/]\n") + + label, style = _SOURCE_STYLES[result.source] + source_line = Text() + source_line.append("Source: ", style=f"bold {HETA}") + source_line.append(label, style=style) + if result.written_back: + source_line.append(f" ({result.written_back} memories written back)", style=MUTED) + + console.print( + Panel( + _AnswerRenderable(Markdown(result.answer), source_line, _kb_sources_text(result)), + title="ask", + border_style=HETA, + padding=(1, 2), + ) + ) + + +def _kb_sources_text(result) -> Text: + text = Text() + if not (result.kb_result and result.kb_result.sources): + return text + for src in result.kb_result.sources: + title = src.title or src.path + heading = f" — {src.heading_path}" if src.heading_path else "" + text.append(f"[{src.wiki_id}] ", style=MUTED) + text.append(f"{title}{heading}\n") + text.rstrip() + return text + + +class _AnswerRenderable: + def __init__(self, answer: Markdown, source: Text, kb_sources: Text) -> None: + self.answer = answer + self.source = source + self.kb_sources = kb_sources + + def __rich_console__(self, console: Console, options): + yield Text("Answer:", style=f"bold {HETA}") + yield self.answer + yield Text("") + yield self.source + if self.kb_sources.plain: + yield Text("") + yield Text("KB Sources:", style=f"bold {HETA}") + yield self.kb_sources diff --git a/src/heta/cli/branding.py b/src/heta/cli/branding.py index 0bcd959..c1cf60a 100644 --- a/src/heta/cli/branding.py +++ b/src/heta/cli/branding.py @@ -1,13 +1,21 @@ -"""Shared Little Heta CLI branding.""" +"""Shared Little Heta CLI branding and color palette. + +Single source of truth for the Heta blue color family. Every CLI command +imports its colors from here so the product keeps one consistent look. +""" from __future__ import annotations from heta import __version__ -HETA = "rgb(52,144,220)" -CYAN = "rgb(88,196,220)" -OK = "rgb(76,196,142)" -MUTED = "rgb(126,146,158)" +# --- Heta blue palette --- +HETA = "rgb(52,144,220)" # primary blue — command names, arrows, panel borders +HETA_DARK = "rgb(31,91,156)" # deep blue — secondary emphasis +CYAN = "rgb(88,196,220)" # cyan accent — blended / secondary highlights +OK = "rgb(76,196,142)" # green — success +WARN = "rgb(238,183,74)" # amber — warnings, prompts +ERR = "rgb(224,108,108)" # coral red — errors, destructive markers +MUTED = "rgb(126,146,158)" # slate gray — secondary text APP_TITLE = "Little Heta" APP_TAGLINE = "Personal knowledge, memory, and document intelligence CLI" @@ -22,4 +30,47 @@ def brand_line() -> str: ) -__all__ = ["APP_TAGLINE", "APP_TEAM", "APP_TITLE", "brand_line"] +def apply_typer_theme() -> None: + """Re-skin Typer's ``--help`` screen to the Heta blue palette. + + Typer reads these module-level style constants at render time, so + overriding them once at import keeps every command's help consistent. + """ + from typer import rich_utils as ru + + ru.STYLE_OPTION = f"bold {HETA}" + ru.STYLE_SWITCH = f"bold {OK}" + ru.STYLE_NEGATIVE_OPTION = f"bold {CYAN}" + ru.STYLE_NEGATIVE_SWITCH = f"bold {ERR}" + ru.STYLE_METAVAR = f"bold {CYAN}" + ru.STYLE_METAVAR_SEPARATOR = MUTED + ru.STYLE_USAGE = HETA + ru.STYLE_HELPTEXT = MUTED + ru.STYLE_OPTION_DEFAULT = MUTED + ru.STYLE_OPTION_ENVVAR = MUTED + ru.STYLE_REQUIRED_SHORT = ERR + ru.STYLE_REQUIRED_LONG = ERR + ru.STYLE_OPTIONS_PANEL_BORDER = HETA + ru.STYLE_COMMANDS_PANEL_BORDER = HETA + ru.STYLE_ERRORS_PANEL_BORDER = ERR + ru.STYLE_COMMANDS_TABLE_FIRST_COLUMN = f"bold {HETA}" + ru.STYLE_ABORTED = ERR + ru.STYLE_DEPRECATED = ERR + ru.STYLE_DEPRECATED_COMMAND = MUTED + ru.STYLE_ERRORS_SUGGESTION = MUTED + + +__all__ = [ + "APP_TAGLINE", + "APP_TEAM", + "APP_TITLE", + "CYAN", + "ERR", + "HETA", + "HETA_DARK", + "MUTED", + "OK", + "WARN", + "apply_typer_theme", + "brand_line", +] diff --git a/src/heta/cli/clean.py b/src/heta/cli/clean.py index 3b89f18..24bda14 100644 --- a/src/heta/cli/clean.py +++ b/src/heta/cli/clean.py @@ -10,22 +10,18 @@ from rich.prompt import Confirm from rich.table import Table +from heta.cli.branding import HETA, MUTED, OK, WARN from heta.config.io import CONFIG_PATH, load_config from heta.kb import paths from heta.kb.clean import CleanSummary, clean_knowledge_base console = Console() -HETA = "rgb(52,144,220)" -MUTED = "rgb(126,146,158)" -OK = "rgb(76,196,142)" -WARN = "rgb(238,183,74)" - def clean_command( yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation."), ) -> None: - """Clean wiki knowledge pages and vector database while keeping raw files.""" + """Remove generated wiki pages (your original files are kept).""" config = load_config() if config is None: console.print(f"[{WARN}]?[/] Little Heta is not initialized.") @@ -86,6 +82,7 @@ def _show_result(summary: CleanSummary) -> None: console.print(f"[{OK}]✓[/] Clean completed.") console.print(f"[{MUTED}]pages deleted:[/] {summary.deleted_pages}") console.print(f"[{MUTED}]vector files deleted:[/] {summary.deleted_vector_files}") + console.print(f"[{MUTED}]invalidated memories:[/] {summary.invalidated_memories}") if summary.commit_id: console.print(f"[{MUTED}]wiki commit:[/] [bold {HETA}]{summary.commit_id}[/]") else: diff --git a/src/heta/cli/dynamic_insert.py b/src/heta/cli/dynamic_insert.py new file mode 100644 index 0000000..9b750e9 --- /dev/null +++ b/src/heta/cli/dynamic_insert.py @@ -0,0 +1,77 @@ +"""`heta dynamic-insert` commands.""" + +from __future__ import annotations + +from dataclasses import replace + +import typer +from rich.console import Console + +from heta.cli.branding import HETA, MUTED, OK, WARN +from heta.config.io import CONFIG_PATH, load_config, save_config +from heta.config.schema import DynamicInsertConfig + +console = Console() + +app = typer.Typer( + name="dynamic-insert", + help="Turn dynamic LLM wiki merging on or off.", + no_args_is_help=True, + rich_markup_mode="rich", +) + + +@app.command("on") +def dynamic_insert_on() -> None: + """Enable dynamic LLM wiki merging during insert.""" + _set_dynamic_insert(True) + + +@app.command("true") +def dynamic_insert_true() -> None: + """Enable dynamic LLM wiki merging during insert.""" + _set_dynamic_insert(True) + + +@app.command("off") +def dynamic_insert_off() -> None: + """Disable dynamic LLM wiki merging during insert.""" + _set_dynamic_insert(False) + + +@app.command("false") +def dynamic_insert_false() -> None: + """Disable dynamic LLM wiki merging during insert.""" + _set_dynamic_insert(False) + + +@app.command("status") +def dynamic_insert_status() -> None: + """Show whether dynamic LLM wiki merging is enabled.""" + config = _require_config() + state = "enabled" if config.dynamic_insert.enable else "disabled" + console.print(f"[{MUTED}]dynamic insert:[/] [bold {HETA}]{state}[/]") + + +def _set_dynamic_insert(enable: bool) -> None: + config = _require_config() + updated = replace(config, dynamic_insert=DynamicInsertConfig(enable=enable)) + save_config(updated) + state = "enabled" if enable else "disabled" + console.print(f"[{OK}]✓[/] dynamic insert {state}") + + +def _require_config(): + try: + config = load_config() + except Exception as exc: + console.print(f"[{WARN}]?[/] Failed to read config: {exc}") + raise typer.Exit(1) from exc + if config is None: + console.print(f"[{WARN}]?[/] Little Heta is not initialized.") + console.print(f"[{MUTED}] Missing config:[/] {CONFIG_PATH}") + raise typer.Exit(1) + return config + + +__all__ = ["app"] diff --git a/src/heta/cli/init.py b/src/heta/cli/init.py index f3db799..0b388db 100644 --- a/src/heta/cli/init.py +++ b/src/heta/cli/init.py @@ -14,21 +14,24 @@ from rich.prompt import Confirm, IntPrompt, Prompt from rich.table import Table -from heta.cli.branding import APP_TAGLINE, brand_line +from heta.assistants import install_assistant_skills, skill_template_hint +from heta.cli.branding import APP_TAGLINE, HETA, MUTED, OK, WARN, brand_line from heta.config.io import CONFIG_PATH, save_config -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import ( + DEFAULT_LLM_PROFILES, + DynamicInsertConfig, + HetaConfig, + InsertPlanningConfig, + LLMConfig, + MinerUConfig, + VectorIndexConfig, +) from heta.providers.llm import validate_llm from heta.providers.mineru import validate_mineru_cloud, validate_mineru_local console = Console() -HETA = "rgb(52,144,220)" -HETA_DARK = "rgb(31,91,156)" -MUTED = "rgb(126,146,158)" -OK = "rgb(76,196,142)" -WARN = "rgb(238,183,74)" - -LLM_PROVIDERS = {1: "qwen", 2: "chatgpt", 3: "gemini"} +LLM_PROVIDERS = {1: "qwen", 2: "chatgpt", 3: "gemini", 4: "custom"} MINERU_OPTIONS = {1: "cloud", 2: "local", 3: "skip"} MAX_RETRIES = 3 @@ -61,6 +64,8 @@ def _run_interactive_init() -> None: llm=llm_config, mineru=MinerUConfig.disabled(), vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), + dynamic_insert=DynamicInsertConfig.disabled(), ) save_config(partial_config) console.print(f"[{HETA}]→[/] wrote {CONFIG_PATH}") @@ -71,10 +76,13 @@ def _run_interactive_init() -> None: llm=llm_config, mineru=mineru_config, vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), + dynamic_insert=DynamicInsertConfig.disabled(), ) save_config(final_config) console.print(f"[{OK}]✓[/] little heta is ready") + _install_assistant_skills() _show_summary(final_config) @@ -97,12 +105,7 @@ def _show_welcome() -> None: ) ) console.print() - console.print(f" [{WARN}]Tip:[/] Run this once to connect Little Heta to your providers.") - console.print() - console.print(f" [{MUTED}]Learn more:[/] https://github.com/KnowledgeXLab/Heta") - console.print() - console.print(f"[{MUTED}]$[/] [bold {HETA}]heta init[/]") - console.print(f"[bold {HETA}]little heta setup[/]") + console.print(f" [{MUTED}]Learn more:[/] https://github.com/KnowledgeXLab/Little_Heta") def _short_path(path: Path) -> str: @@ -128,10 +131,30 @@ def _configure_llm() -> LLMConfig: console.print(f" [{HETA}]1[/] qwen") console.print(f" [{HETA}]2[/] chatgpt") console.print(f" [{HETA}]3[/] gemini") + console.print(f" [{HETA}]4[/] custom") choice = _ask_choice(" Provider", LLM_PROVIDERS) provider = LLM_PROVIDERS[choice] + if provider == "custom": + _show_custom_provider_note() + chat = _configure_custom_capability("Chat", required=True) + multimodal = _configure_custom_capability("Multimodal", required=False) + embedding = _configure_custom_capability("Embedding", required=True) + return LLMConfig( + provider="custom", + api_key=chat["api_key"] or embedding["api_key"] or "", + chat_api_key=chat["api_key"], + chat_model=chat["model"], + chat_base_url=chat["base_url"], + multimodal_api_key=multimodal["api_key"], + multimodal_model=multimodal["model"], + multimodal_base_url=multimodal["base_url"], + embedding_api_key=embedding["api_key"], + embedding_model=embedding["model"], + embedding_base_url=embedding["base_url"], + ) + api_key = _retry_secret( prompt=" Paste API key", validate=lambda key: validate_llm(provider, key), @@ -142,12 +165,13 @@ def _configure_llm() -> LLMConfig: ), exhausted_message="LLM configuration failed. Initialization aborted.", ) - return LLMConfig(provider=provider, api_key=api_key) + defaults = DEFAULT_LLM_PROFILES[provider] + return LLMConfig(provider=provider, api_key=api_key, **defaults) def _configure_mineru() -> MinerUConfig: console.print() - console.print(f"[{WARN}]?[/] Enable PDF parsing with MinerU?") + console.print(f"[{WARN}]?[/] Enable PDF and Office parsing with MinerU?") console.print(f" [{HETA}]1[/] Cloud") console.print(f" [{HETA}]2[/] Local sidecar") console.print(f" [{HETA}]3[/] Skip for now") @@ -195,6 +219,55 @@ def _ask_choice(label: str, choices: dict[int, str]) -> int: console.print(f"[{WARN}]?[/] Choose one of: {', '.join(map(str, choices))}") +def _ask_required_text(prompt: str) -> str: + while True: + value = Prompt.ask(prompt).strip() + if value: + return value + console.print(f"[{WARN}]?[/] Value cannot be empty.") + + +def _ask_optional_text(prompt: str) -> str | None: + value = Prompt.ask(prompt, default="").strip() + return value or None + + +def _show_custom_provider_note() -> None: + console.print() + console.print(f"[{WARN}]?[/] Custom provider expects OpenAI-compatible APIs.") + console.print(f" [{MUTED}]Chat:[/] /chat/completions style text generation") + console.print(f" [{MUTED}]Embedding:[/] /embeddings style vectors with 1024 dimensions") + console.print(f" [{MUTED}]Multimodal:[/] optional OpenAI-style image content blocks") + + +def _configure_custom_capability(label: str, *, required: bool) -> dict[str, str | None]: + console.print() + console.print(f"[{WARN}]?[/] Configure custom {label.lower()} API") + if not required and not Confirm.ask(f" Enable {label.lower()} API?", default=False): + return {"api_key": None, "model": None, "base_url": None} + + while True: + api_key = _ask_required_secret(f" {label} API key") + model = _ask_required_text(f" {label} model") + base_url = _ask_required_text(f" {label} base URL") + with console.status(f"Checking custom {label.lower()} API", spinner="dots"): + ok = validate_llm("custom", api_key, base_url) + if ok: + console.print(f"[{OK}]✓[/] {label} API reachable") + return {"api_key": api_key, "model": model, "base_url": base_url.rstrip("/")} + console.print(f"[{WARN}]?[/] Could not connect to custom {label.lower()} API.") + if not Confirm.ask(" Retry?", default=True): + raise typer.Exit(1) + + +def _ask_required_secret(prompt: str) -> str: + while True: + value = Prompt.ask(prompt, password=True).strip() + if value: + return value + console.print(f"[{WARN}]?[/] Value cannot be empty.") + + def _retry_secret( *, prompt: str, @@ -276,13 +349,29 @@ def _retry_value( raise _RetryExhausted +def _install_assistant_skills() -> None: + """Install the Little Heta skill into supported AI coding assistants.""" + try: + installed = install_assistant_skills() + except Exception as exc: + console.print(f"[{WARN}]?[/] Could not install assistant skills: {exc}") + return + + console.print() + console.print(f"[{OK}]✓[/] assistant skills installed") + for item in installed: + console.print(f" [{MUTED}]{item.assistant}:[/] {_short_path(item.path)}") + console.print(f" [{MUTED}]Other agents:[/] {skill_template_hint()}.") + + def _show_summary(config: HetaConfig) -> None: table = Table.grid(padding=(0, 2)) table.add_column(style=f"bold {HETA}") table.add_column() table.add_row("config", str(CONFIG_PATH)) table.add_row("provider", config.llm.provider) - table.add_row("pdf", _mineru_summary(config.mineru)) + table.add_row("mineru docs", _mineru_summary(config.mineru)) + table.add_row("dynamic insert", "disabled") table.add_row("next", f"[bold {HETA}]heta insert ./notes[/] or [bold {HETA}]heta remember \"...\"[/]") console.print( diff --git a/src/heta/cli/insert.py b/src/heta/cli/insert.py index b3d3474..8e20dc1 100644 --- a/src/heta/cli/insert.py +++ b/src/heta/cli/insert.py @@ -7,34 +7,27 @@ import typer from rich.console import Console from rich.panel import Panel +from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn from rich.table import Table +from heta.cli.branding import HETA, MUTED, OK, WARN from heta.config.io import CONFIG_PATH, load_config from heta.kb import paths from heta.kb.discovery import collect_insert_files, supported_extensions from heta.kb.insert import insert_paths +from heta.kb.models import InsertProgress from heta.kb.pdf_plan import PDF_PAGE_THRESHOLD, estimate_pdf_pages console = Console() -HETA = "rgb(52,144,220)" -MUTED = "rgb(126,146,158)" -OK = "rgb(76,196,142)" -WARN = "rgb(238,183,74)" - def insert_command( targets: list[Path] = typer.Argument( None, help="File or directory paths to insert. Defaults to the current directory.", ), - pdf_planning: bool = typer.Option( - True, - "--pdf-planning/--no-pdf-planning", - help="Split large PDFs before parsing to avoid oversized agent context.", - ), ) -> None: - """Insert files into the Little Heta Markdown knowledge base.""" + """Add files to your knowledge base.""" config = load_config() if config is None: console.print(f"[{WARN}]?[/] Little Heta is not initialized.") @@ -54,11 +47,22 @@ def insert_command( console.print(f"[{MUTED}] Supported:[/] {extensions}") raise typer.Exit(1) + pdf_planning = config.insert_planning.enable _show_plan(files, config, pdf_planning=pdf_planning) try: - with console.status(f"[bold {HETA}]heta insert[/] [{MUTED}]parsing files and merging wiki[/]", spinner="dots"): - result = insert_paths(targets or [], config, enable_pdf_planning=pdf_planning) + with _insert_progress() as progress: + task_id = progress.add_task("preparing files", total=100, completed=1) + + def on_progress(event: InsertProgress) -> None: + progress.update(task_id, completed=event.percent, description=_progress_description(event)) + + result = insert_paths( + targets or [], + config, + enable_pdf_planning=pdf_planning, + on_progress=on_progress, + ) except KeyboardInterrupt: console.print(f"\n[{WARN}]Insert cancelled. Rolled back partial changes.[/]") raise typer.Exit(130) from None @@ -75,6 +79,7 @@ def _show_plan(files: list[Path], config, *, pdf_planning: bool) -> None: table.add_column(style=f"bold {HETA}") table.add_column() table.add_row("files", str(len(files))) + table.add_row("mode", "dynamic" if config.dynamic_insert.enable else "static") table.add_row("mineru", "enabled" if config.mineru.enable else "disabled") table.add_row("pdf planning", "enabled" if pdf_planning else "disabled") table.add_row("workspace", str(paths.workspace_root())) @@ -101,17 +106,17 @@ def _show_result(result) -> None: console.print(f"[{OK}]✓[/] Insert completed.") if result.added: - console.print("\n新增页面:") + console.print("\nAdded pages:") for change in result.added: console.print(f"[{OK}]+[/] {change.title} [{MUTED}]({_absolute_page_path(change.path)})[/]") if result.updated: - console.print("\n更新页面:") + console.print("\nUpdated pages:") for change in result.updated: console.print(f"[{WARN}]~[/] {change.title} [{MUTED}]({_absolute_page_path(change.path)})[/]") if result.deleted: - console.print("\n删除页面:") + console.print("\nDeleted pages:") for change in result.deleted: console.print(f"[red]-[/] {change.title} [{MUTED}]({_absolute_page_path(change.path)})[/]") @@ -123,6 +128,41 @@ def _show_result(result) -> None: if result.planned_pdf_parts: console.print(f"[{MUTED}]pdf parts:[/] {result.planned_pdf_parts}") + if result.invalidated_memories: + console.print(f"[{MUTED}]invalidated memories:[/] {result.invalidated_memories}") + + if result.vector_index_error: + console.print(f"[{WARN}]![/] Vector index sync failed; wiki pages were still committed.") + console.print(f"[{MUTED}] Reason:[/] {result.vector_index_error}") + console.print(f"[{MUTED}] Next:[/] rerun [bold {HETA}]heta vector sync[/] after fixing the issue.") + + if result.skipped_documents: + console.print(f"\n[{WARN}]Documents not organized into wiki pages:[/]") + for source_name in result.skipped_documents: + console.print(f"[{WARN}]![/] {source_name}") + + +def _insert_progress() -> Progress: + return Progress( + TextColumn(f"[bold {HETA}]heta insert[/]"), + BarColumn(bar_width=28, complete_style=HETA, finished_style=OK), + TaskProgressColumn(), + TextColumn("[dim]{task.description}[/]"), + console=console, + ) + + +def _progress_description(event: InsertProgress) -> str: + if event.phase == "prepare": + return event.label + if event.phase == "merge": + return f"merging {event.current}/{event.total} · {event.label}" + if event.phase == "finalize": + return "finalizing wiki, vector index, and commit" + if event.phase == "done": + return "done" + return event.label + def _absolute_page_path(relative_path: str) -> str: return str((paths.wiki_dir() / relative_path).resolve()) diff --git a/src/heta/cli/insert_planning.py b/src/heta/cli/insert_planning.py new file mode 100644 index 0000000..5b2666f --- /dev/null +++ b/src/heta/cli/insert_planning.py @@ -0,0 +1,77 @@ +"""`heta insert-planning` commands.""" + +from __future__ import annotations + +from dataclasses import replace + +import typer +from rich.console import Console + +from heta.cli.branding import HETA, MUTED, OK, WARN +from heta.config.io import CONFIG_PATH, load_config, save_config +from heta.config.schema import InsertPlanningConfig + +console = Console() + +app = typer.Typer( + name="insert-planning", + help="Turn smart insert planning on or off.", + no_args_is_help=True, + rich_markup_mode="rich", +) + + +@app.command("on") +def insert_planning_on() -> None: + """Enable insert planning loops such as large PDF split planning.""" + _set_insert_planning(True) + + +@app.command("true") +def insert_planning_true() -> None: + """Enable insert planning loops.""" + _set_insert_planning(True) + + +@app.command("off") +def insert_planning_off() -> None: + """Disable insert planning loops such as large PDF split planning.""" + _set_insert_planning(False) + + +@app.command("false") +def insert_planning_false() -> None: + """Disable insert planning loops.""" + _set_insert_planning(False) + + +@app.command("status") +def insert_planning_status() -> None: + """Show whether insert planning loops are enabled.""" + config = _require_config() + state = "enabled" if config.insert_planning.enable else "disabled" + console.print(f"[{MUTED}]insert planning:[/] [bold {HETA}]{state}[/]") + + +def _set_insert_planning(enable: bool) -> None: + config = _require_config() + updated = replace(config, insert_planning=InsertPlanningConfig(enable=enable)) + save_config(updated) + state = "enabled" if enable else "disabled" + console.print(f"[{OK}]✓[/] insert planning {state}") + + +def _require_config(): + try: + config = load_config() + except Exception as exc: + console.print(f"[{WARN}]?[/] Failed to read config: {exc}") + raise typer.Exit(1) from exc + if config is None: + console.print(f"[{WARN}]?[/] Little Heta is not initialized.") + console.print(f"[{MUTED}] Missing config:[/] {CONFIG_PATH}") + raise typer.Exit(1) + return config + + +__all__ = ["app"] diff --git a/src/heta/cli/mem_clean.py b/src/heta/cli/mem_clean.py new file mode 100644 index 0000000..71b99c7 --- /dev/null +++ b/src/heta/cli/mem_clean.py @@ -0,0 +1,40 @@ +"""`heta mem-clean` command — wipe all memory data.""" + +from __future__ import annotations + +import typer +from rich.console import Console +from rich.prompt import Confirm + +from heta.cli.branding import MUTED, OK +from heta.mem.clean import clean_memory +from heta.mem.db import get_connection, init_db +from heta.mem.paths import db_path, ensure_mem_dir + +console = Console() + + +def mem_clean_command( + yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation."), +) -> None: + """Erase everything Little Heta has remembered.""" + if not yes and not Confirm.ask( + "Delete all memory data? This cannot be undone.", + default=False, + ): + console.print(f"[{MUTED}]Cancelled.[/]") + raise typer.Exit(0) + + ensure_mem_dir() + conn = get_connection(db_path(), with_vec=True) + init_db(conn) + result = clean_memory(conn) + conn.close() + + console.print(f"[{OK}]✓[/] Memory cleared.") + console.print(f" [{MUTED}]sessions:[/] {result.deleted_sessions}") + console.print(f" [{MUTED}]L0 turns:[/] {result.deleted_l0_turns}") + console.print(f" [{MUTED}]L1 episodes:[/] {result.deleted_l1_episodes}") + console.print(f" [{MUTED}]L2 facts:[/] {result.deleted_l2_facts}") + console.print(f" [{MUTED}]KB insights:[/] {result.deleted_kb_insights}") + console.print(f" [{MUTED}]memory_meta rows:[/] {result.deleted_meta}") diff --git a/src/heta/cli/mem_show.py b/src/heta/cli/mem_show.py new file mode 100644 index 0000000..907bf4a --- /dev/null +++ b/src/heta/cli/mem_show.py @@ -0,0 +1,161 @@ +"""`heta mem-show` commands — inspect stored memories.""" + +from __future__ import annotations + +import sqlite3 +from datetime import datetime + +import typer +from rich.console import Console +from rich.table import Table + +from heta.cli.branding import HETA, MUTED, WARN +from heta.mem.db import get_connection, init_db +from heta.mem.paths import db_path + +console = Console() + +app = typer.Typer( + name="mem-show", + help="Browse the memories Little Heta has stored.", + no_args_is_help=True, + rich_markup_mode="rich", +) + + +@app.command("insights") +def insights_command( + source: str | None = typer.Option(None, "--source", "-s", help="Filter by source_path substring (e.g. 'pages/1-foo.md')."), + question: str | None = typer.Option(None, "--question", "-q", help="Filter by question substring."), + limit: int = typer.Option(50, "--limit", "-n", help="Max rows to show."), + full: bool = typer.Option(False, "--full", "-f", help="Show full insight text (no truncation)."), +) -> None: + """List stored kb_insight memories, newest first.""" + if not db_path().exists(): + console.print(f"[{WARN}]?[/] Memory DB does not exist yet.") + console.print(f"[{MUTED}] Run `heta ask` at least once to populate it.[/]") + raise typer.Exit(0) + + conn = get_connection(db_path(), with_vec=True) + init_db(conn) + try: + rows = _fetch_insights(conn, source=source, question=question, limit=limit) + total = _count_total(conn, source=source, question=question) + finally: + conn.close() + + if not rows: + console.print(f"[{MUTED}]No insights matched.[/]") + return + + table = Table( + title=f"kb_insights ({len(rows)} of {total} shown)", + show_lines=not full, + border_style=HETA, + ) + table.add_column("#", style="dim", justify="right", no_wrap=True) + table.add_column("created", style=MUTED, no_wrap=True) + table.add_column("sources", style=MUTED) + table.add_column("question", style=MUTED) + table.add_column("insight") + + for i, row in enumerate(rows, 1): + insight_text = row["insight"] if full else _truncate(row["insight"], 140) + question_text = row["question"] or "" + if not full: + question_text = _truncate(question_text, 50) + sources_text = "\n".join(row["source_paths"]) if full else _truncate( + ", ".join(row["source_paths"]), 40 + ) + table.add_row( + str(i), + _format_ts(row["created_at"]), + sources_text, + question_text, + insight_text, + ) + console.print(table) + + +def _fetch_insights( + conn: sqlite3.Connection, + *, + source: str | None, + question: str | None, + limit: int, +) -> list[dict]: + """Fetch insights and their full source_paths list.""" + base_sql = """ + SELECT i.memory_id, i.insight, i.question, i.created_at + FROM kb_insight i + JOIN memory_meta m ON m.memory_id = i.memory_id + WHERE m.status = 'active' + """ + clauses, params = _build_filters(source=source, question=question) + sql = f"{base_sql} {clauses} ORDER BY i.created_at DESC LIMIT ?" + params.append(max(1, limit)) + rows = conn.execute(sql, params).fetchall() + + results = [] + for r in rows: + paths = [ + row[0] + for row in conn.execute( + "SELECT source_path FROM kb_insight_source WHERE memory_id = ? ORDER BY source_path", + (r["memory_id"],), + ).fetchall() + ] + results.append({ + "insight": r["insight"], + "question": r["question"], + "source_paths": paths, + "created_at": r["created_at"], + }) + return results + + +def _count_total( + conn: sqlite3.Connection, + *, + source: str | None, + question: str | None, +) -> int: + base_sql = """ + SELECT COUNT(*) FROM kb_insight i + JOIN memory_meta m ON m.memory_id = i.memory_id + WHERE m.status = 'active' + """ + clauses, params = _build_filters(source=source, question=question) + row = conn.execute(f"{base_sql} {clauses}", params).fetchone() + return int(row[0]) + + +def _build_filters(*, source: str | None, question: str | None) -> tuple[str, list]: + clauses: list[str] = [] + params: list = [] + if source: + clauses.append( + "AND i.memory_id IN (SELECT memory_id FROM kb_insight_source WHERE source_path LIKE ?)" + ) + params.append(f"%{source}%") + if question: + clauses.append("AND i.question LIKE ?") + params.append(f"%{question}%") + return " ".join(clauses), params + + +def _truncate(text: str, max_len: int) -> str: + if text is None: + return "" + if len(text) <= max_len: + return text + return text[: max_len - 1] + "…" + + +def _format_ts(ts: int | None) -> str: + if not ts: + return "" + return datetime.fromtimestamp(int(ts)).strftime("%Y-%m-%d %H:%M") + + +__all__ = ["app"] diff --git a/src/heta/cli/query.py b/src/heta/cli/query.py index 07cb8ca..59affe3 100644 --- a/src/heta/cli/query.py +++ b/src/heta/cli/query.py @@ -8,21 +8,18 @@ from rich.panel import Panel from rich.text import Text +from heta.cli.branding import HETA, MUTED, WARN from heta.config.io import CONFIG_PATH, load_config from heta.query import QueryResult, run_wiki_query console = Console() -HETA = "rgb(52,144,220)" -MUTED = "rgb(126,146,158)" -WARN = "rgb(238,183,74)" - def query_command( question: str = typer.Argument(..., help="Question to answer from the Little Heta wiki."), top_k: int = typer.Option(5, "--top-k", min=1, max=10, help="Initial vector matches to include."), ) -> None: - """Ask a read-only question against the Little Heta wiki.""" + """Ask a question about your inserted documents.""" config = load_config() if config is None: console.print(f"[{WARN}]?[/] Little Heta is not initialized.") diff --git a/src/heta/cli/recall.py b/src/heta/cli/recall.py new file mode 100644 index 0000000..ec7535f --- /dev/null +++ b/src/heta/cli/recall.py @@ -0,0 +1,128 @@ +"""CLI command: heta recall.""" + +from __future__ import annotations + +import typer +from rich.console import Console +from rich.padding import Padding +from rich.panel import Panel +from rich.text import Text + +from heta.cli.branding import ERR, HETA, MUTED, WARN +from heta.config.io import load_config +from heta.mem.recall import recall + +console = Console() + +# Technical layer names — shown in --debug output. +_LAYER_LABELS = { + "raw": "L0 Raw", + "episode": "L1 Episode", + "atomic_fact": "L2 Atomic Fact", + "kb_insight": "KB Insight", +} + +# User-facing layer names — shown in the recall result box. +_SOURCE_LABELS = { + "raw": "Conversation", + "episode": "Episodes", + "atomic_fact": "Facts", + "kb_insight": "Documents", +} + + +def recall_command( + query: str = typer.Argument(..., help="What to recall."), + top_k: int = typer.Option(10, "--top-k", "-k", help="Results per layer."), + debug: bool = typer.Option( + False, "--debug", "-d", help="Show layer ranking, reason, and scored evidence." + ), +) -> None: + """Look up what Little Heta remembers.""" + config = load_config() + if config is None: + console.print(f"[{ERR}]Heta is not initialised. Run `heta init` first.[/]") + raise typer.Exit(1) + + with console.status(f"[bold {HETA}]Searching memories...[/]"): + result = recall(query, config, top_k=top_k) + + if debug: + _show_debug(result) + + _show_result(result) + + +def _show_result(result) -> None: + lines = Text() + lines.append("Query: ", style=f"bold {HETA}") + lines.append(f'"{result.query}"\n\n') + + lines.append("Answer:\n", style=f"bold {HETA}") + if result.answer: + lines.append(result.answer) + else: + lines.append( + "I couldn't find a confident answer in your memories yet — " + "the most relevant pieces are listed below.", + style=MUTED, + ) + + source = _source_text(result) + if source.plain: + lines.append("\n\n") + lines.append("Source:\n", style=f"bold {HETA}") + lines.append(source) + + console.print(Panel(lines, title="recall", border_style=HETA, padding=(1, 2))) + + +def _source_text(result) -> Text: + text = Text() + for layer_ev in result.evidence: + if not layer_ev.items: + continue + label = _SOURCE_LABELS.get(layer_ev.layer, layer_ev.layer) + text.append(f"{label}\n", style=HETA) + for item in layer_ev.items: + text.append(" · ", style=HETA) + text.append(f"{_item_text(layer_ev.layer, item)}\n", style=MUTED) + text.rstrip() + return text + + +def _show_debug(result) -> None: + ranking_str = " > ".join(_LAYER_LABELS.get(r, r) for r in result.ranking) + console.print(f"\n[bold {WARN}]── DEBUG ──[/]\n") + + console.print(f"[bold {HETA}]Ranking[/]") + console.print(f" {ranking_str}\n") + + console.print(f"[bold {HETA}]Reason[/]") + console.print(Padding(Text(result.reason), (0, 0, 0, 2))) + console.print() + + console.print(f"[bold {HETA}]Evidence[/]") + for layer_ev in result.evidence: + if not layer_ev.items: + continue + label = _LAYER_LABELS.get(layer_ev.layer, layer_ev.layer) + console.print(f" [{HETA}]{label}[/]") + for item in layer_ev.items: + score = item.get("score", 0) + line = Text(" ") + line.append(f"{score:.3f} · ", style=MUTED) + line.append(_item_text(layer_ev.layer, item)) + console.print(line) + console.print() + console.print(f"[bold {WARN}]──────────[/]\n") + + +def _item_text(layer: str, item: dict) -> str: + if layer == "raw": + return item.get("text_content", "") + if layer == "episode": + return item.get("summary", "") + if layer == "kb_insight": + return item.get("insight", "") + return item.get("fact_text", "") diff --git a/src/heta/cli/remember.py b/src/heta/cli/remember.py new file mode 100644 index 0000000..127ff59 --- /dev/null +++ b/src/heta/cli/remember.py @@ -0,0 +1,37 @@ +"""CLI command: heta remember.""" + +from __future__ import annotations + +import typer +from rich.console import Console +from rich.panel import Panel + +from heta.cli.branding import ERR, HETA, MUTED, OK +from heta.config.io import load_config +from heta.mem.pipeline import remember + +console = Console() + + +def remember_command( + text: str = typer.Argument(..., help="Text to remember."), +) -> None: + """Save something for Little Heta to remember.""" + config = load_config() + if config is None: + console.print(f"[{ERR}]Heta is not initialised. Run `heta init` first.[/]") + raise typer.Exit(1) + + with console.status(f"[bold {HETA}]Extracting memories...[/]"): + result = remember(text, config) + + console.print( + Panel( + f"[bold {HETA}]L1 episodes:[/] {result.l1_count}\n" + f"[bold {HETA}]L2 facts:[/] {result.l2_count}\n" + f"[{MUTED}]session: {result.session_id}[/]\n" + f"[{MUTED}]elapsed: {result.elapsed_s}s[/]", + title="remember", + border_style=OK, + ) + ) diff --git a/src/heta/cli/skill.py b/src/heta/cli/skill.py new file mode 100644 index 0000000..26496e1 --- /dev/null +++ b/src/heta/cli/skill.py @@ -0,0 +1,52 @@ +"""`heta skill` command.""" + +from __future__ import annotations + +from pathlib import Path + +import typer +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +from heta.assistants import install_assistant_skills, skill_template_dir, skill_template_files, skill_template_hint +from heta.cli.branding import HETA, MUTED, OK, WARN + +console = Console() + + +def skill_command() -> None: + """Install the Little Heta skill into supported agent frameworks.""" + try: + installed = install_assistant_skills() + except Exception as exc: + console.print(f"[{WARN}]?[/] Could not install assistant skills: {exc}") + raise typer.Exit(1) from exc + + table = Table.grid(padding=(0, 2)) + table.add_column(style=f"bold {HETA}") + table.add_column(overflow="fold") + for item in installed: + table.add_row(item.assistant, _short_path(item.path)) + + template_dir = skill_template_dir() + for filename in skill_template_files(): + table.add_row(filename, _short_path(template_dir / filename)) + table.add_row("Manual use", f"{skill_template_hint()}.") + + console.print( + Panel( + table, + title="skill", + border_style=OK, + padding=(1, 2), + ) + ) + + +def _short_path(path: Path) -> str: + home = Path.home() + try: + return "~/" + str(path.expanduser().resolve().relative_to(home.resolve())) + except (OSError, ValueError): + return str(path) diff --git a/src/heta/cli/status.py b/src/heta/cli/status.py index 14b91e1..5c5bc53 100644 --- a/src/heta/cli/status.py +++ b/src/heta/cli/status.py @@ -11,16 +11,13 @@ from rich.panel import Panel from rich.table import Table -from heta.cli.branding import brand_line +from heta.cli.branding import HETA, MUTED, WARN, brand_line from heta.config.io import CONFIG_PATH, load_config from heta.config.schema import HetaConfig, MinerUConfig from heta.kb import paths console = Console() -HETA = "rgb(52,144,220)" -MUTED = "rgb(126,146,158)" -WARN = "rgb(238,183,74)" BAR_FULL = "█" BAR_EMPTY = "░" BAR_WIDTH = 20 @@ -30,6 +27,8 @@ class StatusSummary: llm_provider: str mineru: str + insert_planning: str + dynamic_insert: str kb_files: int wiki_pages: int heta_space: Path @@ -38,7 +37,7 @@ class StatusSummary: def status_command() -> None: - """Show the current Little Heta status.""" + """Show what's set up and how much is stored.""" try: config = load_config() except Exception as exc: @@ -55,6 +54,8 @@ def build_status_summary(config: HetaConfig | None, base_dir: Path | None = None return StatusSummary( llm_provider=config.llm.provider if config else "not configured", mineru=_mineru_summary(config.mineru) if config else "not configured", + insert_planning=_enabled_summary(config.insert_planning.enable) if config else "not configured", + dynamic_insert=_enabled_summary(config.dynamic_insert.enable) if config else "not configured", kb_files=_count_files(paths.raw_dir(base_dir)), wiki_pages=_count_markdown_pages(paths.pages_dir(base_dir)), heta_space=heta_space, @@ -70,6 +71,8 @@ def _show_status(summary: StatusSummary, has_config: bool) -> None: table.add_row("Heta space:", f"{_display_path(summary.heta_space).rstrip('/')}/") table.add_row("Model provider:", summary.llm_provider) table.add_row("MinerU:", summary.mineru) + table.add_row("Insert planning:", summary.insert_planning) + table.add_row("Dynamic insert:", summary.dynamic_insert) table.add_row("KB files:", str(summary.kb_files)) table.add_row("Wiki pages:", str(summary.wiki_pages)) @@ -98,6 +101,10 @@ def _mineru_summary(config: MinerUConfig) -> str: return "cloud" +def _enabled_summary(enable: bool) -> str: + return "enabled" if enable else "disabled" + + def _status_content(table: Table) -> Table: layout = Table.grid() layout.add_column() diff --git a/src/heta/cli/vector.py b/src/heta/cli/vector.py index f687285..d6caa0d 100644 --- a/src/heta/cli/vector.py +++ b/src/heta/cli/vector.py @@ -7,19 +7,18 @@ import typer from rich.console import Console +from heta.cli.branding import HETA, MUTED, OK, WARN from heta.config.io import CONFIG_PATH, load_config, save_config from heta.config.schema import VectorIndexConfig +from heta.kb import paths +from heta.kb.models import FileChange +from heta.kb.vector_index import sync_wiki_vector_index console = Console() -HETA = "rgb(52,144,220)" -MUTED = "rgb(126,146,158)" -OK = "rgb(76,196,142)" -WARN = "rgb(238,183,74)" - app = typer.Typer( name="vector", - help="Manage Little Heta wiki vector indexing.", + help="Turn document search vector indexing on or off.", no_args_is_help=True, rich_markup_mode="rich", ) @@ -45,6 +44,23 @@ def vector_status() -> None: console.print(f"[{MUTED}]vector index:[/] [bold {HETA}]{state}[/]") +@app.command("sync") +def vector_sync() -> None: + """Rebuild the wiki vector index from current wiki pages.""" + config = _require_config() + page_files = sorted(paths.pages_dir().glob("*.md")) + changes = [ + FileChange("updated", page.stem, str(page.relative_to(paths.wiki_dir()))) + for page in page_files + ] + try: + sync_wiki_vector_index(changes=changes, config=config) + except Exception as exc: + console.print(f"[{WARN}]?[/] Vector index sync failed: {exc}") + raise typer.Exit(1) from exc + console.print(f"[{OK}]✓[/] vector index synced ({len(changes)} pages)") + + def _set_vector_index(enable: bool) -> None: config = _require_config() updated = replace(config, vector_index=VectorIndexConfig(enable=enable)) diff --git a/src/heta/config/__init__.py b/src/heta/config/__init__.py index 0fbcb10..117b4fb 100644 --- a/src/heta/config/__init__.py +++ b/src/heta/config/__init__.py @@ -1,14 +1,15 @@ """Configuration helpers for Little Heta.""" from heta.config.io import CONFIG_PATH, load_config, save_config -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig +from heta.config.schema import DynamicInsertConfig, HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig __all__ = [ "CONFIG_PATH", + "DynamicInsertConfig", "HetaConfig", + "InsertPlanningConfig", "LLMConfig", "MinerUConfig", "load_config", "save_config", ] - diff --git a/src/heta/config/schema.py b/src/heta/config/schema.py index 7a2f397..f604fa1 100644 --- a/src/heta/config/schema.py +++ b/src/heta/config/schema.py @@ -2,27 +2,135 @@ from __future__ import annotations -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Any, Literal -LLMProvider = Literal["qwen", "chatgpt", "gemini"] +LLMProvider = Literal["qwen", "chatgpt", "gemini", "custom"] MinerUProvider = Literal["cloud", "local"] +DEFAULT_LLM_PROFILES: dict[str, dict[str, str | None]] = { + "qwen": { + "chat_model": "qwen3.5-flash", + "chat_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "multimodal_model": "qwen3.5-omni-flash", + "multimodal_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "embedding_model": "text-embedding-v4", + "embedding_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + }, + "chatgpt": { + "chat_model": "gpt-5.4-nano", + "chat_base_url": None, + "multimodal_model": "gpt-5.4-nano", + "multimodal_base_url": None, + "embedding_model": "text-embedding-3-small", + "embedding_base_url": None, + }, + "gemini": { + "chat_model": "gemini-2.5-flash", + "chat_base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "multimodal_model": "gemini-2.5-flash", + "multimodal_base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "embedding_model": "text-embedding-004", + "embedding_base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + }, +} + @dataclass(frozen=True) class LLMConfig: provider: LLMProvider api_key: str + chat_api_key: str | None = None + chat_model: str | None = None + chat_base_url: str | None = None + chat_extra_body: dict[str, Any] | None = None + multimodal_api_key: str | None = None + multimodal_model: str | None = None + multimodal_base_url: str | None = None + embedding_api_key: str | None = None + embedding_model: str | None = None + embedding_base_url: str | None = None + audio_api_key: str | None = None + audio_model: str | None = None + audio_base_url: str | None = None + + def __post_init__(self) -> None: + defaults = DEFAULT_LLM_PROFILES.get(self.provider, {}) + for field in ( + "chat_api_key", + "chat_model", + "chat_base_url", + "multimodal_api_key", + "multimodal_model", + "multimodal_base_url", + "embedding_api_key", + "embedding_model", + "embedding_base_url", + "audio_api_key", + "audio_model", + "audio_base_url", + ): + if getattr(self, field) is None and field in defaults: + object.__setattr__(self, field, defaults[field]) + if self.provider != "custom": + if self.chat_api_key is None: + object.__setattr__(self, "chat_api_key", self.api_key) + if self.multimodal_api_key is None: + object.__setattr__(self, "multimodal_api_key", self.api_key) + if self.embedding_api_key is None: + object.__setattr__(self, "embedding_api_key", self.api_key) @classmethod def from_dict(cls, data: dict[str, Any]) -> "LLMConfig": provider = data.get("provider") api_key = data.get("api_key") - if provider not in {"qwen", "chatgpt", "gemini"}: + if provider not in {"qwen", "chatgpt", "gemini", "custom"}: raise ValueError("Invalid LLM provider in config.") if not isinstance(api_key, str) or not api_key.strip(): raise ValueError("Invalid LLM api_key in config.") - return cls(provider=provider, api_key=api_key) + + defaults = DEFAULT_LLM_PROFILES.get(provider, {}) + values: dict[str, str | None] = {} + for field in ( + "chat_api_key", + "chat_model", + "chat_base_url", + "multimodal_api_key", + "multimodal_model", + "multimodal_base_url", + "embedding_api_key", + "embedding_model", + "embedding_base_url", + "audio_api_key", + "audio_model", + "audio_base_url", + ): + raw = data.get(field, defaults.get(field)) + if raw is not None and not isinstance(raw, str): + raise ValueError(f"Invalid LLM {field} in config.") + values[field] = raw.strip() if isinstance(raw, str) and raw.strip() else None + + chat_extra_body = data.get("chat_extra_body") + if chat_extra_body is not None and not isinstance(chat_extra_body, dict): + raise ValueError("Invalid LLM chat_extra_body in config.") + + if provider == "custom": + missing = [ + field + for field in ( + "chat_api_key", + "chat_model", + "chat_base_url", + "embedding_api_key", + "embedding_model", + "embedding_base_url", + ) + if values[field] is None + ] + if missing: + raise ValueError(f"Custom LLM config requires: {', '.join(missing)}.") + + return cls(provider=provider, api_key=api_key.strip(), chat_extra_body=chat_extra_body, **values) @dataclass(frozen=True) @@ -73,12 +181,46 @@ def from_dict(cls, data: dict[str, Any]) -> "VectorIndexConfig": return cls(enable=enable) +@dataclass(frozen=True) +class InsertPlanningConfig: + enable: bool + + @classmethod + def enabled(cls) -> "InsertPlanningConfig": + return cls(enable=True) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertPlanningConfig": + enable = data.get("enable") + if not isinstance(enable, bool): + raise ValueError("Invalid insert_planning enable flag in config.") + return cls(enable=enable) + + +@dataclass(frozen=True) +class DynamicInsertConfig: + enable: bool + + @classmethod + def disabled(cls) -> "DynamicInsertConfig": + return cls(enable=False) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DynamicInsertConfig": + enable = data.get("enable") + if not isinstance(enable, bool): + raise ValueError("Invalid dynamic_insert enable flag in config.") + return cls(enable=enable) + + @dataclass(frozen=True) class HetaConfig: version: int llm: LLMConfig mineru: MinerUConfig vector_index: VectorIndexConfig + insert_planning: InsertPlanningConfig + dynamic_insert: DynamicInsertConfig = field(default_factory=DynamicInsertConfig.disabled) @classmethod def from_dict(cls, data: dict[str, Any]) -> "HetaConfig": @@ -88,17 +230,25 @@ def from_dict(cls, data: dict[str, Any]) -> "HetaConfig": llm = data.get("llm") mineru = data.get("mineru") vector_index = data.get("vector_index") + insert_planning = data.get("insert_planning") + dynamic_insert = data.get("dynamic_insert", {"enable": False}) if not isinstance(llm, dict): raise ValueError("Missing LLM config.") if not isinstance(mineru, dict): raise ValueError("Missing MinerU config.") if not isinstance(vector_index, dict): raise ValueError("Missing vector_index config.") + if not isinstance(insert_planning, dict): + raise ValueError("Missing insert_planning config.") + if not isinstance(dynamic_insert, dict): + raise ValueError("Invalid dynamic_insert config.") return cls( version=1, llm=LLMConfig.from_dict(llm), mineru=MinerUConfig.from_dict(mineru), vector_index=VectorIndexConfig.from_dict(vector_index), + insert_planning=InsertPlanningConfig.from_dict(insert_planning), + dynamic_insert=DynamicInsertConfig.from_dict(dynamic_insert), ) def to_dict(self) -> dict[str, Any]: diff --git a/src/heta/kb/agent.py b/src/heta/kb/agent.py index ee392d1..0a0967c 100644 --- a/src/heta/kb/agent.py +++ b/src/heta/kb/agent.py @@ -17,15 +17,10 @@ from heta.kb.models import FileChange, ParsedDocument from heta.kb.text import slugify from heta.kb.wiki import detect_wiki_changes +from heta.providers.clients import build_chat_client, extra_body logger = logging.getLogger(__name__) -FAST_AGENT_MODELS = { - "qwen": "qwen3.5-flash", - "chatgpt": "gpt-5.4-nano", - "gemini": "gemini-2.5-flash", -} - AGENT_TOOLS = [ { "type": "function", @@ -241,37 +236,19 @@ def run_merge_agent( } +_LLM_TIMEOUT_SECONDS = 900 +_LLM_MAX_RETRIES = 3 + + def _get_client(config: HetaConfig) -> tuple[OpenAI, str]: # API keys are intentionally read only from ~/.heta/heta.yaml, which is # created by `heta init`. Model choice stays fixed to fast defaults here. - provider = config.llm.provider - if provider == "qwen": - return ( - OpenAI( - api_key=config.llm.api_key, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - timeout=300, - ), - FAST_AGENT_MODELS["qwen"], - ) - if provider == "chatgpt": - return OpenAI(api_key=config.llm.api_key, timeout=300), FAST_AGENT_MODELS["chatgpt"] - if provider == "gemini": - return ( - OpenAI( - api_key=config.llm.api_key, - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", - timeout=300, - ), - FAST_AGENT_MODELS["gemini"], - ) - raise ValueError(f"Unsupported LLM provider: {provider}") + resolved = build_chat_client(config, timeout=_LLM_TIMEOUT_SECONDS, max_retries=_LLM_MAX_RETRIES) + return resolved.client, resolved.model def _extra_body(config: HetaConfig) -> dict[str, Any] | None: - if config.llm.provider == "qwen": - return {"enable_thinking": False} - return None + return extra_body(config) def _chat_completion( @@ -298,28 +275,39 @@ def _chat_completion( def _system_prompt() -> str: - return """You are running Little Heta KB merge ingest. + return """You are running Little Heta KB ingest. + +You turn parsed source documents into pages of a Markdown wiki working copy. +You have TWO distinct jobs and you must not confuse them: + + 1. TRANSCRIBE — `## Content` is a faithful reproduction of the source. + 2. ORGANIZE — `## Summary`, `## Related Pages`, `[[Wiki Links]]`, and the + `index.md` entry are where you synthesize, condense, and connect. + +Synthesis belongs ONLY in job 2. Inside `## Content` you reproduce; you do not +summarize, paraphrase, or "clean up". -Your job is to absorb parsed source documents into the Markdown wiki working copy. You must use tools to inspect and edit files. Do not claim completion until the -working copy contains the final wiki changes. +working copy contains the final wiki changes. In normal ingest you receive one +source document at a time; treat that document as the current unit of truth. Required workflow: 1. read_page("index.md") to understand the current wiki. 2. Identify up to 5 related pages from index.md for each source document. 3. read_page each related page before deciding whether it is genuinely related. -4. For each source document, either create one complete new page or edit one - existing page that already covers the same topic. -5. Update index.md with one entry per created or substantially updated page. - The entry must use exactly this format: +4. For each source document, either create one new page or edit one existing + page that already covers the same topic (see "Editing an existing page" below). +5. Update index.md with one entry per created or substantially updated page, + in exactly this format: - [id] [[Title]] (pages/file-name.md) — one-line summary - If a page does not have a numeric filename prefix yet, omit the id and use - the semantic path you created. The system will assign stable numeric - filename prefixes and normalize index.md after you finish. + If a page has no numeric filename prefix yet, omit the id and use the + semantic path you created. The system assigns stable numeric prefixes and + normalizes index.md afterwards. 6. Maintain bidirectional [[Wiki Links]] only when the relationship is real. 7. append_log with a concise summary of created, updated, linked, or deleted pages. -Page format: +Page format — every page MUST have exactly these four level-2 sections, in +this order: --- title: Title sources: [source_filename] @@ -327,26 +315,81 @@ def _system_prompt() -> str: --- ## Summary -One paragraph. + +One short paragraph that overviews the page. This is the only place where you +condense the document. The vector index uses Summary as embedding context for +every chunk on the page, so keep it tight and informative. ## Content -Full self-contained content. + +A faithful transcription of the source document. Apply the Content rules and +the Heading rules below. ## Related Pages + - [[Related Title]] +(Write "- None yet" if there are no related pages.) + ## Source + - source_filename -If there are no related pages, write "- None yet". +Content rules — transcription, not summary: +- Reproduce the source's body text. Keep its wording. Do NOT paraphrase, + condense, rewrite, or "clean up" sentences. +- Keep every table, every list, every formula, every code block, every + image link (`![...](...)`), every number, every named entity — as they + appear in the source. +- Preserve the source's section order and hierarchy. +- Lines beginning with `Source:` that follow an image or a table are + provenance annotations injected at parse time (they include the original + filename, page number, and bbox). Reproduce them verbatim on the line(s) + immediately after the figure or table they describe; never drop, edit, or + re-order them. +- Wiki pages may be long. Prefer completeness over brevity. Never truncate, and + never use "...", "etc.", or "(omitted)" to stand in for source content. +- You have a large output budget; do not self-shorten because this is a "wiki". +- Before finishing, verify every section, table, list, image, and `Source:` + annotation from the source is present inside `## Content`. + +Heading rules inside `## Content` — REQUIRED for the vector index: +The vector index splits each page into chunks at level-3-or-deeper headings +(`###`, `####`, ...) inside `## Content`, and uses the breadcrumb of those +headings as the chunk's `heading_path`. Correct headings = correct retrieval. +- Represent the source document's internal structure using `###` and deeper + headings. Reuse the source's own section/subsection titles as the heading + text whenever possible. +- NEVER emit a level-1 (`#`) or level-2 (`##`) heading inside `## Content`. + A stray `##` inside the body would also truncate the Content section. If + the source has top-level headings, demote them: the source's shallowest + heading becomes `###`, the next level `####`, and so on. This preserves the + source's relative hierarchy while staying at level 3+. +- Keep headings outside fenced code blocks. Never let a body line accidentally + start with `#` outside a code fence. +- Faithful sub-headings serve BOTH fidelity and retrieval: they preserve + structure AND give the vector index cleaner, better-scoped chunks. Use them. + +Editing an existing page: +- Only edit an existing page when it genuinely covers the same topic as the + new source. +- Append the new source's content as its own clearly-titled `###` subsection, + transcribed under the Content and Heading rules above. +- Do NOT rewrite, trim, summarize, or re-order the page's existing content in + order to "merge" the new source in. Add; never digest. +- Add the new source filename to the `sources:` frontmatter list and to the + `## Source` section. +- Use exact old_str when calling edit_page. Rules: - Paths are limited to index.md, log.md, and pages/*.md. -- One source document becomes one complete wiki page unless it clearly belongs in an existing page. +- Every source document must be represented inside `## Content` and listed in + `## Source`. +- Two sources may share a page only when they describe the same thing; even + then, keep each source's content under its own `###` subsection. - Do not invent or maintain wiki ids, chunk ids, or numeric page prefixes. - Keep [[Wiki Links]] semantic, e.g. [[HetaGen]], never [[1-HetaGen]]. - Always read a page before editing it. -- Use exact old_str when calling edit_page. - Keep log.md append-only. - Every page must include frontmatter fields: title, sources, updated. - index.md must include every created page with its pages/*.md path and summary. @@ -388,20 +431,25 @@ def _execute_tools( except json.JSONDecodeError as exc: output = f"error: invalid tool arguments: {exc}" else: - if name == "read_page": - output = read_page(root_dir, **arguments) - if not output.startswith("error:"): - read_paths.add(_normalize_path(arguments.get("path", ""))) - elif name == "create_page": - output = create_page(root_dir, written_paths=written_paths, **arguments) - elif name == "edit_page": - output = edit_page(root_dir, written_paths=written_paths, **arguments) - elif name == "delete_page": - output = delete_page(root_dir, written_paths=written_paths, **arguments) - elif name == "append_log": - output = append_log(root_dir, **arguments) - else: - output = f"error: unknown tool {name}" + try: + if name == "read_page": + output = read_page(root_dir, **arguments) + if not output.startswith("error:"): + read_paths.add(_normalize_path(arguments.get("path", ""))) + elif name == "create_page": + output = create_page(root_dir, written_paths=written_paths, **arguments) + elif name == "edit_page": + output = edit_page(root_dir, written_paths=written_paths, **arguments) + elif name == "delete_page": + output = delete_page(root_dir, written_paths=written_paths, **arguments) + elif name == "append_log": + output = append_log(root_dir, **arguments) + else: + output = f"error: unknown tool {name}" + except TypeError as exc: + output = f"error: invalid tool arguments for {name}: {exc}" + except Exception as exc: + output = f"error: tool {name} failed: {exc}" results.append({"role": "tool", "tool_call_id": tool_call.id, "content": output}) return results diff --git a/src/heta/kb/audio_parser.py b/src/heta/kb/audio_parser.py new file mode 100644 index 0000000..10339a3 --- /dev/null +++ b/src/heta/kb/audio_parser.py @@ -0,0 +1,327 @@ +"""Audio and video parsing for Little Heta KB inserts.""" + +from __future__ import annotations + +import base64 +import json +from datetime import date +from pathlib import Path +from typing import Any + +import requests +from openai import OpenAI + +from heta.config.schema import HetaConfig +from heta.kb.agent import _chat_completion, _get_client +from heta.providers.clients import ModelClient, build_multimodal_client + +AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".webm", ".mp4"} + +OPENAI_TRANSCRIBE_MODEL = "gpt-4o-transcribe" + +_MIME_TYPES = { + ".mp3": "audio/mp3", + ".wav": "audio/wav", + ".m4a": "audio/mp4", + ".webm": "audio/webm", + ".mp4": "video/mp4", +} + +_QWEN_FORMATS = { + ".mp3": "mp3", + ".wav": "wav", + ".m4a": "m4a", + ".webm": "webm", +} + + +def parse_audio_markdown(source_path: Path, archived_path: Path, config: HetaConfig) -> str: + """Transcribe audio/video and return stable wiki-flavored Markdown.""" + description = transcribe_media(source_path=source_path, config=config) + media_kind = "Video" if source_path.suffix.lower() == ".mp4" else "Audio" + return build_audio_markdown( + title=f"{media_kind} - {source_path.stem}", + source_name=archived_path.name, + media_path=f"../../raw/{archived_path.name}", + media_kind=media_kind, + summary=description["summary"], + transcript=description["transcript"], + key_points_metadata=description["key_points_metadata"], + interpretation_keywords=description["interpretation_keywords"], + ) + + +def transcribe_media(*, source_path: Path, config: HetaConfig) -> dict[str, str]: + if config.llm.provider == "chatgpt": + transcript = _transcribe_with_openai(source_path, config) + return _structure_transcript(source_path=source_path, transcript=transcript, config=config) + if config.llm.provider == "qwen": + _require_multimodal(config, "Audio/video parsing") + return _transcribe_with_qwen_omni(source_path, config) + if config.llm.provider == "custom": + _require_custom_audio(config) + return _transcribe_with_custom_audio(source_path, config) + if config.llm.provider == "gemini": + _require_multimodal(config, "Audio/video parsing") + return _transcribe_with_gemini(source_path, config) + raise ValueError(f"Unsupported audio provider: {config.llm.provider}") + + +def build_audio_markdown( + *, + title: str, + source_name: str, + media_path: str, + media_kind: str, + summary: str, + transcript: str, + key_points_metadata: str, + interpretation_keywords: str, +) -> str: + link_label = f"{media_kind} file" + return ( + "---\n" + f"title: {title}\n" + f"sources: [{source_name}]\n" + f"updated: {date.today().isoformat()}\n" + "---\n\n" + "## Summary\n\n" + f"{summary.strip()}\n\n" + "## Content\n\n" + f"[{link_label}](<{media_path}>)\n\n" + "### Transcript\n\n" + f"{transcript.strip() or 'No transcript extracted.'}\n\n" + "### Key Points and Metadata\n\n" + f"{key_points_metadata.strip()}\n\n" + "### Interpretation and Keywords\n\n" + f"{interpretation_keywords.strip()}\n\n" + "## Related Pages\n\n" + "- None yet\n\n" + "## Source\n\n" + f"- {source_name}\n" + ) + + +def _transcribe_with_openai(path: Path, config: HetaConfig) -> str: + client = OpenAI(api_key=config.llm.api_key, timeout=300) + with path.open("rb") as file: + result = client.audio.transcriptions.create( + model=OPENAI_TRANSCRIBE_MODEL, + file=file, + response_format="text", + ) + return str(result).strip() + + +def _structure_transcript(*, source_path: Path, transcript: str, config: HetaConfig) -> dict[str, str]: + client, model = _get_client(config) + response = _chat_completion( + client=client, + model=model, + messages=[ + {"role": "system", "content": _media_json_system_prompt()}, + { + "role": "user", + "content": _structure_user_prompt(filename=source_path.name, transcript=transcript), + }, + ], + tools=None, + temperature=0.1, + config=config, + ) + raw = response.choices[0].message.content or "" + data = _normalize_description(_extract_json_object(raw)) + if not data["transcript"].strip(): + data["transcript"] = transcript + return data + + +def _transcribe_with_qwen_omni(path: Path, config: HetaConfig) -> dict[str, str]: + return _transcribe_with_openai_compatible_multimodal(path, config, extra_body={"enable_thinking": False}) + + +def _transcribe_with_custom_audio(path: Path, config: HetaConfig) -> dict[str, str]: + resolved = ModelClient( + client=OpenAI( + api_key=config.llm.audio_api_key or "", + base_url=config.llm.audio_base_url, + timeout=300, + ), + model=config.llm.audio_model or "", + ) + return _transcribe_with_openai_compatible_multimodal(path, config, resolved=resolved) + + +def _transcribe_with_openai_compatible_multimodal( + path: Path, + config: HetaConfig, + *, + extra_body: dict[str, Any] | None = None, + resolved: ModelClient | None = None, +) -> dict[str, str]: + resolved = resolved or build_multimodal_client(config) + suffix = path.suffix.lower() + content: list[dict[str, Any]] = [{"type": "text", "text": _media_prompt(path.name)}] + if suffix == ".mp4": + content.append({"type": "video_url", "video_url": {"url": _data_url(path)}}) + else: + audio_format = _QWEN_FORMATS.get(suffix) + if audio_format is None: + raise ValueError(f"Unsupported Qwen audio type: {suffix}") + content.append( + { + "type": "input_audio", + "input_audio": { + "data": _data_url(path), + "format": audio_format, + }, + } + ) + + response = resolved.client.chat.completions.create( + model=resolved.model, + messages=[{"role": "user", "content": content}], + temperature=0.1, + **({"extra_body": extra_body} if extra_body else {}), + ) + raw = response.choices[0].message.content or "" + return _normalize_description(_extract_json_object(raw)) + + +def _transcribe_with_gemini(path: Path, config: HetaConfig) -> dict[str, str]: + mime = _mime_type(path) + model = config.llm.multimodal_model + if not model: + raise ValueError("Missing LLM multimodal_model in config.") + payload = { + "contents": [ + { + "role": "user", + "parts": [ + {"text": _media_prompt(path.name)}, + { + "inline_data": { + "mime_type": mime, + "data": base64.b64encode(path.read_bytes()).decode("ascii"), + } + }, + ], + } + ], + "generationConfig": {"temperature": 0.1}, + } + response = requests.post( + f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent", + params={"key": config.llm.api_key}, + json=payload, + timeout=300, + ) + if response.status_code != 200: + raise RuntimeError(f"Gemini audio transcription failed: HTTP {response.status_code} {response.text[:300]}") + parts = response.json().get("candidates", [{}])[0].get("content", {}).get("parts", []) + raw = "\n".join(str(part.get("text", "")) for part in parts).strip() + return _normalize_description(_extract_json_object(raw)) + + +def _media_json_system_prompt() -> str: + return """You are an audio/video-to-Markdown parser for Little Heta KB inserts. +Return only one valid JSON object. Do not wrap it in Markdown fences. +Keep the transcript faithful. Do not invent details not present in the transcript or media.""" + + +def _media_prompt(filename: str) -> str: + return f"""Transcribe and describe this audio/video for semantic retrieval. + +Filename: {filename} + +Return JSON with exactly these string fields: +- summary: one concise paragraph describing the media content. +- transcript: full transcript. Preserve speaker labels and timestamps if available. +- key_points_metadata: important facts, decisions, tasks, names, dates, places, speaker count, duration, language, and media type. +- interpretation_keywords: likely meaning or purpose with uncertainty if needed, ending with compact search keywords. +""" + + +def _structure_user_prompt(*, filename: str, transcript: str) -> str: + return f"""Structure this transcript for semantic retrieval. + +Filename: {filename} + +Transcript: +{transcript} + +Return JSON with exactly these string fields: +- summary +- transcript +- key_points_metadata +- interpretation_keywords +""" + + +def _data_url(path: Path) -> str: + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + return f"data:{_mime_type(path)};base64,{encoded}" + + +def _mime_type(path: Path) -> str: + suffix = path.suffix.lower() + mime = _MIME_TYPES.get(suffix) + if mime is None: + raise ValueError(f"Unsupported audio/video type: {suffix}") + return mime + + +def _require_multimodal(config: HetaConfig, feature: str) -> None: + if not ( + config.llm.multimodal_api_key + and config.llm.multimodal_model + and config.llm.multimodal_base_url + ): + raise ValueError( + f"{feature} requires a multimodal model. Run `heta init` and enable custom multimodal API, " + "or skip this file." + ) + + +def _require_custom_audio(config: HetaConfig) -> None: + if not (config.llm.audio_api_key and config.llm.audio_model and config.llm.audio_base_url): + raise ValueError( + "Audio/video parsing is not enabled for custom providers because audio APIs vary by vendor. " + "Use qwen or gemini for built-in audio support, or enable a custom audio adapter later." + ) + + +def _extract_json_object(text: str) -> dict[str, Any]: + stripped = text.strip() + if stripped.startswith("```"): + stripped = stripped.strip("`") + if stripped.lower().startswith("json"): + stripped = stripped[4:].strip() + try: + value = json.loads(stripped) + except json.JSONDecodeError: + start = stripped.find("{") + end = stripped.rfind("}") + if start == -1 or end == -1 or end <= start: + raise ValueError("Audio model did not return JSON.") + value = json.loads(stripped[start : end + 1]) + if not isinstance(value, dict): + raise ValueError("Audio model JSON must be an object.") + return value + + +def _normalize_description(data: dict[str, Any]) -> dict[str, str]: + fields = { + "summary": "Imported audio or video.", + "transcript": "No transcript extracted.", + "key_points_metadata": "No key points extracted.", + "interpretation_keywords": "Audio or video media; transcript.", + } + normalized: dict[str, str] = {} + for key, fallback in fields.items(): + value = data.get(key) + normalized[key] = str(value).strip() if value else fallback + return normalized + + +__all__ = ["AUDIO_EXTENSIONS", "build_audio_markdown", "parse_audio_markdown", "transcribe_media"] diff --git a/src/heta/kb/clean.py b/src/heta/kb/clean.py index ff9782a..9de483c 100644 --- a/src/heta/kb/clean.py +++ b/src/heta/kb/clean.py @@ -14,6 +14,7 @@ class CleanSummary: deleted_pages: int deleted_vector_files: int commit_id: str | None + invalidated_memories: int = 0 def clean_knowledge_base(*, base_dir: Path | None = None) -> CleanSummary: @@ -26,10 +27,14 @@ def clean_knowledge_base(*, base_dir: Path | None = None) -> CleanSummary: deleted_vector_files = _clear_vector_db(base_dir) commit_id = commit_wiki("chore: clean wiki knowledge base", base_dir) + from heta.mem.kb_invalidate import invalidate_all + invalidated = invalidate_all() + return CleanSummary( deleted_pages=deleted_pages, deleted_vector_files=deleted_vector_files, commit_id=commit_id, + invalidated_memories=invalidated, ) diff --git a/src/heta/kb/code_parser.py b/src/heta/kb/code_parser.py new file mode 100644 index 0000000..b2e46fa --- /dev/null +++ b/src/heta/kb/code_parser.py @@ -0,0 +1,375 @@ +"""Static code-file parsing for Little Heta KB inserts.""" + +from __future__ import annotations + +import ast +import json +import re +from dataclasses import dataclass +from datetime import date +from pathlib import Path +from typing import Any + +CODE_EXTENSIONS = { + ".py", + ".js", + ".ts", + ".tsx", + ".jsx", + ".java", + ".go", + ".rs", + ".cpp", + ".c", + ".h", + ".hpp", + ".sh", + ".sql", + ".yaml", + ".yml", + ".json", + ".toml", +} + +SMALL_CODE_LINE_LIMIT = 200 + +LANGUAGES = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".tsx": "typescript", + ".jsx": "javascript", + ".java": "java", + ".go": "go", + ".rs": "rust", + ".cpp": "cpp", + ".c": "c", + ".h": "c/cpp header", + ".hpp": "cpp header", + ".sh": "shell", + ".sql": "sql", + ".yaml": "yaml", + ".yml": "yaml", + ".json": "json", + ".toml": "toml", +} + + +@dataclass(frozen=True) +class CodeSymbol: + name: str + kind: str + signature: str + start_line: int + end_line: int + summary: str + + +def parse_code_markdown(source_path: Path, archived_path: Path) -> str: + text = source_path.read_text(encoding="utf-8", errors="replace") + suffix = source_path.suffix.lower() + language = LANGUAGES.get(suffix, suffix.lstrip(".") or "text") + lines = text.splitlines() + symbols = extract_code_symbols(source_path, text) + + return build_code_markdown( + title=f"Code - {source_path.name}", + source_name=archived_path.name, + raw_path=f"../../raw/{archived_path.name}", + language=language, + line_count=len(lines), + symbols=symbols, + code=text if len(lines) <= SMALL_CODE_LINE_LIMIT else None, + ) + + +def build_code_markdown( + *, + title: str, + source_name: str, + raw_path: str, + language: str, + line_count: int, + symbols: list[CodeSymbol], + code: str | None, +) -> str: + summary = _summary(language, line_count, symbols) + body = [ + "---", + f"title: {title}", + f"sources: [{source_name}]", + f"updated: {date.today().isoformat()}", + "---", + "", + "## Summary", + summary, + "", + "## Content", + "", + f"[Raw source](<{raw_path}>)", + "", + "### File Overview", + f"- language: {language}", + f"- lines: {line_count}", + ] + if symbols: + names = ", ".join(symbol.name for symbol in symbols[:20]) + suffix = "" if len(symbols) <= 20 else f", ... ({len(symbols)} total)" + body.append(f"- symbols: {names}{suffix}") + else: + body.append("- symbols: none detected") + + if code is not None: + body.extend(["", "### Code", f"```{_fence_language(language)}", code.rstrip(), "```"]) + else: + body.extend(["", "### Symbol Index"]) + if symbols: + for symbol in symbols: + body.extend( + [ + "", + f"#### {symbol.name}", + f"Lines: {symbol.start_line}-{symbol.end_line}", + f"Type: {symbol.kind}", + ] + ) + if symbol.signature: + body.append(f"Signature: `{symbol.signature}`") + body.append(f"Summary: {symbol.summary}") + else: + body.extend(["", "#### Lines 1-" + str(line_count), "Summary: Full source is available in raw."]) + + body.extend(["", "## Source", f"- {source_name}", ""]) + return "\n".join(body) + + +def extract_code_symbols(path: Path, text: str) -> list[CodeSymbol]: + suffix = path.suffix.lower() + if suffix == ".py": + return _python_symbols(text) + if suffix in {".yaml", ".yml", ".json", ".toml"}: + return _config_symbols(suffix, text) + if suffix == ".sql": + return _sql_symbols(text) + return _regex_symbols(suffix, text) + + +def _python_symbols(text: str) -> list[CodeSymbol]: + try: + tree = ast.parse(text) + except SyntaxError: + return [] + + symbols: list[CodeSymbol] = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + symbols.append( + CodeSymbol( + name=node.name, + kind="class", + signature=f"class {node.name}", + start_line=node.lineno, + end_line=getattr(node, "end_lineno", node.lineno), + summary=_doc_summary(ast.get_docstring(node), f"Defines class `{node.name}`."), + ) + ) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + prefix = "async function" if isinstance(node, ast.AsyncFunctionDef) else "function" + symbols.append( + CodeSymbol( + name=node.name, + kind=prefix, + signature=_python_signature(node), + start_line=node.lineno, + end_line=getattr(node, "end_lineno", node.lineno), + summary=_doc_summary(ast.get_docstring(node), f"Defines {prefix} `{node.name}`."), + ) + ) + return sorted(symbols, key=lambda symbol: (symbol.start_line, symbol.name)) + + +def _python_signature(node: ast.FunctionDef | ast.AsyncFunctionDef) -> str: + args = [arg.arg for arg in node.args.posonlyargs + node.args.args] + if node.args.vararg: + args.append("*" + node.args.vararg.arg) + args.extend(arg.arg for arg in node.args.kwonlyargs) + if node.args.kwarg: + args.append("**" + node.args.kwarg.arg) + prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def" + return f"{prefix} {node.name}({', '.join(args)})" + + +def _regex_symbols(suffix: str, text: str) -> list[CodeSymbol]: + patterns = { + ".js": [ + ("class", re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)", re.MULTILINE)), + ("function", re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(", re.MULTILINE)), + ("function", re.compile(r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(", re.MULTILINE)), + ], + ".ts": [], + ".tsx": [], + ".jsx": [], + ".java": [ + ("class", re.compile(r"^\s*(?:public\s+)?(?:final\s+)?class\s+([A-Za-z_]\w*)", re.MULTILINE)), + ("method", re.compile(r"^\s*(?:public|private|protected)\s+[\w<>\[\], ?]+\s+([A-Za-z_]\w*)\s*\(", re.MULTILINE)), + ], + ".go": [ + ("function", re.compile(r"^\s*func\s+(?:\([^)]+\)\s*)?([A-Za-z_]\w*)\s*\(", re.MULTILINE)), + ("type", re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+", re.MULTILINE)), + ], + ".rs": [ + ("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+([A-Za-z_]\w*)\s*\(", re.MULTILINE)), + ("type", re.compile(r"^\s*(?:pub\s+)?(?:struct|enum|trait)\s+([A-Za-z_]\w*)", re.MULTILINE)), + ], + ".cpp": [], + ".c": [], + ".h": [], + ".hpp": [], + ".sh": [ + ("function", re.compile(r"^\s*(?:function\s+)?([A-Za-z_][\w-]*)\s*\(\)\s*\{?", re.MULTILINE)), + ], + } + patterns[".ts"] = patterns[".js"] + patterns[".tsx"] = patterns[".js"] + patterns[".jsx"] = patterns[".js"] + c_like = [ + ( + "function", + re.compile( + r"^\s*(?:static\s+|inline\s+|extern\s+)?[\w:*&<>\[\]\s]+\s+([A-Za-z_]\w*)\s*\([^;{}]*\)\s*\{", + re.MULTILINE, + ), + ) + ] + patterns[".cpp"] = c_like + patterns[".c"] = c_like + patterns[".h"] = c_like + patterns[".hpp"] = c_like + + lines = text.splitlines() + symbols: list[CodeSymbol] = [] + seen: set[tuple[str, int]] = set() + for kind, pattern in patterns.get(suffix, []): + for match in pattern.finditer(text): + name = match.group(1) + start_line = text[: match.start(1)].count("\n") + 1 + if (name, start_line) in seen: + continue + seen.add((name, start_line)) + signature = lines[start_line - 1].strip() if start_line - 1 < len(lines) else name + symbols.append( + CodeSymbol( + name=name, + kind=kind, + signature=signature.rstrip("{").strip(), + start_line=start_line, + end_line=_next_symbol_end(start_line, lines), + summary=f"Defines {kind} `{name}`.", + ) + ) + return sorted(symbols, key=lambda symbol: (symbol.start_line, symbol.name)) + + +def _config_symbols(suffix: str, text: str) -> list[CodeSymbol]: + names: list[tuple[str, int]] = [] + if suffix == ".json": + try: + data = json.loads(text) + except json.JSONDecodeError: + data = None + if isinstance(data, dict): + line_lookup = text.splitlines() + for key in data: + line = _find_key_line(str(key), line_lookup) + names.append((str(key), line)) + else: + pattern = re.compile(r"^([A-Za-z0-9_.-]+)\s*[:=]", re.MULTILINE) + names = [(match.group(1), text[: match.start()].count("\n") + 1) for match in pattern.finditer(text)] + + lines = text.splitlines() + symbols = [ + CodeSymbol( + name=name, + kind="config block", + signature=name, + start_line=line, + end_line=_next_symbol_end(line, lines), + summary=f"Configuration block `{name}`.", + ) + for name, line in names + ] + return sorted(symbols, key=lambda symbol: (symbol.start_line, symbol.name)) + + +def _sql_symbols(text: str) -> list[CodeSymbol]: + pattern = re.compile( + r"^\s*(CREATE\s+(?:TABLE|VIRTUAL\s+TABLE|INDEX|VIEW)|SELECT|INSERT|UPDATE|DELETE)\s+(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][\w.]*)?", + re.IGNORECASE | re.MULTILINE, + ) + lines = text.splitlines() + symbols: list[CodeSymbol] = [] + for index, match in enumerate(pattern.finditer(text), start=1): + start_line = text[: match.start()].count("\n") + 1 + op = " ".join(match.group(1).upper().split()) + target = match.group(2) or f"statement-{index}" + name = f"{op} {target}" + symbols.append( + CodeSymbol( + name=name, + kind="sql statement", + signature=lines[start_line - 1].strip() if start_line - 1 < len(lines) else name, + start_line=start_line, + end_line=_next_symbol_end(start_line, lines), + summary=f"SQL statement `{name}`.", + ) + ) + return symbols + + +def _next_symbol_end(start_line: int, lines: list[str]) -> int: + return min(len(lines), start_line + 80) + + +def _find_key_line(key: str, lines: list[str]) -> int: + quoted = re.compile(rf'^\s*"{re.escape(key)}"\s*:') + plain = re.compile(rf"^\s*{re.escape(key)}\s*[:=]") + for index, line in enumerate(lines, start=1): + if quoted.search(line) or plain.search(line): + return index + return 1 + + +def _doc_summary(docstring: str | None, fallback: str) -> str: + if not docstring: + return fallback + first = " ".join(docstring.strip().splitlines()[0].split()) + return first.rstrip(".") + "." + + +def _summary(language: str, line_count: int, symbols: list[CodeSymbol]) -> str: + if symbols: + names = ", ".join(symbol.name for symbol in symbols[:8]) + suffix = "" if len(symbols) <= 8 else f", and {len(symbols) - 8} more" + return f"{language} source file with {line_count} lines. Main indexed symbols: {names}{suffix}." + return f"{language} source file with {line_count} lines. Full source is available through the raw file link." + + +def _fence_language(language: str) -> str: + return { + "python": "python", + "javascript": "javascript", + "typescript": "typescript", + "java": "java", + "go": "go", + "rust": "rust", + "cpp": "cpp", + "c": "c", + "shell": "bash", + "sql": "sql", + "yaml": "yaml", + "json": "json", + "toml": "toml", + }.get(language, "") + + +__all__ = ["CODE_EXTENSIONS", "CodeSymbol", "build_code_markdown", "extract_code_symbols", "parse_code_markdown"] diff --git a/src/heta/kb/discovery.py b/src/heta/kb/discovery.py index e15a1c7..2a99014 100644 --- a/src/heta/kb/discovery.py +++ b/src/heta/kb/discovery.py @@ -5,13 +5,17 @@ from pathlib import Path from heta.config.schema import HetaConfig +from heta.kb.code_parser import CODE_EXTENSIONS +from heta.kb.html_parser import HTML_EXTENSIONS PLAIN_EXTENSIONS = {".md", ".markdown", ".txt"} -MINERU_EXTENSIONS = {".pdf"} +MINERU_EXTENSIONS = {".pdf", ".doc", ".docx", ".ppt", ".pptx", ".xls", ".xlsx"} +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"} +AUDIO_EXTENSIONS = {".mp3", ".wav", ".m4a", ".webm", ".mp4"} def supported_extensions(config: HetaConfig) -> set[str]: - extensions = set(PLAIN_EXTENSIONS) + extensions = set(PLAIN_EXTENSIONS) | IMAGE_EXTENSIONS | AUDIO_EXTENSIONS | CODE_EXTENSIONS | HTML_EXTENSIONS if config.mineru.enable: extensions |= MINERU_EXTENSIONS return extensions @@ -58,4 +62,3 @@ def _add_supported_file(files: list[Path], path: Path, extensions: set[str]) -> def _is_ignored_path(path: Path) -> bool: ignored = {".git", ".worktrees", "__pycache__", ".pytest_cache", "workspace"} return any(part in ignored for part in path.parts) - diff --git a/src/heta/kb/html_parser.py b/src/heta/kb/html_parser.py new file mode 100644 index 0000000..3ba74a9 --- /dev/null +++ b/src/heta/kb/html_parser.py @@ -0,0 +1,470 @@ +"""Structure-preserving HTML parsing for Little Heta KB inserts.""" + +from __future__ import annotations + +import json +import re +import shutil +from dataclasses import asdict, dataclass +from datetime import date +from html import unescape +from pathlib import Path +from urllib.parse import urljoin, urlparse + +from bs4 import BeautifulSoup +from bs4.element import Comment, NavigableString, Tag + +HTML_EXTENSIONS = {".html", ".htm"} +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff"} +NOISE_TAGS = {"script", "style", "nav", "footer", "header", "aside", "iframe", "button", "noscript"} +NOISE_SELECTORS = ( + "[role='navigation']", + "[role='banner']", + "[role='contentinfo']", + "[role='search']", + ".mw-editsection", + ".mw-indicators", + ".mw-jump-link", + ".mw-portlet", + ".mw-sidebar", + ".ambox", + ".docNav", + ".docSearch", + ".navfooter", + ".navheader", + ".metadata", + ".menu", + ".nosearch", + ".noprint", + ".navbox", + ".navbar", + ".printfooter", + ".shortdescription", + ".sidebar", + ".toc", + ".topicon", + "#catlinks", + "#docSearchForm", + "#footer", + "#mw-navigation", + "#p-lang-btn", + "#siteNotice", +) + + +@dataclass(frozen=True) +class HtmlAsset: + id: str + raw_path: str | None + original_src: str + alt: str + title: str + section: str + near_text_before: str + near_text_after: str + + +def parse_html_markdown(source_path: Path, archived_path: Path) -> str: + html = source_path.read_text(encoding="utf-8", errors="replace") + soup = BeautifulSoup(html, "html.parser") + _remove_noise(soup) + + title = _page_title(soup, source_path) + description = _description(soup) + body = _main_content(soup) + asset_dir = archived_path.parent / "assets" / archived_path.stem + converter = _HtmlMarkdownConverter(source_path=source_path, asset_dir=asset_dir, asset_stem=archived_path.stem) + content = converter.convert(body).strip() + content = _ensure_content_title(content, title) + summary = _html_summary(body, title, description) or _summary(title, description, content) + + if converter.assets: + asset_dir.mkdir(parents=True, exist_ok=True) + manifest = { + "source_html": archived_path.name, + "assets": [asdict(asset) for asset in converter.assets], + } + (asset_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8") + + return build_html_markdown( + title=f"Web Page - {title}", + source_name=archived_path.name, + summary=summary, + content=content, + ) + + +def build_html_markdown(*, title: str, source_name: str, summary: str, content: str) -> str: + return ( + "---\n" + f"title: {title}\n" + f"sources: [{source_name}]\n" + f"updated: {date.today().isoformat()}\n" + "---\n\n" + "## Summary\n" + f"{summary.strip()}\n\n" + "## Content\n\n" + f"{content.strip() or 'No main HTML content extracted.'}\n\n" + "## Source\n" + f"- {source_name}\n" + ) + + +class _HtmlMarkdownConverter: + def __init__(self, *, source_path: Path, asset_dir: Path, asset_stem: str) -> None: + self.source_path = source_path + self.asset_dir = asset_dir + self.asset_stem = asset_stem + self.assets: list[HtmlAsset] = [] + self.section_stack: list[str] = [] + self.recent_text: str = "" + + def convert(self, node: Tag) -> str: + parts = [self._convert_child(child) for child in node.children] + return _compact_blocks("\n".join(part for part in parts if part.strip())) + + def _convert_child(self, node) -> str: + if isinstance(node, Comment): + return "" + if isinstance(node, NavigableString): + text = _clean_text(str(node)) + self._remember_text(text) + return text + if not isinstance(node, Tag): + return "" + + name = node.name.lower() if node.name else "" + if name in NOISE_TAGS: + return "" + if name in {"h1", "h2", "h3", "h4", "h5", "h6"}: + return self._heading(node, int(name[1])) + if name == "p": + return self._paragraph(node) + if name in {"ul", "ol"}: + return self._list(node, ordered=name == "ol") + if name == "li": + return self._inline_children(node) + if name == "blockquote": + text = self.convert(node) + return "\n".join(f"> {line}" if line else ">" for line in text.splitlines()) + if name == "pre": + return self._pre(node) + if name == "code": + return f"`{_clean_text(node.get_text(' ', strip=True))}`" + if name == "table": + return self._table(node) + if name == "img": + return self._image(node) + if name == "br": + return "\n" + if name in {"strong", "b"}: + return f"**{self._inline_children(node)}**" + if name in {"em", "i"}: + return f"*{self._inline_children(node)}*" + if name == "a": + return self._link(node) + + return self.convert(node) + + def _heading(self, node: Tag, html_level: int) -> str: + text = _clean_text(node.get_text(" ", strip=True)) + if not text: + return "" + markdown_level = min(6, html_level + 2) + depth = markdown_level - 2 + self.section_stack = self.section_stack[: max(0, depth - 1)] + [text] + self._remember_text(text) + return f"{'#' * markdown_level} {text}" + + def _paragraph(self, node: Tag) -> str: + text = self._inline_children(node) + self._remember_text(text) + return text + + def _list(self, node: Tag, *, ordered: bool) -> str: + lines: list[str] = [] + index = 1 + for child in node.find_all("li", recursive=False): + text = _compact_inline(self._inline_children(child)) + if not text: + continue + marker = f"{index}." if ordered else "-" + lines.append(f"{marker} {text}") + index += 1 + return "\n".join(lines) + + def _pre(self, node: Tag) -> str: + code = node.get_text("\n", strip=False).strip("\n") + language = "" + code_tag = node.find("code") + if code_tag: + classes = " ".join(code_tag.get("class", [])) + match = re.search(r"language-([\w+-]+)", classes) + if match: + language = match.group(1) + return f"```{language}\n{code}\n```" + + def _table(self, node: Tag) -> str: + rows: list[list[str]] = [] + for tr in node.find_all("tr"): + cells = tr.find_all(["th", "td"], recursive=False) + if cells: + rows.append([_compact_inline(cell.get_text(" ", strip=True)) for cell in cells]) + if not rows: + return "" + width = max(len(row) for row in rows) + normalized = [row + [""] * (width - len(row)) for row in rows] + header = normalized[0] + separator = ["---"] * width + body = normalized[1:] + table_lines = [_markdown_row(header), _markdown_row(separator), *[_markdown_row(row) for row in body]] + text = "\n".join(table_lines) + self._remember_text(" ".join(" ".join(row) for row in normalized)) + return text + + def _image(self, node: Tag) -> str: + src = _img_src(node) + if not src: + return "" + alt = _clean_text(str(node.get("alt") or "")) + title = _clean_text(str(node.get("title") or "")) + markdown_src, raw_path = self._image_path(src) + label = alt or title or Path(urlparse(src).path).name or "HTML image" + section = self.section_stack[-1] if self.section_stack else "" + asset = HtmlAsset( + id=f"img-{len(self.assets) + 1:03d}", + raw_path=raw_path, + original_src=src, + alt=alt, + title=title, + section=section, + near_text_before=self.recent_text, + near_text_after="", + ) + self.assets.append(asset) + note = alt or title + if note: + return f"![{_escape_brackets(label)}](<{markdown_src}>)\n\nImage note: {note}." + return f"![{_escape_brackets(label)}](<{markdown_src}>)" + + def _image_path(self, src: str) -> tuple[str, str | None]: + parsed = urlparse(src) + if parsed.scheme in {"http", "https", "data"} or src.startswith("//"): + return src, None + local = (self.source_path.parent / src).resolve() + if not local.exists() or local.suffix.lower() not in IMAGE_EXTENSIONS: + return src, None + self.asset_dir.mkdir(parents=True, exist_ok=True) + target = self.asset_dir / f"img-{len(self.assets) + 1:03d}{local.suffix.lower()}" + shutil.copy2(local, target) + raw_path = f"raw/assets/{self.asset_stem}/{target.name}" + markdown_path = f"../../raw/assets/{self.asset_stem}/{target.name}" + return markdown_path, raw_path + + def _link(self, node: Tag) -> str: + if node.find("img"): + return self._inline_children(node) + text = self._inline_children(node) or _clean_text(node.get_text(" ", strip=True)) + href = str(node.get("href") or "").strip() + if not href: + return text + return f"[{_escape_brackets(text)}](<{href}>)" + + def _inline_children(self, node: Tag) -> str: + parts = [self._convert_child(child) for child in node.children] + return _compact_inline(" ".join(part for part in parts if part.strip())) + + def _remember_text(self, text: str) -> None: + cleaned = _compact_inline(text) + if cleaned: + self.recent_text = cleaned[-240:] + + +def _remove_noise(soup: BeautifulSoup) -> None: + for tag in soup.find_all(list(NOISE_TAGS)): + tag.decompose() + for selector in NOISE_SELECTORS: + for tag in soup.select(selector): + tag.decompose() + + +def _main_content(soup: BeautifulSoup) -> Tag: + selectors = ( + ".mw-parser-output", + "#docContent", + "#main-content", + "#maincontent", + ".document", + ".documentwrapper", + ".body", + "article", + "main", + "[role='main']", + "#bodyContent", + "#mw-content-text", + "#content", + "body", + ) + for selector in selectors: + tag = soup.select_one(selector) + if tag and _clean_text(tag.get_text(" ", strip=True)): + return tag + return soup + + +def _ensure_content_title(content: str, title: str) -> str: + lines = [line.strip() for line in content.splitlines() if line.strip()] + if lines and lines[0] == f"### {title}": + return content + return f"### {title}\n{content}" if content else f"### {title}" + + +def _page_title(soup: BeautifulSoup, source_path: Path) -> str: + for selector in ("h1", "title"): + tag = soup.find(selector) + if tag: + text = _clean_text(tag.get_text(" ", strip=True)) + if text: + return text + return source_path.stem.replace("_", " ").replace("-", " ").title() + + +def _description(soup: BeautifulSoup) -> str: + for attrs in ({"name": "description"}, {"property": "og:description"}): + tag = soup.find("meta", attrs=attrs) + if tag and tag.get("content"): + return _clean_text(str(tag.get("content"))) + return "" + + +def _html_summary(body: Tag, title: str, description: str) -> str: + if _description_candidate(description): + return description + for paragraph in body.find_all("p"): + if _is_non_content_node(paragraph): + continue + text = _lead_summary_text(_clean_text(paragraph.get_text(" ", strip=True))) + if _summary_candidate(text): + return text[:240].rstrip() + ("..." if len(text) > 240 else "") + return description + + +def _description_candidate(text: str) -> bool: + if not _summary_candidate(text): + return False + lowered = text.lower() + if "…" in text or lowered.endswith("..."): + return False + bad_fragments = ("table of contents", "part i.", "newpp limit report") + return not any(fragment in lowered for fragment in bad_fragments) + + +def _is_non_content_node(tag: Tag) -> bool: + blocked_tags = {"table", "figure", "aside", "nav", "footer", "header"} + blocked_classes = { + "ambox", + "hatnote", + "infobox", + "metadata", + "navbox", + "noprint", + "shortdescription", + "sidebar", + } + for parent in [tag, *tag.parents]: + if not isinstance(parent, Tag): + continue + if parent.name and parent.name.lower() in blocked_tags: + return True + classes = set(parent.get("class", [])) + if classes & blocked_classes: + return True + return False + + +def _summary(title: str, description: str, content: str) -> str: + if description: + return description + for block in re.split(r"\n\s*\n", content): + text = _lead_summary_text(_strip_markdown(block)) + if _summary_candidate(text): + return text[:240].rstrip() + ("..." if len(text) > 240 else "") + return f"HTML page about {title}." + + +def _lead_summary_text(text: str) -> str: + text = re.sub(r"^For other uses, see .*?\.\s*", "", text) + match = re.search(r"\bIn\s+[A-Za-z]", text) + if match and 0 < match.start() < 420: + return text[match.start() :] + return text + + +def _summary_candidate(text: str) -> bool: + if len(text) < 80: + return False + lowered = text.lower() + skipped_prefixes = ( + "for other uses", + "image note:", + "jump to content", + "main article:", + "see also:", + ) + if lowered.startswith(skipped_prefixes): + return False + return any(char.isalpha() for char in text) + + +def _img_src(node: Tag) -> str: + for key in ("data-src", "data-original", "data-lazy-src", "src"): + value = node.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + srcset = node.get("srcset") + if isinstance(srcset, str) and srcset.strip(): + return srcset.split(",")[-1].strip().split()[0] + return "" + + +def _markdown_row(row: list[str]) -> str: + return "| " + " | ".join(cell.replace("|", "\\|") for cell in row) + " |" + + +def _clean_text(text: str) -> str: + return re.sub(r"\s+", " ", unescape(text)).strip() + + +def _compact_inline(text: str) -> str: + return re.sub(r"[ \t]+", " ", text).strip() + + +def _compact_blocks(text: str) -> str: + lines = [line.rstrip() for line in text.splitlines()] + compact: list[str] = [] + blank = False + for line in lines: + if not line.strip(): + if not blank: + compact.append("") + blank = True + continue + compact.append(line) + blank = False + return "\n".join(compact).strip() + + +def _strip_markdown(text: str) -> str: + cleaned = re.sub(r"!\[[^\]]*]\(<[^>]*>\)", " ", text) + cleaned = re.sub(r"!\[[^\]]*]\([^)]+\)", " ", cleaned) + cleaned = re.sub(r"\[([^\]]+)]\(<[^>]*>\)", r"\1", cleaned) + cleaned = re.sub(r"\[([^\]]+)]\([^)]+\)", r"\1", cleaned) + cleaned = re.sub(r"[#*_`>|-]+", " ", cleaned) + return _clean_text(cleaned) + + +def _escape_brackets(text: str) -> str: + return text.replace("[", "\\[").replace("]", "\\]") + + +__all__ = ["HTML_EXTENSIONS", "HtmlAsset", "build_html_markdown", "parse_html_markdown"] diff --git a/src/heta/kb/image_parser.py b/src/heta/kb/image_parser.py new file mode 100644 index 0000000..d37059b --- /dev/null +++ b/src/heta/kb/image_parser.py @@ -0,0 +1,177 @@ +"""Image parsing for Little Heta KB inserts.""" + +from __future__ import annotations + +import base64 +import json +from datetime import date +from pathlib import Path +from typing import Any + +from heta.config.schema import HetaConfig +from heta.kb.agent import _chat_completion +from heta.providers.clients import build_multimodal_client + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"} + +_MIME_TYPES = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".webp": "image/webp", +} + + +def parse_image_markdown(source_path: Path, archived_path: Path, config: HetaConfig) -> str: + """Describe an image with a VLM and return stable wiki-flavored Markdown.""" + _require_multimodal(config, "Image parsing") + description = describe_image(source_path=source_path, config=config) + return build_image_markdown( + title=f"Image - {source_path.stem}", + source_name=archived_path.name, + image_path=f"../../raw/{archived_path.name}", + summary=description["summary"], + visual_facts=description["visual_facts"], + visible_text=description["visible_text"], + interpretation_keywords=description["interpretation_keywords"], + ) + + +def describe_image(*, source_path: Path, config: HetaConfig) -> dict[str, str]: + _require_multimodal(config, "Image parsing") + resolved = build_multimodal_client(config) + response = _chat_completion( + client=resolved.client, + model=resolved.model, + messages=[ + {"role": "system", "content": _image_system_prompt()}, + { + "role": "user", + "content": [ + {"type": "text", "text": _image_user_prompt(source_path.name)}, + { + "type": "image_url", + "image_url": { + "url": _data_url(source_path), + }, + }, + ], + }, + ], + tools=None, + temperature=0.1, + config=config, + ) + raw = response.choices[0].message.content or "" + return _normalize_description(_extract_json_object(raw)) + + +def build_image_markdown( + *, + title: str, + source_name: str, + image_path: str, + summary: str, + visual_facts: str, + visible_text: str, + interpretation_keywords: str, +) -> str: + return ( + "---\n" + f"title: {title}\n" + f"sources: [{source_name}]\n" + f"updated: {date.today().isoformat()}\n" + "---\n\n" + "## Summary\n\n" + f"{summary.strip()}\n\n" + "## Content\n\n" + f"![{source_name}](<{image_path}>)\n\n" + "### Visual Facts\n\n" + f"{visual_facts.strip()}\n\n" + "### Visible Text\n\n" + f"{visible_text.strip() or 'None detected.'}\n\n" + "### Interpretation and Keywords\n\n" + f"{interpretation_keywords.strip()}\n\n" + "## Related Pages\n\n" + "- None yet\n\n" + "## Source\n\n" + f"- {source_name}\n" + ) + + +def _image_system_prompt() -> str: + return """You are an image-to-Markdown parser for Little Heta KB inserts. +Return only one valid JSON object. Do not wrap it in Markdown fences. +Be detailed, factual, and efficient. Do not invent hidden context. +If visible text exists, transcribe it faithfully. If there is no visible text, +write "None detected.".""" + + +def _image_user_prompt(filename: str) -> str: + return f"""Describe this image for semantic retrieval. + +Filename: {filename} + +Return JSON with exactly these string fields: +- summary: one concise paragraph describing what the image is and why it matters. +- visual_facts: detailed factual description of scene type, main subject, objects, people, layout, colors, labels, numbers, and spatial relations. +- visible_text: visible text transcription, or "None detected." +- interpretation_keywords: likely meaning or purpose with uncertainty if needed, ending with compact search keywords. +""" + + +def _data_url(path: Path) -> str: + suffix = path.suffix.lower() + mime = _MIME_TYPES.get(suffix) + if mime is None: + raise ValueError(f"Unsupported image type: {suffix}") + encoded = base64.b64encode(path.read_bytes()).decode("ascii") + return f"data:{mime};base64,{encoded}" + + +def _require_multimodal(config: HetaConfig, feature: str) -> None: + if not ( + config.llm.multimodal_api_key + and config.llm.multimodal_model + and config.llm.multimodal_base_url + ): + raise ValueError( + f"{feature} requires a multimodal model. Run `heta init` and enable custom multimodal API, " + "or skip this file." + ) + + +def _extract_json_object(text: str) -> dict[str, Any]: + stripped = text.strip() + if stripped.startswith("```"): + stripped = stripped.strip("`") + if stripped.lower().startswith("json"): + stripped = stripped[4:].strip() + try: + value = json.loads(stripped) + except json.JSONDecodeError: + start = stripped.find("{") + end = stripped.rfind("}") + if start == -1 or end == -1 or end <= start: + raise ValueError("Image model did not return JSON.") + value = json.loads(stripped[start : end + 1]) + if not isinstance(value, dict): + raise ValueError("Image model JSON must be an object.") + return value + + +def _normalize_description(data: dict[str, Any]) -> dict[str, str]: + fields = { + "summary": "Imported image.", + "visual_facts": "No visual facts extracted.", + "visible_text": "None detected.", + "interpretation_keywords": "Image; visual document.", + } + normalized: dict[str, str] = {} + for key, fallback in fields.items(): + value = data.get(key) + normalized[key] = str(value).strip() if value else fallback + return normalized + + +__all__ = ["IMAGE_EXTENSIONS", "build_image_markdown", "describe_image", "parse_image_markdown"] diff --git a/src/heta/kb/insert.py b/src/heta/kb/insert.py index 400f7d5..c62c598 100644 --- a/src/heta/kb/insert.py +++ b/src/heta/kb/insert.py @@ -2,19 +2,22 @@ from __future__ import annotations +from collections.abc import Callable from datetime import datetime from pathlib import Path from uuid import uuid4 from heta.config.schema import HetaConfig -from heta.kb.discovery import collect_insert_files from heta.kb.agent import run_merge_agent -from heta.kb.models import InsertResult, ParsedDocument +from heta.kb.code_parser import CODE_EXTENSIONS +from heta.kb.discovery import collect_insert_files +from heta.kb.models import FileChange, InsertProgress, InsertResult, ParsedDocument from heta.kb.parser import parse_document from heta.kb.pdf_plan import plan_insert_files +from heta.kb.static_insert import write_static_page from heta.kb.store import commit_wiki, ensure_wiki_layout, reset_wiki from heta.kb.vector_index import sync_wiki_vector_index -from heta.kb.wiki import apply_path_map, normalize_wiki_pages, validate_wiki +from heta.kb.wiki import apply_path_map, normalize_wiki_pages, repair_broken_wiki_links, validate_wiki from heta.kb.workspace import cleanup_working_copy, create_working_copy, promote_working_copy @@ -24,7 +27,9 @@ def insert_paths( *, base_dir: Path | None = None, enable_pdf_planning: bool = True, + on_progress: Callable[[InsertProgress], None] | None = None, ) -> InsertResult: + _emit_progress(on_progress, "prepare", 1, 0, 0, "preparing files") files = collect_insert_files(targets, config) if not files: raise ValueError("No supported files found.") @@ -54,24 +59,87 @@ def insert_paths( ) parsed_documents: list[ParsedDocument] = [] for source in prepared_sources: - parsed_documents.append(parse_document(source.source_path, source.archived_path, config)) + parsed_documents.append( + parse_document( + source.source_path, + source.archived_path, + config, + original_name=source.original_name or source.source_path.name, + page_offset=(source.page_start - 1) if source.page_start else 0, + base_dir=base_dir, + ) + ) working_wiki = create_working_copy(task_id, base_dir) - agent_result = run_merge_agent( - task_id=task_id, - documents=parsed_documents, - root_dir=working_wiki, - config=config, + total_documents = len(parsed_documents) + _emit_progress(on_progress, "merge", 1, 0, total_documents, "ready to merge documents") + added = [] + updated = [] + deleted = [] + skipped_documents: list[str] = [] + for index, document in enumerate(parsed_documents, start=1): + _emit_progress( + on_progress, + "merge", + _merge_percent(index - 1, total_documents), + index - 1, + total_documents, + document.source_name, + ) + if config.dynamic_insert.enable: + agent_result = run_merge_agent( + task_id=f"{task_id}_{index}", + documents=[document], + root_dir=working_wiki, + config=config, + ) + else: + agent_result = write_static_page( + root_dir=working_wiki, + document=document, + config=config, + ) + if not (agent_result["added"] or agent_result["updated"] or agent_result["deleted"]): + skipped_documents.append(document.source_name) + _emit_progress( + on_progress, + "merge", + _merge_percent(index, total_documents), + index, + total_documents, + document.source_name, + ) + continue + normalize_result = normalize_wiki_pages(working_wiki) + repair_broken_wiki_links(working_wiki) + normalized_added = apply_path_map(agent_result["added"], normalize_result.path_map) + normalized_updated = apply_path_map(agent_result["updated"], normalize_result.path_map) + normalized_deleted = apply_path_map(agent_result["deleted"], normalize_result.path_map) + _ensure_code_raw_links(working_wiki, document, [*normalized_added, *normalized_updated]) + validate_wiki(working_wiki) + added.extend(normalized_added) + updated.extend(normalized_updated) + deleted.extend(normalized_deleted) + _emit_progress( + on_progress, + "merge", + _merge_percent(index, total_documents), + index, + total_documents, + document.source_name, + ) + + _emit_progress( + on_progress, + "finalize", + 99, + total_documents, + total_documents, + "finalizing wiki and vector index", ) - if not (agent_result["added"] or agent_result["updated"] or agent_result["deleted"]): - raise RuntimeError("Agent completed without changing the wiki.") - normalize_result = normalize_wiki_pages(working_wiki) - validate_wiki(working_wiki) promote_working_copy(task_id, base_dir) commit_id = commit_wiki(f"ingest: {', '.join(file.name for file in files)}", base_dir) - added = apply_path_map(agent_result["added"], normalize_result.path_map) - updated = apply_path_map(agent_result["updated"], normalize_result.path_map) - deleted = agent_result["deleted"] + vector_index_error = None if config.vector_index.enable: try: sync_wiki_vector_index( @@ -79,10 +147,15 @@ def insert_paths( config=config, base_dir=base_dir, ) - except Exception: - pass + except Exception as exc: + vector_index_error = str(exc) or exc.__class__.__name__ cleanup_working_copy(task_id, base_dir) + from heta.mem.kb_invalidate import invalidate_by_paths + invalidated = invalidate_by_paths(c.path for c in (*updated, *deleted)) + + _emit_progress(on_progress, "done", 100, total_documents, total_documents, "insert completed") + return InsertResult( commit_id=commit_id, added=added, @@ -90,6 +163,9 @@ def insert_paths( deleted=deleted, raw_files=raw_files, planned_pdf_parts=sum(plan.parts for plan in pdf_plans if plan.enabled), + invalidated_memories=invalidated, + skipped_documents=skipped_documents, + vector_index_error=vector_index_error, ) except BaseException: for raw in raw_files: @@ -98,3 +174,49 @@ def insert_paths( cleanup_working_copy(task_id, base_dir) reset_wiki(base_dir) raise + + +def _merge_percent(done: int, total: int) -> int: + if total <= 0: + return 99 + return min(99, 1 + int(done / total * 98)) + + +def _ensure_code_raw_links(wiki_root: Path, document: ParsedDocument, changes: list[FileChange]) -> None: + if document.metadata.get("extension") not in CODE_EXTENSIONS: + return + raw_link = f"[Raw source](<../../raw/{document.source_name}>)" + for change in changes: + if not change.path.startswith("pages/") or not change.path.endswith(".md"): + continue + page = wiki_root / change.path + if not page.exists(): + continue + text = page.read_text(encoding="utf-8") + if raw_link in text: + continue + if "## Content" not in text: + continue + updated = text.replace("## Content\n", f"## Content\n\n{raw_link}\n", 1) + page.write_text(updated, encoding="utf-8") + + +def _emit_progress( + callback: Callable[[InsertProgress], None] | None, + phase: str, + percent: int, + current: int, + total: int, + label: str, +) -> None: + if callback is None: + return + callback( + InsertProgress( + phase=phase, + percent=max(0, min(100, percent)), + current=current, + total=total, + label=label, + ) + ) diff --git a/src/heta/kb/models.py b/src/heta/kb/models.py index 733afe0..b01ed65 100644 --- a/src/heta/kb/models.py +++ b/src/heta/kb/models.py @@ -31,3 +31,15 @@ class InsertResult: deleted: list[FileChange] raw_files: list[Path] planned_pdf_parts: int = 0 + invalidated_memories: int = 0 + skipped_documents: list[str] = field(default_factory=list) + vector_index_error: str | None = None + + +@dataclass(frozen=True) +class InsertProgress: + phase: str + percent: int + current: int + total: int + label: str diff --git a/src/heta/kb/parser.py b/src/heta/kb/parser.py index c3d25a8..67a6c44 100644 --- a/src/heta/kb/parser.py +++ b/src/heta/kb/parser.py @@ -2,22 +2,56 @@ from __future__ import annotations +import hashlib +import json +import mimetypes +import re import time +import zipfile +from io import BytesIO from pathlib import Path import requests from heta.config.schema import HetaConfig +from heta.kb import paths +from heta.kb.audio_parser import AUDIO_EXTENSIONS, parse_audio_markdown +from heta.kb.code_parser import CODE_EXTENSIONS, parse_code_markdown +from heta.kb.discovery import MINERU_EXTENSIONS +from heta.kb.html_parser import HTML_EXTENSIONS, parse_html_markdown +from heta.kb.image_parser import IMAGE_EXTENSIONS, parse_image_markdown from heta.kb.models import ParsedDocument from heta.kb.text import extract_title -def parse_document(source_path: Path, archived_path: Path, config: HetaConfig) -> ParsedDocument: +def parse_document( + source_path: Path, + archived_path: Path, + config: HetaConfig, + *, + original_name: str | None = None, + page_offset: int = 0, + base_dir: Path | None = None, +) -> ParsedDocument: suffix = source_path.suffix.lower() if suffix in {".md", ".markdown", ".txt"}: markdown = source_path.read_text(encoding="utf-8") - elif suffix == ".pdf": - markdown = _parse_pdf_with_mineru(archived_path, config) + elif suffix in MINERU_EXTENSIONS: + markdown = _parse_with_mineru( + archived_path, + config, + original_name=original_name or source_path.name, + page_offset=page_offset, + base_dir=base_dir, + ) + elif suffix in IMAGE_EXTENSIONS: + markdown = parse_image_markdown(source_path, archived_path, config) + elif suffix in AUDIO_EXTENSIONS: + markdown = parse_audio_markdown(source_path, archived_path, config) + elif suffix in HTML_EXTENSIONS: + markdown = parse_html_markdown(source_path, archived_path) + elif suffix in CODE_EXTENSIONS: + markdown = parse_code_markdown(source_path, archived_path) else: raise ValueError(f"Unsupported file type: {suffix}") @@ -32,50 +66,111 @@ def parse_document(source_path: Path, archived_path: Path, config: HetaConfig) - ) -def _parse_pdf_with_mineru(path: Path, config: HetaConfig) -> str: +def _parse_with_mineru( + path: Path, + config: HetaConfig, + *, + original_name: str | None = None, + page_offset: int = 0, + base_dir: Path | None = None, +) -> str: if not config.mineru.enable: - raise ValueError(f"PDF parsing requires MinerU: {path.name}") + raise ValueError(f"Document parsing requires MinerU: {path.name}") if config.mineru.provider == "local": - return _parse_pdf_with_local_mineru(path, config.mineru.endpoint or "") + return _parse_with_local_mineru( + path, + config.mineru.endpoint or "", + original_name=original_name or path.name, + page_offset=page_offset, + base_dir=base_dir, + ) if config.mineru.provider == "cloud": - return _parse_pdf_with_cloud_mineru(path) + return _parse_with_cloud_mineru( + path, + config.mineru.api_key or "", + original_name=original_name or path.name, + page_offset=page_offset, + base_dir=base_dir, + ) raise ValueError("Invalid MinerU configuration.") -def _parse_pdf_with_local_mineru(path: Path, endpoint: str) -> str: +def _parse_with_local_mineru( + path: Path, + endpoint: str, + *, + original_name: str, + page_offset: int, + base_dir: Path | None, +) -> str: url = endpoint.rstrip("/") + "/file_parse" + content_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" with path.open("rb") as file: - response = requests.post(url, files={"file": (path.name, file, "application/pdf")}, timeout=300) + response = _resilient_request( + "POST", + url, + files={"files": (path.name, file, content_type)}, + data={ + "lang_list": "ch", + "backend": "hybrid-auto-engine", + "parse_method": "auto", + "formula_enable": "true", + "table_enable": "true", + "return_md": "true", + "return_middle_json": "true", + "return_model_output": "true", + "return_content_list": "true", + "return_images": "true", + "return_original_file": "true", + "response_format_zip": "true", + }, + timeout=300, + ) if response.status_code != 200: raise RuntimeError(f"MinerU local parse failed: HTTP {response.status_code}") + if _looks_like_zip_response(response): + return _finalize_mineru_artifacts( + zip_content=response.content, + empty_error="MinerU local returned empty markdown.", + path=path, + original_name=original_name, + page_offset=page_offset, + base_dir=base_dir, + ) + content_type = response.headers.get("content-type", "") if "application/json" in content_type: - payload = response.json() - for key in ("markdown", "content", "text", "md"): - value = payload.get(key) - if isinstance(value, str) and value.strip(): - return value - data = payload.get("data") - if isinstance(data, dict): - for key in ("markdown", "content", "text", "md"): - value = data.get(key) - if isinstance(value, str) and value.strip(): - return value - raise RuntimeError("MinerU local response did not include markdown content.") + return _local_markdown_from_json(response.json()) return response.text -def _parse_pdf_with_cloud_mineru(path: Path) -> str: - create_response = requests.post( - "https://mineru.net/api/v1/agent/parse/file", +def _parse_with_cloud_mineru( + path: Path, + api_key: str, + *, + original_name: str, + page_offset: int, + base_dir: Path | None, +) -> str: + if not api_key.strip(): + raise ValueError("MinerU cloud parsing requires api_key.") + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "*/*", + } + create_response = _resilient_request( + "POST", + "https://mineru.net/api/v4/file-urls/batch", + headers=headers, json={ - "file_name": path.name, + "files": [{"name": path.name, "data_id": _safe_mineru_data_id(path.stem)}], "language": "ch", "enable_table": True, - "is_ocr": False, "enable_formula": True, + "model_version": "vlm", }, timeout=30, ) @@ -85,44 +180,334 @@ def _parse_pdf_with_cloud_mineru(path: Path) -> str: payload = create_response.json() if payload.get("code") != 0: raise RuntimeError(f"MinerU cloud task creation failed: {payload.get('msg')}") - task_id = payload.get("data", {}).get("task_id") - file_url = payload.get("data", {}).get("file_url") - if not task_id or not file_url: - raise RuntimeError("MinerU cloud did not return task_id and file_url.") + batch_id = payload.get("data", {}).get("batch_id") + file_urls = payload.get("data", {}).get("file_urls") + if not batch_id or not isinstance(file_urls, list) or not file_urls: + raise RuntimeError("MinerU cloud did not return batch_id and file_urls.") - with path.open("rb") as file: - upload_response = requests.put(file_url, data=file, timeout=120) - if upload_response.status_code not in {200, 204}: + upload_payload = path.read_bytes() # buffer once so retries reuse the same bytes + upload_response = _resilient_request("PUT", file_urls[0], data=upload_payload, timeout=120) + if upload_response.status_code not in {200, 201, 204}: raise RuntimeError(f"MinerU cloud upload failed: HTTP {upload_response.status_code}") - markdown_url = _poll_mineru_markdown_url(task_id) - markdown_response = requests.get(markdown_url, timeout=60) - if markdown_response.status_code != 200: - raise RuntimeError(f"MinerU markdown download failed: HTTP {markdown_response.status_code}") - markdown = markdown_response.text.strip() + zip_url = _poll_mineru_zip_url(batch_id, headers=headers, file_name=path.name) + zip_response = _resilient_request("GET", zip_url, timeout=120) + if zip_response.status_code != 200: + raise RuntimeError(f"MinerU zip download failed: HTTP {zip_response.status_code}") + + return _finalize_mineru_artifacts( + zip_content=zip_response.content, + empty_error="MinerU cloud returned empty markdown.", + path=path, + original_name=original_name, + page_offset=page_offset, + base_dir=base_dir, + ) + + +def _finalize_mineru_artifacts( + *, + zip_content: bytes, + empty_error: str, + path: Path, + original_name: str, + page_offset: int, + base_dir: Path | None, +) -> str: + artifacts = _extract_mineru_artifacts(zip_content) + markdown = artifacts["markdown"].strip() if not markdown: - raise RuntimeError("MinerU cloud returned empty markdown.") - return markdown + raise RuntimeError(empty_error) + + if artifacts["content_list"] or artifacts["images"]: + parsed_dir = paths.raw_dir(base_dir) / "parsed" / path.stem + _persist_mineru_artifacts(parsed_dir, markdown, artifacts["content_list"], artifacts["images"]) + + annotated = _annotate_mineru_markdown( + markdown=markdown, + content_list=artifacts["content_list"], + original_name=original_name, + page_offset=page_offset, + part_stem=path.stem, + ) + return annotated.strip() + +def _looks_like_zip_response(response) -> bool: + content_type = response.headers.get("content-type", "").lower() + return "zip" in content_type or response.content.startswith(b"PK\x03\x04") -def _poll_mineru_markdown_url(task_id: str, *, timeout_seconds: int = 180) -> str: + +def _local_markdown_from_json(payload: dict) -> str: + for key in ("markdown", "content", "text", "md"): + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value + data = payload.get("data") + if isinstance(data, dict): + for key in ("markdown", "content", "text", "md"): + value = data.get(key) + if isinstance(value, str) and value.strip(): + return value + raise RuntimeError("MinerU local response did not include markdown content.") + + +def _poll_mineru_zip_url( + batch_id: str, + *, + headers: dict[str, str], + file_name: str, + timeout_seconds: int = 300, +) -> str: deadline = time.time() + timeout_seconds - url = f"https://mineru.net/api/v1/agent/parse/{task_id}" + url = f"https://mineru.net/api/v4/extract-results/batch/{batch_id}" while time.time() < deadline: - response = requests.get(url, timeout=30) + try: + response = requests.get(url, headers=headers, timeout=30) + except requests.exceptions.RequestException: + time.sleep(5) + continue if response.status_code != 200: raise RuntimeError(f"MinerU cloud polling failed: HTTP {response.status_code}") payload = response.json() if payload.get("code") != 0: raise RuntimeError(f"MinerU cloud polling failed: {payload.get('msg')}") - data = payload.get("data", {}) - state = data.get("state") + result = _mineru_batch_result(payload, file_name=file_name) + state = result.get("state") if state == "done": - markdown_url = data.get("markdown_url") - if not markdown_url: - raise RuntimeError("MinerU cloud result did not include markdown_url.") - return markdown_url + zip_url = result.get("full_zip_url") + if not zip_url: + raise RuntimeError("MinerU cloud result did not include full_zip_url.") + return zip_url if state == "failed": - raise RuntimeError(f"MinerU cloud parsing failed: {data.get('err_msg') or data.get('err_code')}") + raise RuntimeError(f"MinerU cloud parsing failed: {result.get('err_msg') or result.get('err_code')}") time.sleep(2) - raise TimeoutError(f"MinerU cloud parsing timed out after {timeout_seconds}s: {task_id}") + raise TimeoutError(f"MinerU cloud parsing timed out after {timeout_seconds}s: {batch_id}") + + +def _mineru_batch_result(payload: dict, *, file_name: str) -> dict: + results = payload.get("data", {}).get("extract_result") + if isinstance(results, list): + for result in results: + if isinstance(result, dict) and result.get("file_name") == file_name: + return result + if results and isinstance(results[0], dict): + return results[0] + return {} + + +_MINERU_RETRY_BACKOFFS = (3, 5, 10, 20) + + +def _resilient_request(method: str, url: str, **kwargs): + """HTTP request with retry on transient network/SSL/DNS/proxy errors. + + Retries on any `requests.RequestException`. Backoff series is + `_MINERU_RETRY_BACKOFFS` (4 intervals → 5 attempts total). + """ + funcs = {"GET": requests.get, "POST": requests.post, "PUT": requests.put} + func = funcs[method.upper()] + last_exc: Exception | None = None + backoffs = (*_MINERU_RETRY_BACKOFFS, None) + for backoff in backoffs: + try: + return func(url, **kwargs) + except requests.exceptions.RequestException as exc: + last_exc = exc + if backoff is None: + raise + time.sleep(backoff) + assert last_exc is not None # for type checker; loop always raises or returns + raise last_exc + + +def _safe_mineru_data_id(stem: str, *, max_bytes: int = 120) -> str: + """Keep MinerU's data_id under its 128-byte cap. + + CJK chars are 3 bytes in UTF-8, so a long part stem with Chinese section + titles easily blows past the cap. Truncate the byte form and append a + short content hash so two parts with the same prefix still get distinct ids. + """ + encoded = stem.encode("utf-8") + if len(encoded) <= max_bytes: + return stem + digest = hashlib.sha1(encoded).hexdigest()[:12] + head = encoded[: max_bytes - len(digest) - 1].decode("utf-8", errors="ignore") + return f"{head}_{digest}" + + +def _extract_mineru_artifacts(zip_content: bytes) -> dict: + """Pull markdown, content_list.json, and image bytes out of MinerU's zip. + + The cloud zip layout (per inspected sample): + full.md + _content_list.json ← flat per-block list with bbox/page_idx + _content_list_v2.json ← richer schema, not used + images/.jpg ← all images + ...other debug files... + """ + markdown = "" + content_list: list[dict] = [] + images: dict[str, bytes] = {} + + with zipfile.ZipFile(BytesIO(zip_content)) as archive: + names = archive.namelist() + + # Markdown: prefer full.md, fall back to any .md + md_name = next((n for n in names if n.endswith("full.md")), None) + if md_name is None: + md_name = next((n for n in names if n.endswith(".md")), None) + if md_name is None: + raise RuntimeError("MinerU zip did not include markdown output.") + markdown = archive.read(md_name).decode("utf-8") + + # Content list (v1, flat): the file ending in `_content_list.json` (not _v2) + cl_name = next( + (n for n in names if n.endswith("_content_list.json") and not n.endswith("_v2.json")), + None, + ) + if cl_name is not None: + try: + data = json.loads(archive.read(cl_name).decode("utf-8")) + if isinstance(data, list): + content_list = [item for item in data if isinstance(item, dict)] + except (json.JSONDecodeError, UnicodeDecodeError): + content_list = [] + + # Images + for name in names: + # Match "images/.ext" anywhere in the archive path. + idx = name.find("images/") + if idx == -1: + continue + rel = name[idx:] # "images/.ext" + if rel == "images/" or rel.endswith("/"): + continue + images[rel] = archive.read(name) + + return {"markdown": markdown, "content_list": content_list, "images": images} + + +def _persist_mineru_artifacts( + parsed_dir: Path, + markdown: str, + content_list: list[dict], + images: dict[str, bytes], +) -> None: + """Write MinerU outputs to disk so external agents can read them later.""" + parsed_dir.mkdir(parents=True, exist_ok=True) + (parsed_dir / "full.md").write_text(markdown, encoding="utf-8") + (parsed_dir / "content_list.json").write_text( + json.dumps(content_list, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + if images: + (parsed_dir / "images").mkdir(parents=True, exist_ok=True) + for rel_path, data in images.items(): + # rel_path is like "images/.jpg" + target = parsed_dir / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(data) + + +# Image markdown line: ![alt](images/) +_MINERU_IMG_RE = re.compile(r"^!\[[^\]]*\]\(images/([^)]+)\)\s*$") +# Table block markers in MinerU markdown. +_MINERU_TABLE_OPEN_RE = re.compile(r"]", re.IGNORECASE) +_MINERU_TABLE_CLOSE_RE = re.compile(r"", re.IGNORECASE) + + +def _annotate_mineru_markdown( + *, + markdown: str, + content_list: list[dict], + original_name: str, + page_offset: int, + part_stem: str, +) -> str: + """Inject `Source:` lines after each image and table, and rewrite image paths. + + Image paths in MinerU markdown are `images/.jpg`, relative to the parser's + own `full.md`. We rewrite them so they resolve from a wiki page at + `/wiki/pages/-.md` to the persisted artifact at + `/raw/parsed//images/.jpg`. + """ + image_path_prefix = f"../../raw/parsed/{part_stem}/images/" + + # Queue figure/table provenance in document order. content_list emits items + # in document order; we keep image and table queues separate so positional + # matching is robust to missing items on either side. + image_provenance = [ + (item.get("bbox"), item.get("page_idx")) + for item in content_list + if isinstance(item, dict) and item.get("type") == "image" + ] + table_provenance = [ + (item.get("bbox"), item.get("page_idx")) + for item in content_list + if isinstance(item, dict) and item.get("type") == "table" + ] + img_cursor = 0 + tbl_cursor = 0 + + out_lines: list[str] = [] + in_table = False + lines = markdown.splitlines() + + for line in lines: + # Inside a table block: just pass lines through; check for close marker. + if in_table: + out_lines.append(line) + if _MINERU_TABLE_CLOSE_RE.search(line): + in_table = False + src = _source_line_at(table_provenance, tbl_cursor, original_name, page_offset) + if src is not None: + out_lines.append("") + out_lines.append(src) + tbl_cursor += 1 + continue + + # Image line + img_match = _MINERU_IMG_RE.match(line.strip()) + if img_match is not None: + out_lines.append(f"![]({image_path_prefix}{img_match.group(1)})") + src = _source_line_at(image_provenance, img_cursor, original_name, page_offset) + if src is not None: + out_lines.append("") + out_lines.append(src) + img_cursor += 1 + continue + + # Table start + if _MINERU_TABLE_OPEN_RE.search(line): + out_lines.append(line) + if _MINERU_TABLE_CLOSE_RE.search(line): + # Single-line table block — close immediately. + src = _source_line_at(table_provenance, tbl_cursor, original_name, page_offset) + if src is not None: + out_lines.append("") + out_lines.append(src) + tbl_cursor += 1 + else: + in_table = True + continue + + out_lines.append(line) + + return "\n".join(out_lines) + + +def _source_line_at( + provenance: list[tuple], + cursor: int, + original_name: str, + page_offset: int, +) -> str | None: + if cursor >= len(provenance): + return None + bbox, page_idx = provenance[cursor] + if not isinstance(bbox, list) or len(bbox) != 4 or not isinstance(page_idx, int): + return None + page = page_offset + page_idx + 1 + bbox_str = f"[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]" + return f"Source: {original_name}, page {page}, bbox {bbox_str}" diff --git a/src/heta/kb/pdf_plan.py b/src/heta/kb/pdf_plan.py index 979804a..b3484e3 100644 --- a/src/heta/kb/pdf_plan.py +++ b/src/heta/kb/pdf_plan.py @@ -18,8 +18,8 @@ from heta.kb.text import slugify PDF_PAGE_THRESHOLD = 80 -PDF_PART_MAX_PAGES = 40 -PDF_PROFILE_MAX_CHARS = 12000 +PDF_PART_MAX_PAGES = 20 +PDF_PROFILE_MAX_CHARS = 60000 @dataclass(frozen=True) @@ -30,6 +30,7 @@ class PreparedSource: page_start: int | None = None page_end: int | None = None metadata_path: Path | None = None + original_name: str | None = None @dataclass(frozen=True) @@ -71,13 +72,25 @@ def plan_insert_files( for file in files: if file.suffix.lower() != ".pdf": - prepared.append(PreparedSource(source_path=file, archived_path=_save_raw_file(file, base_dir))) + prepared.append( + PreparedSource( + source_path=file, + archived_path=_save_raw_file(file, base_dir), + original_name=file.name, + ) + ) continue page_count = estimate_pdf_pages(file) should_split = enable_pdf_planning and page_count > PDF_PAGE_THRESHOLD if not should_split: - prepared.append(PreparedSource(source_path=file, archived_path=_save_raw_file(file, base_dir))) + prepared.append( + PreparedSource( + source_path=file, + archived_path=_save_raw_file(file, base_dir), + original_name=file.name, + ) + ) plans.append(PdfPlan(source_path=file, page_count=page_count, enabled=False, parts=1)) continue @@ -133,7 +146,7 @@ def build_pdf_profile(path: Path, *, page_count: int | None = None) -> PdfProfil filename=path.name, page_count=total_pages, metadata=_metadata(reader), - outline=outline[:80], + outline=outline, page_samples=samples, heading_candidates=heading_candidates[:80], ) @@ -219,6 +232,7 @@ def split_pdf_to_raw_parts( page_start=unit.start_page, page_end=unit.end_page, metadata_path=metadata, + original_name=source.name, ) ) @@ -228,26 +242,52 @@ def split_pdf_to_raw_parts( def _planning_system_prompt() -> str: return """You are Little Heta's PDF split planning agent. -You do not read the full PDF. You only receive a lightweight profile containing -metadata, outline/bookmarks, sampled page text, heading-like lines, and page -count. Decide how to split the PDF into smaller source units. +You do not read the full PDF. You receive a lightweight profile containing +metadata, outline/bookmarks (titled leaf entries with real page numbers), +sampled page text, heading-like lines, and page count. Decide how to split +the PDF into smaller source units. + +Each unit MUST be at most 20 pages. + +Critical: do NOT propose oversized units expecting the system to "split them +later". The system's fallback splitter chops oversized units into mechanical +20-page windows that all share your title, which destroys the outline's +semantic boundaries you saw in the profile. Pick the right granularity +yourself. Return JSON only with this shape: { "document_type": "textbook | paper_collection | report | slides | manual | scanned_book | mixed", "split_strategy": "outline | fixed_page_window | chapter | section | fallback", "units": [ - {"title": "Chapter 1: Introduction", "start_page": 1, "end_page": 32} + {"title": "Section 1.2: Introduction", "start_page": 12, "end_page": 19} ] } -Rules: +Granularity rules — use the FINEST outline level whose units fit ≤20 pages: +- Compute the average page span between consecutive outline entries: + `avg = page_count / len(outline)`. +- If `avg ≤ 20`: use outline entries as unit boundaries directly. Each unit + spans from one entry's page to the page before the next entry's page (or + to the document end for the last entry). Inherit the entry's title. +- If individual entries are very small (`avg < 5`): you may merge 2–4 + consecutive entries into one unit to reach a more useful size, but the + merged title MUST reflect the range (e.g. include the first and last + entry's identifiers, or the chapter they share). Never let the merged + unit exceed 20 pages. +- If `avg > 20`: the outline is too coarse. Subdivide each outline entry + into 20-page fixed windows that inherit the entry's title plus a part + suffix (e.g. "Chapter 3 (pages 60-80)"). +- For paper collections: each paper is one unit (title = paper title). If a + paper exceeds 20 pages, subdivide that paper alone into 20-page windows. +- For slides, scanned books, or empty/unreliable outline: use fixed 20-page + windows. Set split_strategy to "fixed_page_window". + +Hard rules: - Page numbers are 1-based and inclusive. -- Prefer outline/chapter/section boundaries when reliable. -- For paper collections, try title/reference-like boundaries only if samples are clear. -- For reports, prefer top-level sections. -- For slides, scanned books, or weak evidence, use fixed page windows. -- Keep each unit small enough for downstream parsing. The system will further split oversized units. +- Every unit MUST be ≤20 pages. A unit larger than 20 pages will be + rejected and re-split mechanically. +- Cover [1, page_count] without gaps and without overlap. - Do not invent details that are absent from the profile. """ @@ -352,10 +392,17 @@ def _fixed_range_units(start_page: int, end_page: int, *, max_pages: int) -> lis def _extract_outline(reader: PdfReader) -> list[dict[str, Any]]: + """Collect leaf outline entries (titled bookmarks with a real page target). + + Folder/group bookmarks (e.g. collapsed parents like "正文前资料") have a null + page destination and are useless for split planning; we skip them so they + don't crowd out the real leaves under the OUTLINE_MAX cap. + """ outline: list[dict[str, Any]] = [] + OUTLINE_MAX = 500 def visit(items: Any, depth: int = 0) -> None: - if len(outline) >= 120: + if len(outline) >= OUTLINE_MAX: return if isinstance(items, list): for item in items: @@ -364,10 +411,13 @@ def visit(items: Any, depth: int = 0) -> None: title = getattr(items, "title", None) if title: try: - page = reader.get_destination_page_number(items) + 1 + page_number = reader.get_destination_page_number(items) except Exception: - page = None - outline.append({"title": str(title), "page": page, "depth": depth}) + page_number = None + if page_number is None: + # Folder/group bookmark — keep walking siblings but do not record. + return + outline.append({"title": str(title), "page": page_number + 1, "depth": depth}) return try: for child in items: diff --git a/src/heta/kb/static_insert.py b/src/heta/kb/static_insert.py new file mode 100644 index 0000000..3f4f5ff --- /dev/null +++ b/src/heta/kb/static_insert.py @@ -0,0 +1,195 @@ +"""Static wiki page generation for `heta insert`.""" + +from __future__ import annotations + +import re +import time +from datetime import date, datetime +from pathlib import Path +from typing import Any + +from openai import APIError + +from heta.config.schema import HetaConfig +from heta.kb.models import FileChange, ParsedDocument +from heta.kb.text import slugify +from heta.providers.clients import build_chat_client, extra_body + +SUMMARY_MAX_CHARS = 12000 +SUMMARY_MAX_TOKENS = 512 +SUMMARY_RETRIES = 3 + +SUMMARY_PROMPT = """Write a concise Little Heta wiki Summary for one parsed source document. +Return only the summary paragraph, normally 1-3 sentences. +Do not use Markdown headings or bullets. +Mention the main object/topic, document purpose, and important identifiers if visible. +""" + + +def write_static_page( + *, + root_dir: Path, + document: ParsedDocument, + config: HetaConfig, +) -> dict[str, Any]: + """Write exactly one static wiki page for a parsed document.""" + pages = root_dir / "pages" + pages.mkdir(parents=True, exist_ok=True) + + summary = generate_summary(document=document, config=config) + page_name = _available_page_name(pages, document.title) + page_rel = f"pages/{page_name}" + page = _build_page(document=document, summary=summary) + (pages / page_name).write_text(page, encoding="utf-8") + _append_index_entry(root_dir / "index.md", document.title, page_rel, summary) + _append_log(root_dir / "log.md", f"Created static page: {document.title} from {document.source_name}") + + change = FileChange("added", document.title, page_rel) + return {"added": [change], "updated": [], "deleted": []} + + +def generate_summary(*, document: ParsedDocument, config: HetaConfig) -> str: + resolved = build_chat_client(config, timeout=300, max_retries=2) + request_extra_body = extra_body(config) + prompt = _summary_user_prompt(document) + last_exc: Exception | None = None + for attempt in range(1, SUMMARY_RETRIES + 1): + try: + kwargs: dict[str, Any] = { + "model": resolved.model, + "messages": [ + {"role": "system", "content": SUMMARY_PROMPT}, + {"role": "user", "content": prompt}, + ], + "temperature": 0.1, + "max_tokens": SUMMARY_MAX_TOKENS, + } + if request_extra_body is not None: + kwargs["extra_body"] = request_extra_body + response = resolved.client.chat.completions.create(**kwargs) + summary = _normalize_summary(response.choices[0].message.content or "") + if summary: + return summary + last_exc = RuntimeError("LLM returned an empty summary.") + except APIError as exc: + last_exc = exc + if attempt < SUMMARY_RETRIES: + time.sleep(min(2**attempt, 20)) + assert last_exc is not None + raise last_exc + + +def normalize_content_for_static_page(document: ParsedDocument) -> str: + text = _normalize_model_markdown(document.markdown_content) + lines: list[str] = [] + has_level3 = False + in_code = False + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("```"): + in_code = not in_code + lines.append(line) + continue + if not in_code and (match := re.match(r"^(#{1,6})\s+(.*)$", line)): + level = len(match.group(1)) + title = match.group(2).strip() + new_level = max(3, level + 2) + if new_level >= 3: + has_level3 = True + line = f"{'#' * min(new_level, 6)} {title}" + lines.append(line) + + content = "\n".join(lines).strip() + if not content: + content = "No content." + if not has_level3: + content = f"### {document.title}\n\n{content}" + return content + + +def _summary_user_prompt(document: ParsedDocument) -> str: + sample = document.markdown_content[:SUMMARY_MAX_CHARS] + return f"""Title: {document.title} +Source file: {document.source_name} + +Parsed markdown excerpt: +```markdown +{sample} +``` +""" + + +def _build_page(*, document: ParsedDocument, summary: str) -> str: + content = normalize_content_for_static_page(document) + today = date.today().isoformat() + return ( + "---\n" + f"title: {document.title}\n" + f"sources: [{document.source_name}]\n" + f"updated: {today}\n" + "---\n\n" + "## Summary\n\n" + f"{summary.strip() or document.title}\n\n" + "## Content\n\n" + f"{content.strip()}\n\n" + "## Related Pages\n\n" + "- None yet\n\n" + "## Source\n\n" + f"- {document.source_name}\n" + ) + + +def _available_page_name(pages: Path, title: str) -> str: + next_id = _next_wiki_id(pages) + slug = slugify(title) + for wiki_id in range(next_id, next_id + 10000): + candidate = f"{wiki_id}-{slug}.md" + if not (pages / candidate).exists(): + return candidate + raise RuntimeError(f"Too many wiki pages while creating: {title}") + + +def _next_wiki_id(pages: Path) -> int: + ids = [] + for page in pages.glob("*.md"): + match = re.match(r"^(\d+)-.+\.md$", page.name) + if match: + ids.append(int(match.group(1))) + return max(ids, default=0) + 1 + + +def _append_index_entry(index_path: Path, title: str, page_rel: str, summary: str) -> None: + index = index_path.read_text(encoding="utf-8") if index_path.exists() else "# Wiki Index\n" + wiki_id = _wiki_id_from_page_rel(page_rel) + prefix = f"- [{wiki_id}] " if wiki_id is not None else "- " + entry = f"{prefix}[[{title}]] ({page_rel}) — {summary.strip() or title}" + index_path.write_text(index.rstrip() + "\n" + entry + "\n", encoding="utf-8") + + +def _append_log(log_path: Path, message: str) -> None: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + existing = log_path.read_text(encoding="utf-8") if log_path.exists() else "# Wiki Log\n\n" + log_path.write_text(existing.rstrip() + f"\n- [{timestamp}] {message}\n", encoding="utf-8") + + +def _wiki_id_from_page_rel(page_rel: str) -> int | None: + match = re.match(r"^pages/(\d+)-", page_rel) + return int(match.group(1)) if match else None + + +def _normalize_summary(markdown: str) -> str: + text = _normalize_model_markdown(markdown) + text = re.sub(r"^#+\s*", "", text).strip() + text = " ".join(line.strip().lstrip("-*").strip() for line in text.splitlines() if line.strip()) + return re.sub(r"\s+", " ", text).strip() + + +def _normalize_model_markdown(markdown: str) -> str: + text = markdown.strip() + if text.startswith("```"): + text = re.sub(r"^```(?:markdown)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + return text.strip() + + +__all__ = ["normalize_content_for_static_page", "write_static_page"] diff --git a/src/heta/kb/text.py b/src/heta/kb/text.py index a4a183d..1b74613 100644 --- a/src/heta/kb/text.py +++ b/src/heta/kb/text.py @@ -14,8 +14,16 @@ def slugify(value: str) -> str: def extract_title(markdown: str, fallback: str) -> str: + in_frontmatter = False for line in markdown.splitlines(): stripped = line.strip() + if stripped == "---": + in_frontmatter = not in_frontmatter + continue + if in_frontmatter and stripped.startswith("title:"): + title = stripped.split(":", 1)[1].strip() + if title: + return title if stripped.startswith("#"): title = stripped.lstrip("#").strip() if title: diff --git a/src/heta/kb/vector_index.py b/src/heta/kb/vector_index.py index 1c1bc03..bb10408 100644 --- a/src/heta/kb/vector_index.py +++ b/src/heta/kb/vector_index.py @@ -11,23 +11,14 @@ from typing import Iterable import sqlite_vec -from openai import OpenAI from heta.config.schema import HetaConfig from heta.kb import paths from heta.kb.models import FileChange +from heta.providers.clients import EMBEDDING_DIM, build_embedding_client -EMBEDDING_DIM = 1024 -EMBEDDING_MODELS = { - "qwen": "text-embedding-v4", - "chatgpt": "text-embedding-3-small", - "gemini": "text-embedding-004", -} EMBEDDING_BATCH_SIZE = 10 -EMBEDDING_BASE_URLS = { - "qwen": "https://dashscope.aliyuncs.com/compatible-mode/v1", - "gemini": "https://generativelanguage.googleapis.com/v1beta/openai/", -} +MAX_CHUNK_CHARS = 4096 PAGE_NAME_RE = re.compile(r"^(?P\d+)-.+\.md$") HEADING_RE = re.compile(r"^(#{1,6})\s+(.+?)\s*$") @@ -51,6 +42,7 @@ class WikiChunkSearchResult: content: str distance: float score: float + retrieval: str = "vector" def sync_wiki_vector_index( @@ -66,25 +58,25 @@ def sync_wiki_vector_index( """ db_path = paths.vector_db_path(base_dir) db_path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(db_path) - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) + conn = _connect_index_db(db_path) try: _ensure_schema(conn) changed = list(changes) - pages_to_embed: list[Path] = [] + deleted_wiki_ids: set[int] = set() + pages_to_embed: dict[int, Path] = {} for change in changed: wiki_id = _wiki_id_from_path(change.path) if wiki_id is not None: - _delete_page_chunks(conn, wiki_id) + if wiki_id not in deleted_wiki_ids: + _delete_page_chunks(conn, wiki_id) + deleted_wiki_ids.add(wiki_id) if change.kind == "deleted": continue page = paths.wiki_dir(base_dir) / change.path - if page.exists(): - pages_to_embed.append(page) + if page.exists() and wiki_id is not None: + pages_to_embed[wiki_id] = page - chunks = [chunk for page in pages_to_embed for chunk in chunk_wiki_page(page)] + chunks = [chunk for page in pages_to_embed.values() for chunk in chunk_wiki_page(page)] if chunks: embeddings = _embed_texts([chunk.content for chunk in chunks], config) for chunk, embedding in zip(chunks, embeddings, strict=True): @@ -109,11 +101,10 @@ def search_wiki_vector_index( if not db_path.exists(): return [] - conn = sqlite3.connect(db_path) - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) + conn = _connect_index_db(db_path) try: + _ensure_schema(conn) + conn.commit() embedding = _embed_texts([query], config)[0] rows = conn.execute( """ @@ -143,11 +134,91 @@ def search_wiki_vector_index( content=str(row[4]), distance=float(row[5]), score=1.0 / (1.0 + float(row[5])), + retrieval="vector", + ) + for row in rows + ] + + +def search_wiki_fts_index( + *, + query: str, + top_k: int = 5, + base_dir: Path | None = None, +) -> list[WikiChunkSearchResult]: + """Return lexical wiki chunk matches from SQLite FTS5 trigram search.""" + db_path = paths.vector_db_path(base_dir) + if not db_path.exists(): + return [] + + match_query = _fts_match_query(query) + if not match_query: + return [] + + conn = _connect_index_db(db_path) + try: + _ensure_schema(conn) + conn.commit() + rows = conn.execute( + """ + SELECT + c.wiki_id, + c.page_name, + c.chunk_id, + c.heading_path, + c.content, + bm25(wiki_chunk_fts) AS rank + FROM wiki_chunk_fts f + JOIN wiki_chunks c ON c.id = f.rowid + WHERE wiki_chunk_fts MATCH ? + ORDER BY rank + LIMIT ? + """, + (match_query, max(1, top_k)), + ).fetchall() + finally: + conn.close() + + return [ + WikiChunkSearchResult( + wiki_id=int(row[0]), + page_name=str(row[1]), + chunk_id=str(row[2]), + heading_path=str(row[3]), + content=str(row[4]), + distance=float(row[5]), + score=1.0 / (1.0 + max(float(row[5]), 0.0)), + retrieval="fts", ) for row in rows ] +def search_wiki_hybrid_index( + *, + query: str, + config: HetaConfig, + top_k: int = 5, + candidate_k: int | None = None, + base_dir: Path | None = None, +) -> list[WikiChunkSearchResult]: + """Return wiki chunk matches fused from vector and lexical retrieval.""" + limit = max(1, top_k) + candidates = max(limit, candidate_k or limit * 3) + vector_results = search_wiki_vector_index( + query=query, + config=config, + top_k=candidates, + base_dir=base_dir, + ) + fts_results = search_wiki_fts_index( + query=query, + top_k=candidates, + base_dir=base_dir, + ) + return _rrf_fuse(vector_results=vector_results, fts_results=fts_results, top_k=limit) + + def chunk_wiki_page(page: Path) -> list[WikiChunk]: wiki_id = _wiki_id_from_path(f"pages/{page.name}") if wiki_id is None: @@ -161,21 +232,29 @@ def chunk_wiki_page(page: Path) -> list[WikiChunk]: return [] chunks: list[WikiChunk] = [] + seen_hashes: set[str] = set() sections = _content_sections(content) for index, (heading_path, body) in enumerate(sections): - chunk_text = _chunk_text(title=title, summary=summary, heading_path=heading_path, body=body) - content_hash = _hash_text(chunk_text) - chunk_id = f"{wiki_id}:{content_hash[:16]}" - chunks.append( - WikiChunk( - wiki_id=wiki_id, - page_name=page.name, - chunk_id=chunk_id, - heading_path=heading_path or "Content", - content=chunk_text, - content_hash=content_hash, + prefix_overhead = len(_chunk_text(title=title, summary=summary, heading_path=heading_path, body="")) + body_budget = max(MAX_CHUNK_CHARS - prefix_overhead - 16, 256) + body_pieces = _split_text(body, body_budget) or [body] + for piece in body_pieces: + chunk_text = _chunk_text(title=title, summary=summary, heading_path=heading_path, body=piece) + content_hash = _hash_text(chunk_text) + if content_hash in seen_hashes: + continue + seen_hashes.add(content_hash) + chunk_id = f"{wiki_id}:{content_hash[:16]}" + chunks.append( + WikiChunk( + wiki_id=wiki_id, + page_name=page.name, + chunk_id=chunk_id, + heading_path=heading_path or "Content", + content=chunk_text, + content_hash=content_hash, + ) ) - ) return chunks @@ -201,12 +280,32 @@ def _ensure_schema(conn: sqlite3.Connection) -> None: ) """ ) + conn.execute( + """ + CREATE VIRTUAL TABLE IF NOT EXISTS wiki_chunk_fts USING fts5( + page_name, + heading_path, + content, + tokenize='trigram' + ) + """ + ) + _backfill_fts_if_needed(conn) + + +def _connect_index_db(db_path: Path) -> sqlite3.Connection: + conn = sqlite3.connect(db_path, timeout=30) + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + return conn def _delete_page_chunks(conn: sqlite3.Connection, wiki_id: int) -> None: rowids = [row[0] for row in conn.execute("SELECT id FROM wiki_chunks WHERE wiki_id = ?", (wiki_id,))] for rowid in rowids: conn.execute("DELETE FROM wiki_chunk_vec WHERE rowid = ?", (rowid,)) + conn.execute("DELETE FROM wiki_chunk_fts WHERE rowid = ?", (rowid,)) conn.execute("DELETE FROM wiki_chunks WHERE wiki_id = ?", (wiki_id,)) @@ -233,15 +332,113 @@ def _insert_chunk(conn: sqlite3.Connection, chunk: WikiChunk, embedding: list[fl "INSERT INTO wiki_chunk_vec(rowid, embedding) VALUES (?, ?)", (cursor.lastrowid, sqlite_vec.serialize_float32(embedding)), ) + conn.execute( + """ + INSERT INTO wiki_chunk_fts(rowid, page_name, heading_path, content) + VALUES (?, ?, ?, ?) + """, + (cursor.lastrowid, chunk.page_name, chunk.heading_path, chunk.content), + ) + + +def _backfill_fts_if_needed(conn: sqlite3.Connection) -> None: + chunks_count = conn.execute("SELECT count(*) FROM wiki_chunks").fetchone()[0] + fts_count = conn.execute("SELECT count(*) FROM wiki_chunk_fts").fetchone()[0] + if chunks_count == fts_count: + return + conn.execute("DELETE FROM wiki_chunk_fts") + conn.execute( + """ + INSERT INTO wiki_chunk_fts(rowid, page_name, heading_path, content) + SELECT id, page_name, heading_path, content + FROM wiki_chunks + """ + ) + + +def _fts_match_query(query: str) -> str: + terms = _fts_terms(query) + return " OR ".join(f'"{term}"' for term in terms) + + +def _fts_terms(query: str) -> list[str]: + normalized = _normalize_fts_query(query) + raw_terms = re.findall(r"[\w\u4e00-\u9fff][\w\u4e00-\u9fff./\-]*", normalized) + terms: list[str] = [] + seen: set[str] = set() + for term in raw_terms: + term = term.strip("./-") + if len(term) < 3 or term in seen: + continue + seen.add(term) + terms.append(term) + return terms + + +def _normalize_fts_query(query: str) -> str: + table = str.maketrans( + { + "-": "-", + "–": "-", + "—": "-", + "−": "-", + ":": ":", + "/": "/", + "(": " ", + ")": " ", + ",": " ", + "。": " ", + ";": " ", + "、": " ", + } + ) + return re.sub(r"\s+", " ", query.translate(table).upper()).strip() + + +def _rrf_fuse( + *, + vector_results: list[WikiChunkSearchResult], + fts_results: list[WikiChunkSearchResult], + top_k: int, + rrf_k: int = 60, +) -> list[WikiChunkSearchResult]: + by_chunk: dict[str, WikiChunkSearchResult] = {} + scores: dict[str, float] = {} + retrievals: dict[str, set[str]] = {} + + for results, retrieval in ((vector_results, "vector"), (fts_results, "fts")): + for rank, result in enumerate(results, start=1): + by_chunk.setdefault(result.chunk_id, result) + scores[result.chunk_id] = scores.get(result.chunk_id, 0.0) + 1.0 / (rrf_k + rank) + retrievals.setdefault(result.chunk_id, set()).add(retrieval) + + ranked = sorted(scores.items(), key=lambda item: item[1], reverse=True)[:top_k] + max_score = ranked[0][1] if ranked else 1.0 + fused: list[WikiChunkSearchResult] = [] + for chunk_id, score in ranked: + result = by_chunk[chunk_id] + fused.append( + WikiChunkSearchResult( + wiki_id=result.wiki_id, + page_name=result.page_name, + chunk_id=result.chunk_id, + heading_path=result.heading_path, + content=result.content, + distance=result.distance, + score=score / max_score, + retrieval="+".join(sorted(retrievals[chunk_id])), + ) + ) + return fused def _embed_texts(texts: list[str], config: HetaConfig) -> list[list[float]]: - client = _embedding_client(config) + resolved = build_embedding_client(config) embeddings: list[list[float]] = [] for start in range(0, len(texts), EMBEDDING_BATCH_SIZE): batch = texts[start : start + EMBEDDING_BATCH_SIZE] - response = client.embeddings.create( - model=EMBEDDING_MODELS[config.llm.provider], + response = resolved.client.embeddings.create( + model=resolved.model, input=batch, dimensions=EMBEDDING_DIM, ) @@ -249,14 +446,6 @@ def _embed_texts(texts: list[str], config: HetaConfig) -> list[list[float]]: return embeddings -def _embedding_client(config: HetaConfig) -> OpenAI: - kwargs = {"api_key": config.llm.api_key, "timeout": 120} - base_url = EMBEDDING_BASE_URLS.get(config.llm.provider) - if base_url: - kwargs["base_url"] = base_url - return OpenAI(**kwargs) - - def _wiki_id_from_path(path: str) -> int | None: name = Path(path).name match = PAGE_NAME_RE.match(name) @@ -313,6 +502,41 @@ def flush() -> None: return sections or [("Content", content.strip())] +def _split_text(text: str, max_chars: int) -> list[str]: + """Split text into pieces of at most max_chars, preferring paragraph then line then sentence boundaries.""" + text = text.strip() + if not text: + return [] + if len(text) <= max_chars: + return [text] + + for sep in ("\n\n", "\n", "。", "!", "?", ". ", " "): + if sep not in text: + continue + parts = text.split(sep) + pieces: list[str] = [] + buf = "" + for part in parts: + candidate = (buf + sep + part) if buf else part + if len(candidate) <= max_chars: + buf = candidate + else: + if buf: + pieces.append(buf) + buf = part + if buf: + pieces.append(buf) + result: list[str] = [] + for piece in pieces: + if len(piece) <= max_chars: + result.append(piece) + else: + result.extend(_split_text(piece, max_chars)) + return result + + return [text[i:i + max_chars] for i in range(0, len(text), max_chars)] + + def _chunk_text(*, title: str, summary: str, heading_path: str, body: str) -> str: parts = [f"Page: {title}"] if summary: @@ -332,6 +556,8 @@ def _hash_text(text: str) -> str: "WikiChunk", "WikiChunkSearchResult", "chunk_wiki_page", + "search_wiki_fts_index", + "search_wiki_hybrid_index", "search_wiki_vector_index", "sync_wiki_vector_index", ] diff --git a/src/heta/kb/wiki.py b/src/heta/kb/wiki.py index 9e1e80a..469b6a6 100644 --- a/src/heta/kb/wiki.py +++ b/src/heta/kb/wiki.py @@ -103,6 +103,33 @@ def validate_wiki(wiki_root: Path) -> None: raise ValueError(f"broken wiki link in {page.name}: {link}") +def repair_broken_wiki_links(wiki_root: Path) -> None: + """Downgrade broken wiki links to plain text before validation.""" + pages = wiki_root / "pages" + if not pages.exists(): + return + + page_titles = set() + for page in pages.glob("*.md"): + text = page.read_text(encoding="utf-8") + title = _frontmatter_value(text, "title") or page.stem + page_titles.add(title.strip().lower()) + page_titles.add(page.stem.strip().lower()) + + for page in pages.glob("*.md"): + text = page.read_text(encoding="utf-8") + + def replace(match: re.Match[str]) -> str: + link = match.group(1).strip() + if link.lower() in page_titles: + return match.group(0) + return link + + repaired = re.sub(r"\[\[([^\]]+)\]\]", replace, text) + if repaired != text: + page.write_text(repaired, encoding="utf-8") + + def normalize_wiki_pages(wiki_root: Path) -> NormalizeResult: """Assign numeric page filename prefixes and rewrite index paths. diff --git a/src/heta/mem/__init__.py b/src/heta/mem/__init__.py new file mode 100644 index 0000000..4a9bf41 --- /dev/null +++ b/src/heta/mem/__init__.py @@ -0,0 +1 @@ +"""Little Heta memory module.""" diff --git a/src/heta/mem/clean.py b/src/heta/mem/clean.py new file mode 100644 index 0000000..9bf8d17 --- /dev/null +++ b/src/heta/mem/clean.py @@ -0,0 +1,66 @@ +"""Wipe all memory data while preserving the schema.""" + +from __future__ import annotations + +import sqlite3 +from dataclasses import dataclass + + +@dataclass +class CleanMemoryResult: + deleted_sessions: int + deleted_l0_turns: int + deleted_l1_episodes: int + deleted_l2_facts: int + deleted_kb_insights: int + deleted_meta: int + + +def clean_memory(conn: sqlite3.Connection) -> CleanMemoryResult: + """Delete every row from all memory tables. Schema is preserved.""" + sessions = _count(conn, "session") + turns = _count(conn, "l0_turn") + episodes = _count(conn, "l1_episodic") + facts = _count(conn, "l2_semantic") + insights = _count(conn, "kb_insight") + meta = _count(conn, "memory_meta") + + # vec0 and FTS5 virtual tables must be cleared before main tables + conn.execute("DELETE FROM l2_fact_vec") + conn.execute("DELETE FROM l1_episode_vec") + conn.execute("DELETE FROM l0_turn_fts") + conn.execute("DELETE FROM kb_insight_vec") + # legacy vec tables (may exist on older DBs) + for t in ("kb_chunk_vec", "kb_qa_vec"): + try: + conn.execute(f"DELETE FROM {t}") + except Exception: + pass + + # delete leaf tables first to avoid FK ordering issues + conn.execute("DELETE FROM kb_insight") + # legacy tables + for t in ("kb_qa_chunk", "kb_qa", "kb_chunk", "kb_source"): + try: + conn.execute(f"DELETE FROM {t}") + except Exception: + pass + conn.execute("DELETE FROM l2_semantic") + conn.execute("DELETE FROM l1_episodic") + conn.execute("DELETE FROM memory_meta") + conn.execute("DELETE FROM l0_turn") + conn.execute("DELETE FROM session") + conn.commit() + + return CleanMemoryResult( + deleted_sessions=sessions, + deleted_l0_turns=turns, + deleted_l1_episodes=episodes, + deleted_l2_facts=facts, + deleted_kb_insights=insights, + deleted_meta=meta, + ) + + +def _count(conn: sqlite3.Connection, table: str) -> int: + return conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] diff --git a/src/heta/mem/client.py b/src/heta/mem/client.py new file mode 100644 index 0000000..9806bbc --- /dev/null +++ b/src/heta/mem/client.py @@ -0,0 +1,25 @@ +"""LLM and embedding client factories for the memory module.""" + +from __future__ import annotations + +from openai import OpenAI + +from heta.config.schema import HetaConfig +from heta.providers.clients import ( + EMBEDDING_DIM, + build_chat_client, + build_embedding_client as build_provider_embedding_client, + extra_body, +) + + +def build_client(config: HetaConfig) -> tuple[OpenAI, str]: + """Return (client, model) for text generation.""" + resolved = build_chat_client(config, timeout=60) + return resolved.client, resolved.model + + +def build_embedding_client(config: HetaConfig) -> tuple[OpenAI, str]: + """Return (client, model) for embedding generation.""" + resolved = build_provider_embedding_client(config, timeout=120) + return resolved.client, resolved.model diff --git a/src/heta/mem/db.py b/src/heta/mem/db.py new file mode 100644 index 0000000..55d39c0 --- /dev/null +++ b/src/heta/mem/db.py @@ -0,0 +1,210 @@ +"""SQLite connection factory and schema initialisation.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import sqlite_vec + +from heta.mem.client import EMBEDDING_DIM + + +def get_connection(path: Path, *, with_vec: bool = False) -> sqlite3.Connection: + conn = sqlite3.connect(str(path)) + if with_vec: + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + conn.row_factory = sqlite3.Row + return conn + + +def init_db(conn: sqlite3.Connection) -> None: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS session ( + session_id TEXT PRIMARY KEY, + started_at INTEGER NOT NULL, + ended_at INTEGER, + consolidated INTEGER NOT NULL DEFAULT 0, + consolidated_at INTEGER + ); + + CREATE TABLE IF NOT EXISTS l0_turn ( + session_id TEXT NOT NULL REFERENCES session(session_id), + turn_index INTEGER NOT NULL, + role TEXT NOT NULL, + modality TEXT NOT NULL DEFAULT 'text', + text_content TEXT NOT NULL, + created_at INTEGER NOT NULL, + UNIQUE(session_id, turn_index) + ); + + CREATE TABLE IF NOT EXISTS memory_meta ( + memory_id TEXT PRIMARY KEY, + memory_type TEXT NOT NULL, + session_id TEXT REFERENCES session(session_id), + origin TEXT NOT NULL, + kb_uid TEXT, + status TEXT NOT NULL DEFAULT 'active', + deprecated_by TEXT REFERENCES memory_meta(memory_id), + recency_score REAL NOT NULL DEFAULT 1.0, + access_freq INTEGER NOT NULL DEFAULT 0, + user_emphasis REAL NOT NULL DEFAULT 0.0, + importance REAL NOT NULL DEFAULT 0.5, + confidence REAL NOT NULL DEFAULT 0.9, + created_at INTEGER NOT NULL, + last_access_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS l1_episodic ( + memory_id TEXT PRIMARY KEY REFERENCES memory_meta(memory_id) ON DELETE CASCADE, + who TEXT NOT NULL, + what TEXT NOT NULL, + where_loc TEXT, + when_ts INTEGER, + when_text TEXT, + when_resolved TEXT, + when_precision TEXT, + why TEXT, + summary TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_l1_when ON l1_episodic(when_ts); + + CREATE TABLE IF NOT EXISTS l2_semantic ( + memory_id TEXT PRIMARY KEY REFERENCES memory_meta(memory_id) ON DELETE CASCADE, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + object_type TEXT NOT NULL DEFAULT 'literal', + fact_text TEXT NOT NULL DEFAULT '', + t_valid_start INTEGER NOT NULL, + t_valid_end INTEGER + ); + + CREATE INDEX IF NOT EXISTS idx_l2_predicate ON l2_semantic(predicate); + + CREATE TABLE IF NOT EXISTS kb_insight ( + memory_id TEXT PRIMARY KEY REFERENCES memory_meta(memory_id) ON DELETE CASCADE, + insight TEXT NOT NULL, + question TEXT, + source_path TEXT NOT NULL, + wiki_id INTEGER, + heading_path TEXT, + created_at INTEGER NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_kb_insight_source ON kb_insight(source_path); + CREATE INDEX IF NOT EXISTS idx_kb_insight_wiki ON kb_insight(wiki_id); + + CREATE TABLE IF NOT EXISTS kb_insight_source ( + memory_id TEXT NOT NULL REFERENCES kb_insight(memory_id) ON DELETE CASCADE, + source_path TEXT NOT NULL, + PRIMARY KEY (memory_id, source_path) + ); + + CREATE INDEX IF NOT EXISTS idx_kb_insight_source_path + ON kb_insight_source(source_path); + """) + _migrate(conn) + _ensure_vec_table(conn) + conn.commit() + + +def _migrate(conn: sqlite3.Connection) -> None: + """Add columns introduced after initial schema creation.""" + tables = {row[0] for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'")} + + l2_cols = {row[1] for row in conn.execute("PRAGMA table_info(l2_semantic)")} + if "fact_text" not in l2_cols: + conn.execute("ALTER TABLE l2_semantic ADD COLUMN fact_text TEXT NOT NULL DEFAULT ''") + if "when_text" not in l2_cols: + conn.execute("ALTER TABLE l2_semantic ADD COLUMN when_text TEXT") + if "when_resolved" not in l2_cols: + conn.execute("ALTER TABLE l2_semantic ADD COLUMN when_resolved TEXT") + if "when_precision" not in l2_cols: + conn.execute("ALTER TABLE l2_semantic ADD COLUMN when_precision TEXT") + + l1_cols = {row[1] for row in conn.execute("PRAGMA table_info(l1_episodic)")} + if "when_resolved" not in l1_cols: + conn.execute("ALTER TABLE l1_episodic ADD COLUMN when_resolved TEXT") + if "when_precision" not in l1_cols: + conn.execute("ALTER TABLE l1_episodic ADD COLUMN when_precision TEXT") + + # Backfill kb_insight_source from kb_insight.source_path for pre-existing rows. + # Idempotent: PRIMARY KEY (memory_id, source_path) prevents duplicates on rerun. + try: + conn.execute(""" + INSERT OR IGNORE INTO kb_insight_source (memory_id, source_path) + SELECT memory_id, source_path FROM kb_insight + WHERE source_path IS NOT NULL AND source_path != '' + """) + except Exception: + pass + + # legacy tables from earlier design — kept so existing DBs don't break + if "kb_source" not in tables: + conn.execute("""CREATE TABLE kb_source ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + wiki_id INTEGER, page_title TEXT, + page_path TEXT NOT NULL, heading_path TEXT, + synced_at INTEGER NOT NULL)""") + if "kb_chunk" not in tables: + conn.execute("""CREATE TABLE kb_chunk ( + memory_id TEXT PRIMARY KEY, wiki_id INTEGER, page_title TEXT, + page_path TEXT NOT NULL, heading_path TEXT, + chunk_text TEXT NOT NULL, synced_at INTEGER NOT NULL)""") + if "kb_qa" not in tables: + conn.execute("""CREATE TABLE kb_qa ( + memory_id TEXT PRIMARY KEY, + question TEXT NOT NULL, answer TEXT NOT NULL, + created_at INTEGER NOT NULL)""") + if "kb_qa_chunk" not in tables: + conn.execute("""CREATE TABLE kb_qa_chunk ( + qa_memory_id TEXT NOT NULL, chunk_memory_id TEXT NOT NULL, + PRIMARY KEY (qa_memory_id, chunk_memory_id))""") + + +def _ensure_vec_table(conn: sqlite3.Connection) -> None: + conn.execute( + f"""CREATE VIRTUAL TABLE IF NOT EXISTS l2_fact_vec USING vec0( + memory_id TEXT PRIMARY KEY, + embedding FLOAT[{EMBEDDING_DIM}] + )""" + ) + conn.execute( + f"""CREATE VIRTUAL TABLE IF NOT EXISTS l1_episode_vec USING vec0( + memory_id TEXT PRIMARY KEY, + embedding FLOAT[{EMBEDDING_DIM}] + )""" + ) + conn.execute( + """CREATE VIRTUAL TABLE IF NOT EXISTS l0_turn_fts USING fts5( + session_id UNINDEXED, + turn_index UNINDEXED, + text_content + )""" + ) + conn.execute( + f"""CREATE VIRTUAL TABLE IF NOT EXISTS kb_insight_vec USING vec0( + memory_id TEXT PRIMARY KEY, + embedding FLOAT[{EMBEDDING_DIM}] + )""" + ) + # legacy vec tables — kept so existing DBs don't break + conn.execute( + f"""CREATE VIRTUAL TABLE IF NOT EXISTS kb_chunk_vec USING vec0( + memory_id TEXT PRIMARY KEY, + embedding FLOAT[{EMBEDDING_DIM}] + )""" + ) + conn.execute( + f"""CREATE VIRTUAL TABLE IF NOT EXISTS kb_qa_vec USING vec0( + memory_id TEXT PRIMARY KEY, + embedding FLOAT[{EMBEDDING_DIM}] + )""" + ) diff --git a/src/heta/mem/embedder.py b/src/heta/mem/embedder.py new file mode 100644 index 0000000..007afea --- /dev/null +++ b/src/heta/mem/embedder.py @@ -0,0 +1,21 @@ +"""Embedding calls for the memory module.""" + +from __future__ import annotations + +from openai import OpenAI + +from heta.mem.client import EMBEDDING_DIM + + +def embed_text(client: OpenAI, model: str, text: str) -> list[float]: + response = client.embeddings.create( + model=model, + input=[text], + dimensions=EMBEDDING_DIM, + ) + return response.data[0].embedding + + +def fact_text(subject: str, predicate: str, object_: str) -> str: + """Convert a triple to a natural language string for embedding.""" + return f"{subject} {predicate.replace('_', ' ')} {object_}" diff --git a/src/heta/mem/kb_invalidate.py b/src/heta/mem/kb_invalidate.py new file mode 100644 index 0000000..29ac9b6 --- /dev/null +++ b/src/heta/mem/kb_invalidate.py @@ -0,0 +1,99 @@ +"""Invalidate kb_insight memories whose source wiki pages changed.""" + +from __future__ import annotations + +import logging +import sqlite3 +from collections.abc import Iterable + +from heta.mem.db import get_connection, init_db +from heta.mem.paths import db_path + +logger = logging.getLogger(__name__) + + +def invalidate_by_paths(paths: Iterable[str]) -> int: + """Delete kb_insight memories whose source_path is in `paths`. Returns count deleted. + + Silently returns 0 when the memory DB does not exist yet, so KB operations + succeed even if the user never initialised memory. + """ + path_list = [p for p in paths if p] + if not path_list: + return 0 + if not db_path().exists(): + return 0 + try: + conn = get_connection(db_path(), with_vec=True) + except Exception: + logger.warning("memory DB open failed; skip kb_insight invalidation", exc_info=True) + return 0 + try: + init_db(conn) + return delete_insights_by_paths(conn, path_list) + except Exception: + logger.warning("kb_insight invalidation failed", exc_info=True) + return 0 + finally: + conn.close() + + +def invalidate_all() -> int: + """Delete every kb_insight memory. Returns count deleted. + + Silently returns 0 when the memory DB does not exist yet. + """ + if not db_path().exists(): + return 0 + try: + conn = get_connection(db_path(), with_vec=True) + except Exception: + logger.warning("memory DB open failed; skip kb_insight invalidation", exc_info=True) + return 0 + try: + init_db(conn) + return delete_all_insights(conn) + except Exception: + logger.warning("kb_insight invalidation (all) failed", exc_info=True) + return 0 + finally: + conn.close() + + +def delete_insights_by_paths(conn: sqlite3.Connection, paths: list[str]) -> int: + """Connection-level helper. Exposed for tests and callers with an open conn. + + An insight is invalidated if ANY of its source_paths matches a changed page. + """ + if not paths: + return 0 + placeholders = ",".join("?" for _ in paths) + ids = [ + r[0] + for r in conn.execute( + f"SELECT DISTINCT memory_id FROM kb_insight_source WHERE source_path IN ({placeholders})", + paths, + ).fetchall() + ] + if not ids: + return 0 + id_placeholders = ",".join("?" for _ in ids) + # vec0 virtual table does not honour FK cascade; delete explicitly. + conn.execute(f"DELETE FROM kb_insight_vec WHERE memory_id IN ({id_placeholders})", ids) + # memory_meta delete cascades to kb_insight via ON DELETE CASCADE, + # which in turn cascades to kb_insight_source. + conn.execute(f"DELETE FROM memory_meta WHERE memory_id IN ({id_placeholders})", ids) + conn.commit() + return len(ids) + + +def delete_all_insights(conn: sqlite3.Connection) -> int: + """Connection-level helper to wipe all kb_insight rows. Exposed for tests.""" + ids = [r[0] for r in conn.execute("SELECT memory_id FROM kb_insight").fetchall()] + if not ids: + return 0 + conn.execute("DELETE FROM kb_insight_vec") + placeholders = ",".join("?" for _ in ids) + conn.execute(f"DELETE FROM memory_meta WHERE memory_id IN ({placeholders})", ids) + conn.commit() + return len(ids) diff --git a/src/heta/mem/kb_store.py b/src/heta/mem/kb_store.py new file mode 100644 index 0000000..7015cf7 --- /dev/null +++ b/src/heta/mem/kb_store.py @@ -0,0 +1,70 @@ +"""CRUD and search operations for kb_insight.""" + +from __future__ import annotations + +import sqlite3 + +import sqlite_vec + +from heta.mem.models import KBInsight + + +def insert_kb_insight(conn: sqlite3.Connection, insight: KBInsight) -> None: + """Insert insight row plus one row per source_path into the join table.""" + conn.execute( + """INSERT INTO kb_insight + (memory_id, insight, question, source_path, wiki_id, heading_path, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (insight.memory_id, insight.insight, insight.question, insight.source_path, + insight.wiki_id, insight.heading_path, insight.created_at), + ) + for path in insight.source_paths: + conn.execute( + "INSERT OR IGNORE INTO kb_insight_source (memory_id, source_path) VALUES (?, ?)", + (insight.memory_id, path), + ) + + +def insert_insight_embedding( + conn: sqlite3.Connection, memory_id: str, embedding: list[float] +) -> None: + conn.execute( + "INSERT INTO kb_insight_vec (memory_id, embedding) VALUES (?, ?)", + (memory_id, sqlite_vec.serialize_float32(embedding)), + ) + + +def get_source_paths(conn: sqlite3.Connection, memory_id: str) -> list[str]: + rows = conn.execute( + "SELECT source_path FROM kb_insight_source WHERE memory_id = ? ORDER BY source_path", + (memory_id,), + ).fetchall() + return [r[0] for r in rows] + + +def search_kb_insights( + conn: sqlite3.Connection, + embedding: list[float], + top_k: int = 5, +) -> list[dict]: + rows = conn.execute( + """SELECT i.memory_id, i.insight, i.source_path, v.distance + FROM kb_insight_vec v + JOIN kb_insight i ON i.memory_id = v.memory_id + JOIN memory_meta m ON m.memory_id = i.memory_id + WHERE v.embedding MATCH ? AND k = ? + AND m.status = 'active' + ORDER BY v.distance""", + (sqlite_vec.serialize_float32(embedding), top_k), + ).fetchall() + results = [] + for r in rows: + mid = r["memory_id"] + results.append({ + "memory_id": mid, + "insight": r["insight"], + "source_path": r["source_path"], + "source_paths": get_source_paths(conn, mid), + "score": 1.0 / (1.0 + float(r["distance"])), + }) + return results diff --git a/src/heta/mem/kb_writer.py b/src/heta/mem/kb_writer.py new file mode 100644 index 0000000..f259783 --- /dev/null +++ b/src/heta/mem/kb_writer.py @@ -0,0 +1,140 @@ +"""Store agent-distilled kb_insights into memory, with dedup.""" + +from __future__ import annotations + +import json +import logging +import time +import uuid +from pathlib import Path + +from heta.config.schema import HetaConfig +from heta.mem import meta_store +from heta.mem.client import build_client, build_embedding_client, extra_body +from heta.mem.db import get_connection, init_db +from heta.mem.embedder import embed_text +from heta.mem.kb_store import insert_insight_embedding, insert_kb_insight, search_kb_insights +from heta.mem.models import KBInsight, MemoryMeta +from heta.mem.paths import db_path, ensure_mem_dir +from heta.mem.prompts import INSIGHT_DEDUP_PROMPT +from heta.query.models import QueryInsight, QuerySource + +logger = logging.getLogger(__name__) + +# Below this similarity score the candidate is almost certainly a new insight, +# so we skip the LLM dedup call entirely. +_DEDUP_SIMILARITY_THRESHOLD = 0.7 +_DEDUP_TOP_K = 5 + + +def remember_kb_insights( + question: str, + insights: list[QueryInsight], + sources: list[QuerySource], + config: HetaConfig, + base_dir: Path | None = None, +) -> int: + """Persist agent-distilled insights into memory. Returns count stored after dedup.""" + if not insights: + return 0 + + ensure_mem_dir() + conn = get_connection(db_path(), with_vec=True) + init_db(conn) + + llm_client, llm_model = build_client(config) + emb_client, emb_model = build_embedding_client(config) + now = int(time.time()) + + # Build a path → QuerySource map so insights can inherit wiki_id / heading + # from the primary source. + source_index = {s.path: s for s in sources} + total = 0 + + for qi in insights: + text = qi.text.strip() + if not text: + continue + + emb = embed_text(emb_client, emb_model, text) + similar = search_kb_insights(conn, emb, top_k=_DEDUP_TOP_K) + if similar and similar[0]["score"] >= _DEDUP_SIMILARITY_THRESHOLD: + if _is_duplicate(llm_client, llm_model, text, similar, config): + logger.debug("skip duplicate insight: %.80s", text) + continue + + primary_path = qi.source_paths[0] if qi.source_paths else "" + primary = source_index.get(primary_path) + wiki_id = primary.wiki_id if primary else None + heading_path = primary.heading_path if primary else None + + memory_id = str(uuid.uuid4()) + meta = MemoryMeta( + memory_id=memory_id, + memory_type="kb_insight", + session_id=None, + origin="kb_insight", + kb_uid=str(wiki_id) if wiki_id is not None else None, + created_at=now, + last_access_at=now, + ) + insight = KBInsight( + memory_id=memory_id, + insight=text, + source_paths=list(qi.source_paths), + created_at=now, + question=question, + wiki_id=wiki_id, + heading_path=heading_path, + ) + meta_store.insert_meta(conn, meta) + insert_kb_insight(conn, insight) + insert_insight_embedding(conn, memory_id, emb) + total += 1 + + conn.commit() + conn.close() + return total + + +def _is_duplicate( + client, + model: str, + insight_text: str, + similar: list[dict], + config: HetaConfig, +) -> bool: + """Ask the LLM whether the new insight is fully covered by the similar set.""" + existing_block = "\n".join( + f"[{i + 1}] {s['insight']}" for i, s in enumerate(similar) + ) + user_msg = f"NEW insight:\n{insight_text}\n\nEXISTING similar insights:\n{existing_block}" + kwargs: dict = { + "model": model, + "messages": [ + {"role": "system", "content": INSIGHT_DEDUP_PROMPT}, + {"role": "user", "content": user_msg}, + ], + "temperature": 0.0, + } + body = extra_body(config) + if body: + kwargs["extra_body"] = body + try: + raw = (client.chat.completions.create(**kwargs).choices[0].message.content or "").strip() + data = _parse_json(raw) + return bool(data.get("duplicate", False)) + except Exception: + logger.warning("dedup check failed, defaulting to store: %.80s", insight_text) + return False + + +def _parse_json(raw: str) -> dict: + text = raw.strip() + if text.startswith("```"): + lines = text.splitlines() + text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) + try: + return json.loads(text) + except (json.JSONDecodeError, AttributeError): + return {} diff --git a/src/heta/mem/l0_search.py b/src/heta/mem/l0_search.py new file mode 100644 index 0000000..1ad5fed --- /dev/null +++ b/src/heta/mem/l0_search.py @@ -0,0 +1,44 @@ +"""Full-text search on L0 raw turns.""" + +from __future__ import annotations + +import sqlite3 + + +def _build_fts_query(query: str) -> str: + """Build an FTS5 OR query from individual tokens, avoiding syntax errors.""" + import re + tokens = re.findall(r'[\w一-鿿]+', query) + if not tokens: + return '""' + # quote each token individually to handle special chars, then OR them + return " OR ".join('"' + t.replace('"', '""') + '"' for t in tokens) + + +def search_turns(conn: sqlite3.Connection, query: str, top_k: int = 3) -> list[dict]: + """FTS5 search on raw turn text. Returns matching turns with context.""" + fts_query = _build_fts_query(query) + try: + rows = conn.execute( + """ + SELECT session_id, turn_index, text_content, rank + FROM l0_turn_fts + WHERE text_content MATCH ? + ORDER BY rank + LIMIT ? + """, + (fts_query, top_k), + ).fetchall() + except Exception: + rows = [] + + results = [] + for r in rows: + score = 1.0 / (1.0 + abs(float(r["rank"]))) + results.append({ + "session_id": r["session_id"], + "turn_index": r["turn_index"], + "text_content": r["text_content"], + "score": score, + }) + return results diff --git a/src/heta/mem/l0_store.py b/src/heta/mem/l0_store.py new file mode 100644 index 0000000..e66b3d0 --- /dev/null +++ b/src/heta/mem/l0_store.py @@ -0,0 +1,21 @@ +"""Write operations for the l0_turn table.""" + +from __future__ import annotations + +import sqlite3 + +from heta.mem.models import L0Turn + + +def insert_turn(conn: sqlite3.Connection, turn: L0Turn) -> None: + conn.execute( + "INSERT INTO l0_turn (session_id, turn_index, role, modality, text_content, created_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (turn.session_id, turn.turn_index, turn.role, + turn.modality, turn.text_content, turn.created_at), + ) + conn.execute( + "INSERT INTO l0_turn_fts (session_id, turn_index, text_content) VALUES (?, ?, ?)", + (turn.session_id, turn.turn_index, turn.text_content), + ) + conn.commit() diff --git a/src/heta/mem/l1_extractor.py b/src/heta/mem/l1_extractor.py new file mode 100644 index 0000000..75173e3 --- /dev/null +++ b/src/heta/mem/l1_extractor.py @@ -0,0 +1,85 @@ +"""LLM-based episodic memory extraction.""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from typing import Any + +from openai import OpenAI + +from heta.config.schema import HetaConfig +from heta.mem.client import extra_body +from heta.mem.prompts import EPISODE_EXTRACTION_PROMPT + +logger = logging.getLogger(__name__) + + +def extract_episodes( + client: OpenAI, + model: str, + text: str, + config: HetaConfig, + session_ts: int | None = None, +) -> list[dict[str, Any]]: + """Call the LLM and return a list of raw episode dicts.""" + anchor_date = _fmt_date(session_ts) + user_content = f"Anchor date: {anchor_date}\n\nText:\n{text}" + + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": EPISODE_EXTRACTION_PROMPT}, + {"role": "user", "content": user_content}, + ], + temperature=0.2, + **({"extra_body": extra_body(config)} if extra_body(config) else {}), + ) + raw = response.choices[0].message.content or "" + return _parse_episodes(raw) + + +def resolve_when_ts(when_resolved: str | None) -> int | None: + """Parse a variable-precision date string to unix timestamp of period start. + + Accepts: YYYY-MM-DD, YYYY-Www (ISO week), YYYY-MM, YYYY + """ + if not when_resolved: + return None + s = when_resolved.strip() + for fmt in ("%Y-%m-%d", "%Y-%m", "%Y"): + try: + dt = datetime.strptime(s, fmt) + return int(dt.replace(tzinfo=timezone.utc).timestamp()) + except ValueError: + pass + # ISO week: "2026-W21" + try: + dt = datetime.strptime(s + "-1", "%Y-W%W-%w") + return int(dt.replace(tzinfo=timezone.utc).timestamp()) + except ValueError: + pass + return None + + +def _fmt_date(ts: int | None) -> str: + if ts is None: + return datetime.now().strftime("%Y-%m-%d") + return datetime.fromtimestamp(ts).strftime("%Y-%m-%d") + + +def _parse_episodes(raw: str) -> list[dict[str, Any]]: + text = raw.strip() + if text.startswith("```"): + lines = text.splitlines() + text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) + try: + data = json.loads(text) + episodes = data.get("episodes", []) + if not isinstance(episodes, list): + return [] + return [e for e in episodes if isinstance(e, dict) and "what" in e] + except (json.JSONDecodeError, AttributeError): + logger.warning("Failed to parse episode extraction response: %s", raw[:200]) + return [] diff --git a/src/heta/mem/l1_search.py b/src/heta/mem/l1_search.py new file mode 100644 index 0000000..c79a5c8 --- /dev/null +++ b/src/heta/mem/l1_search.py @@ -0,0 +1,37 @@ +"""Vector search on L1 episodic memory summaries.""" + +from __future__ import annotations + +import sqlite3 + +import sqlite_vec + + +def search_episodes(conn: sqlite3.Connection, embedding: list[float], top_k: int = 3) -> list[dict]: + """Return active episodes closest to the query embedding.""" + rows = conn.execute( + """ + SELECT e.memory_id, e.who, e.what, e.where_loc, e.when_text, e.why, e.summary, v.distance + FROM l1_episode_vec v + JOIN l1_episodic e ON e.memory_id = v.memory_id + JOIN memory_meta m ON m.memory_id = e.memory_id + WHERE v.embedding MATCH ? AND k = ? + AND m.status = 'active' + ORDER BY v.distance + """, + (sqlite_vec.serialize_float32(embedding), max(1, top_k)), + ).fetchall() + + return [ + { + "memory_id": r["memory_id"], + "who": r["who"], + "what": r["what"], + "where_loc": r["where_loc"], + "when_text": r["when_text"], + "why": r["why"], + "summary": r["summary"], + "score": 1.0 / (1.0 + float(r["distance"])), + } + for r in rows + ] diff --git a/src/heta/mem/l1_store.py b/src/heta/mem/l1_store.py new file mode 100644 index 0000000..168260c --- /dev/null +++ b/src/heta/mem/l1_store.py @@ -0,0 +1,31 @@ +"""Write operations for the l1_episodic table.""" + +from __future__ import annotations + +import sqlite3 + +from heta.mem.models import L1Episodic + + +def insert_episodic(conn: sqlite3.Connection, episode: L1Episodic) -> None: + conn.execute( + """INSERT INTO l1_episodic + (memory_id, who, what, where_loc, + when_ts, when_text, when_resolved, when_precision, why, summary) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (episode.memory_id, episode.who, episode.what, episode.where_loc, + episode.when_ts, episode.when_text, episode.when_resolved, + episode.when_precision, episode.why, episode.summary), + ) + conn.commit() + + +def insert_episode_embedding( + conn: sqlite3.Connection, memory_id: str, embedding: list[float] +) -> None: + import sqlite_vec + conn.execute( + "INSERT INTO l1_episode_vec (memory_id, embedding) VALUES (?, ?)", + (memory_id, sqlite_vec.serialize_float32(embedding)), + ) + conn.commit() diff --git a/src/heta/mem/l2_conflict.py b/src/heta/mem/l2_conflict.py new file mode 100644 index 0000000..c7de836 --- /dev/null +++ b/src/heta/mem/l2_conflict.py @@ -0,0 +1,83 @@ +"""Semantic conflict detection for L2 fact memories.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from openai import OpenAI + +from heta.config.schema import HetaConfig +from heta.mem.client import extra_body +from heta.mem.embedder import embed_text +from heta.mem.l2_store import search_similar_facts +from heta.mem.prompts import CONFLICT_JUDGE_PROMPT + +logger = logging.getLogger(__name__) + + +def detect_conflicts( + conn: Any, + new_fact_text: str, + llm_client: OpenAI, + llm_model: str, + emb_client: OpenAI, + emb_model: str, + config: HetaConfig, + top_k: int = 10, + session_id: str | None = None, +) -> list[str]: + """Return memory_ids of existing facts that the new fact contradicts.""" + embedding = embed_text(emb_client, emb_model, new_fact_text) + candidates = search_similar_facts(conn, embedding, top_k=top_k, exclude_session_id=session_id) + + if not candidates: + return [], embedding + + ids_to_deprecate = _judge(llm_client, llm_model, new_fact_text, candidates, config) + return ids_to_deprecate, embedding + + +def _judge( + client: OpenAI, + model: str, + new_fact_text: str, + candidates: list[dict], + config: HetaConfig, +) -> list[str]: + candidate_lines = "\n".join( + f'- id: "{c["memory_id"]}" fact: "{c["fact_text"]}"' + for c in candidates + ) + user_msg = f'New fact: "{new_fact_text}"\n\nExisting facts:\n{candidate_lines}' + + kwargs: dict = { + "model": model, + "messages": [ + {"role": "system", "content": CONFLICT_JUDGE_PROMPT}, + {"role": "user", "content": user_msg}, + ], + "temperature": 0.0, + } + body = extra_body(config) + if body: + kwargs["extra_body"] = body + + response = client.chat.completions.create(**kwargs) + raw = (response.choices[0].message.content or "").strip() + return _parse_judge_response(raw) + + +def _parse_judge_response(raw: str) -> list[str]: + text = raw + if text.startswith("```"): + lines = text.splitlines() + text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) + try: + data = json.loads(text) + result = data.get("deprecate", []) + return result if isinstance(result, list) else [] + except (json.JSONDecodeError, AttributeError): + logger.warning("Failed to parse conflict judge response: %s", raw[:200]) + return [] diff --git a/src/heta/mem/l2_extractor.py b/src/heta/mem/l2_extractor.py new file mode 100644 index 0000000..2e13268 --- /dev/null +++ b/src/heta/mem/l2_extractor.py @@ -0,0 +1,65 @@ +"""LLM-based semantic fact extraction.""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime +from typing import Any + +from openai import OpenAI + +from heta.config.schema import HetaConfig +from heta.mem.client import extra_body +from heta.mem.prompts import FACT_EXTRACTION_PROMPT + +logger = logging.getLogger(__name__) + + +def extract_facts( + client: OpenAI, + model: str, + text: str, + config: HetaConfig, + session_ts: int | None = None, +) -> list[dict[str, Any]]: + """Call the LLM and return a list of raw fact dicts.""" + anchor_date = _fmt_date(session_ts) + user_content = f"Anchor date: {anchor_date}\n\nText:\n{text}" + + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": FACT_EXTRACTION_PROMPT}, + {"role": "user", "content": user_content}, + ], + temperature=0.2, + **({"extra_body": extra_body(config)} if extra_body(config) else {}), + ) + raw = response.choices[0].message.content or "" + return _parse_facts(raw) + + +def _fmt_date(ts: int | None) -> str: + if ts is None: + return datetime.now().strftime("%Y-%m-%d") + return datetime.fromtimestamp(ts).strftime("%Y-%m-%d") + + +def _parse_facts(raw: str) -> list[dict[str, Any]]: + text = raw.strip() + if text.startswith("```"): + lines = text.splitlines() + text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) + try: + data = json.loads(text) + facts = data.get("facts", []) + if not isinstance(facts, list): + return [] + return [ + f for f in facts + if isinstance(f, dict) and all(k in f for k in ("subject", "predicate", "object")) + ] + except (json.JSONDecodeError, AttributeError): + logger.warning("Failed to parse fact extraction response: %s", raw[:200]) + return [] diff --git a/src/heta/mem/l2_store.py b/src/heta/mem/l2_store.py new file mode 100644 index 0000000..8237c98 --- /dev/null +++ b/src/heta/mem/l2_store.py @@ -0,0 +1,80 @@ +"""Write operations, vector search, and conflict handling for l2_semantic.""" + +from __future__ import annotations + +import sqlite3 + +import sqlite_vec + +from heta.mem.models import L2Semantic + + +def insert_fact(conn: sqlite3.Connection, fact: L2Semantic) -> None: + conn.execute( + """INSERT INTO l2_semantic + (memory_id, subject, predicate, object, object_type, + fact_text, t_valid_start, t_valid_end, when_text, when_resolved, when_precision) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (fact.memory_id, fact.subject, fact.predicate, fact.object, + fact.object_type, fact.fact_text, fact.t_valid_start, fact.t_valid_end, + fact.when_text, fact.when_resolved, fact.when_precision), + ) + conn.commit() + + +def insert_fact_embedding( + conn: sqlite3.Connection, memory_id: str, embedding: list[float] +) -> None: + conn.execute( + "INSERT INTO l2_fact_vec (memory_id, embedding) VALUES (?, ?)", + (memory_id, sqlite_vec.serialize_float32(embedding)), + ) + conn.commit() + + +def search_similar_facts( + conn: sqlite3.Connection, + embedding: list[float], + top_k: int = 5, + exclude_session_id: str | None = None, +) -> list[dict]: + """Return active facts closest to the given embedding, excluding the current session.""" + if exclude_session_id: + rows = conn.execute( + """ + SELECT s.memory_id, s.fact_text, v.distance + FROM l2_fact_vec v + JOIN l2_semantic s ON s.memory_id = v.memory_id + JOIN memory_meta m ON m.memory_id = s.memory_id + WHERE v.embedding MATCH ? AND k = ? + AND m.status = 'active' + AND s.t_valid_end IS NULL + AND m.session_id != ? + ORDER BY v.distance + """, + (sqlite_vec.serialize_float32(embedding), max(1, top_k), exclude_session_id), + ).fetchall() + else: + rows = conn.execute( + """ + SELECT s.memory_id, s.fact_text, v.distance + FROM l2_fact_vec v + JOIN l2_semantic s ON s.memory_id = v.memory_id + JOIN memory_meta m ON m.memory_id = s.memory_id + WHERE v.embedding MATCH ? AND k = ? + AND m.status = 'active' + AND s.t_valid_end IS NULL + ORDER BY v.distance + """, + (sqlite_vec.serialize_float32(embedding), max(1, top_k)), + ).fetchall() + return [{"memory_id": r["memory_id"], "fact_text": r["fact_text"], "distance": r["distance"]} + for r in rows] + + +def expire_fact(conn: sqlite3.Connection, memory_id: str, t_valid_end: int) -> None: + conn.execute( + "UPDATE l2_semantic SET t_valid_end = ? WHERE memory_id = ?", + (t_valid_end, memory_id), + ) + conn.commit() diff --git a/src/heta/mem/meta_store.py b/src/heta/mem/meta_store.py new file mode 100644 index 0000000..5977064 --- /dev/null +++ b/src/heta/mem/meta_store.py @@ -0,0 +1,30 @@ +"""CRUD operations for the memory_meta table.""" + +from __future__ import annotations + +import sqlite3 + +from heta.mem.models import MemoryMeta + + +def insert_meta(conn: sqlite3.Connection, meta: MemoryMeta) -> None: + conn.execute( + """INSERT INTO memory_meta + (memory_id, memory_type, session_id, origin, kb_uid, status, + deprecated_by, recency_score, access_freq, user_emphasis, + importance, confidence, created_at, last_access_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + (meta.memory_id, meta.memory_type, meta.session_id, meta.origin, + meta.kb_uid, meta.status, meta.deprecated_by, meta.recency_score, + meta.access_freq, meta.user_emphasis, meta.importance, + meta.confidence, meta.created_at, meta.last_access_at), + ) + conn.commit() + + +def deprecate(conn: sqlite3.Connection, memory_id: str, deprecated_by: str) -> None: + conn.execute( + "UPDATE memory_meta SET status = 'deprecated', deprecated_by = ? WHERE memory_id = ?", + (deprecated_by, memory_id), + ) + conn.commit() diff --git a/src/heta/mem/models.py b/src/heta/mem/models.py new file mode 100644 index 0000000..b2ab57f --- /dev/null +++ b/src/heta/mem/models.py @@ -0,0 +1,87 @@ +"""Dataclasses for all memory tables.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class Session: + session_id: str + started_at: int + ended_at: int | None = None + consolidated: int = 0 + consolidated_at: int | None = None + + +@dataclass +class L0Turn: + session_id: str + turn_index: int + role: str # user / assistant / system / tool + modality: str # text / audio / image / mixed + text_content: str + created_at: int + + +@dataclass +class MemoryMeta: + memory_id: str + memory_type: str # L1 / L2 + session_id: str | None + origin: str # extracted / promoted / user_explicit / consolidated + created_at: int + last_access_at: int + kb_uid: str | None = None + status: str = "active" + deprecated_by: str | None = None + recency_score: float = 1.0 + access_freq: int = 0 + user_emphasis: float = 0.0 + importance: float = 0.5 + confidence: float = 0.9 + + +@dataclass +class L1Episodic: + memory_id: str + who: str # JSON array, e.g. '["Alice", "Bob"]' + what: str + where_loc: str | None + when_ts: int | None # unix timestamp of period start + when_text: str | None # original expression ("昨天", "下个月") + when_resolved: str | None # variable-precision: "2026-05-12" / "2026-06" / "2026" + when_precision: str | None # day / week / month / year + why: str | None + summary: str # used for vector embedding + + +@dataclass +class L2Semantic: + memory_id: str + subject: str + predicate: str + object: str + object_type: str # literal / entity_ref + fact_text: str # natural language form, used for embedding + t_valid_start: int + t_valid_end: int | None = None + when_text: str | None = None # original relative expression ("下个月") + when_resolved: str | None = None # variable-precision: "2026-06" / "2026-05-12" + when_precision: str | None = None # day / week / month / year + + +@dataclass +class KBInsight: + memory_id: str + insight: str # distilled knowledge point + source_paths: list[str] # all KB pages this insight derives from + created_at: int + question: str | None = None + wiki_id: int | None = None # primary wiki id (from first source) + heading_path: str | None = None # primary heading (from first source) + + @property + def source_path(self) -> str: + """Primary source path — kept for the legacy column / display.""" + return self.source_paths[0] if self.source_paths else "" diff --git a/src/heta/mem/paths.py b/src/heta/mem/paths.py new file mode 100644 index 0000000..bdefef0 --- /dev/null +++ b/src/heta/mem/paths.py @@ -0,0 +1,19 @@ +"""Filesystem paths for the memory module.""" + +from __future__ import annotations + +from pathlib import Path + + +def mem_dir() -> Path: + return Path.home() / ".heta" / "workspace" / "mem" + + +def db_path() -> Path: + return mem_dir() / "mem.sqlite3" + + +def ensure_mem_dir() -> Path: + path = mem_dir() + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/src/heta/mem/pipeline.py b/src/heta/mem/pipeline.py new file mode 100644 index 0000000..f9249bc --- /dev/null +++ b/src/heta/mem/pipeline.py @@ -0,0 +1,149 @@ +"""Orchestrator for the heta remember pipeline.""" + +from __future__ import annotations + +import json +import time +import uuid +from dataclasses import dataclass + +from heta.config.schema import HetaConfig +from heta.mem import l0_store, l1_store, l2_store, meta_store, session_store +from heta.mem.client import build_client, build_embedding_client +from heta.mem.db import get_connection, init_db +from heta.mem.embedder import embed_text, fact_text +from heta.mem.l1_extractor import extract_episodes, resolve_when_ts +from heta.mem.l2_conflict import detect_conflicts +from heta.mem.l2_extractor import extract_facts +from heta.mem.models import L0Turn, L1Episodic, L2Semantic, MemoryMeta, Session +from heta.mem.paths import db_path, ensure_mem_dir + + +@dataclass +class RememberResult: + session_id: str + l1_count: int + l2_count: int + elapsed_s: float + + +def remember(text: str, config: HetaConfig) -> RememberResult: + ensure_mem_dir() + conn = get_connection(db_path(), with_vec=True) + init_db(conn) + + now = int(time.time()) + session_id = str(uuid.uuid4()) + llm_client, llm_model = build_client(config) + emb_client, emb_model = build_embedding_client(config) + + # --- session + L0 --- + session_store.create_session(conn, Session(session_id=session_id, started_at=now)) + l0_store.insert_turn( + conn, + L0Turn( + session_id=session_id, + turn_index=0, + role="user", + modality="text", + text_content=text, + created_at=now, + ), + ) + + # --- extract --- + t0 = time.time() + raw_episodes = extract_episodes(llm_client, llm_model, text, config, session_ts=now) + raw_facts = extract_facts(llm_client, llm_model, text, config, session_ts=now) + + # --- persist L1 --- + l1_count = 0 + for ep in raw_episodes: + memory_id = str(uuid.uuid4()) + meta = MemoryMeta( + memory_id=memory_id, + memory_type="L1", + session_id=session_id, + origin="extracted", + created_at=now, + last_access_at=now, + ) + episode = L1Episodic( + memory_id=memory_id, + who=json.dumps(ep.get("who", ["user"]), ensure_ascii=False), + what=ep.get("what", ""), + where_loc=ep.get("where_loc"), + when_ts=resolve_when_ts(ep.get("when_resolved")), + when_text=ep.get("when_text"), + when_resolved=ep.get("when_resolved"), + when_precision=ep.get("when_precision"), + why=ep.get("why"), + summary=ep.get("summary", ep.get("what", "")), + ) + meta_store.insert_meta(conn, meta) + l1_store.insert_episodic(conn, episode) + l1_emb = embed_text(emb_client, emb_model, episode.summary) + l1_store.insert_episode_embedding(conn, memory_id, l1_emb) + l1_count += 1 + + # --- persist L2 (semantic conflict resolution) --- + l2_count = 0 + for raw_fact in raw_facts: + memory_id = str(uuid.uuid4()) + subject = str(raw_fact.get("subject", "")) + predicate = str(raw_fact.get("predicate", "")) + object_ = str(raw_fact.get("object", "")) + raw_object_type = raw_fact.get("object_type", "literal") + object_type_val = raw_object_type[0] if isinstance(raw_object_type, list) else str(raw_object_type) + ft = fact_text(subject, predicate, object_) + + ids_to_deprecate, embedding = detect_conflicts( + conn=conn, + new_fact_text=ft, + llm_client=llm_client, + llm_model=llm_model, + emb_client=emb_client, + emb_model=emb_model, + config=config, + session_id=session_id, + ) + + meta = MemoryMeta( + memory_id=memory_id, + memory_type="L2", + session_id=session_id, + origin="extracted", + created_at=now, + last_access_at=now, + ) + fact_record = L2Semantic( + memory_id=memory_id, + subject=subject, + predicate=predicate, + object=object_, + object_type=object_type_val, + fact_text=ft, + t_valid_start=now, + when_text=raw_fact.get("when_text"), + when_resolved=raw_fact.get("when_resolved"), + when_precision=raw_fact.get("when_precision"), + ) + + # insert new meta + fact first so FK reference is valid + meta_store.insert_meta(conn, meta) + for old_id in ids_to_deprecate: + l2_store.expire_fact(conn, old_id, now) + meta_store.deprecate(conn, old_id, memory_id) + l2_store.insert_fact(conn, fact_record) + l2_store.insert_fact_embedding(conn, memory_id, embedding) + l2_count += 1 + + session_store.close_session(conn, session_id, int(time.time())) + conn.close() + + return RememberResult( + session_id=session_id, + l1_count=l1_count, + l2_count=l2_count, + elapsed_s=round(time.time() - t0, 2), + ) diff --git a/src/heta/mem/prompts.py b/src/heta/mem/prompts.py new file mode 100644 index 0000000..2b9f8e8 --- /dev/null +++ b/src/heta/mem/prompts.py @@ -0,0 +1,211 @@ +"""LLM prompt templates for memory extraction.""" + +from __future__ import annotations + +EPISODE_EXTRACTION_PROMPT = """\ +You are an episodic memory extraction engine for long-term personal memory. + +LANGUAGE RULE (highest priority): +All text fields you output — what, where_loc, why, summary, and names in who — MUST be +written in the SAME language as the input text. If the input is Chinese, write Chinese. +If the input is English, write English. Never translate or switch languages. + +Task: +Extract significant events and experiences from the input text as discrete episodes. +The input begins with an "Anchor date" line — use it to resolve all relative time expressions. +Return STRICT JSON only. Do not output markdown or extra text. + +Schema: +{"episodes":[{"who":["name"],"what":"event verb or short description","where_loc":"location or null","when_text":"original relative expression or null (e.g. '昨天','下周')","when_resolved":"variable-precision date or null","when_precision":"day|week|month|year or null","why":"reason or null","summary":"<=60 words self-contained description"}]} + +Definition of a GOOD episode: +A coherent, bounded real-world event or experience — something that happened, is happening, +or is concretely planned — that a person would remember and recount as a story. + +What TO extract: +- Past events: trips, meetings, purchases, job changes, medical visits, conflicts, milestones +- Ongoing situations: a project in progress, a health issue, a relationship change +- Concrete plans: confirmed future events with enough specificity (who, what, when) +- Significant outcomes: a decision made, a problem solved, a goal reached or failed + +What NOT to extract: +- General opinions or preferences (those belong in facts, not episodes) +- Abstract discussions or hypotheticals without resolution +- Trivial micro-exchanges with no event content +- Duplicate episodes restating the same event + +Quantity discipline: +- A short paragraph should yield 1 to 3 episodes. Do not force-create episodes from thin content. +- If no meaningful event is present, return {"episodes":[]}. + +Format rules: +- `summary` must be self-contained: a reader with no context should understand what happened. +- `who` is a JSON array of names. If the subject is implicit (e.g. "I"), use "user". +- `where_loc` and `why` are optional; use null if not mentioned. +- `when_text`: copy the original relative expression verbatim ("昨天", "下个月", "last Monday"). +- `when_resolved` + `when_precision`: resolve using the Anchor date with honest precision: + - Day-level: "2026-05-12", precision="day" (e.g. "昨天", "3天前") + - Week-level: "2026-W21", precision="week" (e.g. "下周", "上周") + - Month-level: "2026-06", precision="month" (e.g. "下个月", "上个月") + - Year-level: "2026", precision="year" (e.g. "明年", "去年") + Do NOT pad to YYYY-MM-01 — use only the precision the expression actually conveys. + If unresolvable, both fields are null. +""" + +RECALL_RANKING_PROMPT = """\ +You are a memory-layer relevance ranker. +Given a question and evidence retrieved from multiple memory layers, rank the layers from most to least relevant for answering the question. +Return STRICT JSON only. Do not output anything outside the JSON object. + +Schema: +{"ranking": ["best_layer", "second_layer", ...], "reason": "one sentence explaining which layer is most relevant and why"} + +Available memory layers: +- raw (L0): original input text preserved verbatim. +- episode (L1): bounded episodic memories — events, experiences, plans. +- atomic_fact (L2): compact factual memories — attributes, relationships, outcomes. +- kb_insight: distilled knowledge points extracted from the knowledge base. + +Rules: +- Rank based on relevance to the question only — do not attempt to answer the question here. +- If a layer has no results, rank it last. +- If no layer has any relevant evidence, return {"ranking": [], "reason": "no relevant evidence found"}. +""" + +RECALL_ANSWER_PROMPT = """\ +You are a strictly evidence-grounded answer generator. +Your task: answer the question using ONLY the evidence provided. No outside knowledge allowed. + +Return STRICT JSON only. Do not output anything outside the JSON object. + +Schema (sufficient): {"answer": "", "sufficient": true} +Schema (insufficient): {"answer": "[INSUFFICIENT]", "sufficient": false} + +CRITICAL rules: +- Use ONLY information explicitly stated in the evidence. Do NOT infer, extrapolate, or fill in details from your training knowledge. +- If the evidence does not explicitly contain what is needed to answer the question, output {"answer": "[INSUFFICIENT]", "sufficient": false}. +- "Thematically related" evidence is NOT sufficient. The evidence must directly state the specific information being asked. +- If the question asks for specific details that are not literally present in the evidence, output [INSUFFICIENT]. +- When in doubt, output [INSUFFICIENT]. + +Answer format (when sufficient): +- Write in Markdown with appropriate structure (headers, lists, code blocks). +- Answer in the SAME language as the question. +- Do NOT include a Sources or References section. +""" + +INSIGHT_DEDUP_PROMPT = """\ +You are a memory deduplication judge for a retrieval cache of factual insights. + +Given a NEW insight and a list of EXISTING similar insights already stored in memory, +decide whether the new insight should be skipped as redundant. + +Return STRICT JSON only — no markdown, no commentary. +Schema: {"duplicate": true} OR {"duplicate": false} + +Decision rule: +A new insight is REDUNDANT (duplicate=true) if every factual element it +asserts — every entity, relationship, attribute, time, and place — is +already covered by the COMBINATION of existing insights. The new insight +does not need to be a paraphrase of any single existing one; what matters +is whether any genuinely new fact is being introduced. + +A new insight is WORTH KEEPING (duplicate=false) if it introduces at least +one factual element not expressed by the existing set. + +Examples: +- NEW: "Martha Mattie 是 MJ 的祖母,青年时期生活在 Russell County" + EXISTING: ["Martha Mattie 是 MJ 的祖母", + "Martha Mattie 青年时期生活在 Russell County"] + → {"duplicate": true} (every fact already covered by the combination) + +- NEW: "Martha Mattie 是 MJ 的祖母,她的丈夫名叫 Samuel" + EXISTING: ["Martha Mattie 是 MJ 的祖母"] + → {"duplicate": false} (introduces "Samuel as husband" — a new fact) + +- NEW: "John Doe 是诗人" + EXISTING: ["John Doe 是 20 世纪初居住在 Russell County 的诗人"] + → {"duplicate": true} (the existing insight already covers "John Doe 是诗人") + +When in doubt, return {"duplicate": false} — information loss is harder +to recover than slight redundancy. +""" + +CONFLICT_JUDGE_PROMPT = """\ +You are a memory conflict resolver. Given a new fact and a list of existing facts, +decide which existing facts are directly contradicted by the new fact and should be deprecated. +Return STRICT JSON only. Do not output markdown or extra text. + +Schema: +{"deprecate": ["memory_id_1", "memory_id_2"]} + +Rules: +- Only deprecate facts that are DIRECTLY CONTRADICTED (mutually exclusive with the new fact). +- Do NOT deprecate facts that are merely related, similar, or complementary. +- If nothing is contradicted, return {"deprecate": []}. + +Examples of contradiction: + new: "user lives in Beijing" vs existing: "user lives in Shanghai" → deprecate + new: "user works at Alibaba" vs existing: "user works at ByteDance" → deprecate + +Examples of NO contradiction: + new: "user likes running" vs existing: "user likes swimming" → keep both + new: "user age 26" vs existing: "user had meeting with Bob" → keep both +""" + +FACT_EXTRACTION_PROMPT = """\ +You are a semantic memory extraction engine for long-term personal memory. + +LANGUAGE RULE (highest priority): +All text fields you output — subject, predicate, object — MUST be written in the SAME language +as the input text. If the input is Chinese, write Chinese. If the input is English, write English. +Never translate or switch languages. + Chinese input example: {{"subject":"用户","predicate":"居住在","object":"北京海淀区"}} + English input example: {{"subject":"user","predicate":"lives_in","object":"Haidian, Beijing"}} + +Task: +Extract durable, retrieval-useful facts from the input text as atomic subject-predicate-object triples. +The input begins with an "Anchor date" line — use it to resolve all relative time expressions. +Return STRICT JSON only. Do not output markdown or any extra text. + +Schema: +{{"facts":[{{"subject":"entity name","predicate":"relationship or attribute","object":"value or entity","object_type":"literal","when_text":"original relative time expression or null","when_resolved":"variable-precision date or null","when_precision":"day|week|month|year or null"}}]}} + +object_type is always "literal" unless the object is a known named entity that should be referenced +separately, in which case use "entity_ref". + +Definition of a GOOD fact: +A stable attribute or relationship that would still be useful to know weeks or months later — +who a person is, what they own, believe, or plan, what happened to them. + +What TO extract: +- Personal attributes: occupation, role, education, location, living situation +- Relationships: family, partners, friends, with context +- Preferences and opinions: hobbies, tastes, values — if explicitly stated +- Life events and outcomes: major decisions made, goals achieved, problems resolved +- Possessions, skills, or resources mentioned as notable +- Health, financial, or situational status changes + +What NOT to extract: +- Questions, requests, or intentions never confirmed as outcomes +- Casual small talk or filler without factual content +- Plans or hypotheticals unless explicitly decided or acted upon +- Trivially obvious facts that add no retrieval value +- Restatements of the same fact (avoid duplicates) + +Quantity discipline: +- A short paragraph should yield 2 to 6 facts. Do not pad with minor details. +- If the text contains no durable facts, return {{"facts":[]}}. + +Format rules: +- `subject` is a named entity (person, organisation, place). Use "user" if implicit (in input language). +- `predicate` is a short natural-language phrase in the input language describing the relationship + or attribute (e.g. Chinese: "居住在", "就职于", "喜欢", "月薪"; English: "lives_in", "works_at"). +- `object` is a concise value or name. +- One atomic statement per fact — no conjunctions linking two independent claims. +- `when_text`: copy the original relative time expression verbatim if the fact has a temporal + reference ("下个月", "next week"). Null otherwise. +- `when_resolved` + `when_precision`: resolve with honest precision (same rules as episode extraction): + - "2026-05-12" / "day", "2026-W21" / "week", "2026-06" / "month", "2026" / "year" + Do NOT pad month/week expressions to day level. Null if no temporal reference. +""" diff --git a/src/heta/mem/recall.py b/src/heta/mem/recall.py new file mode 100644 index 0000000..a5c12c6 --- /dev/null +++ b/src/heta/mem/recall.py @@ -0,0 +1,218 @@ +"""Orchestrator for the heta recall pipeline.""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field + +from heta.config.schema import HetaConfig +from heta.mem.client import build_client, build_embedding_client, extra_body +from heta.mem.db import get_connection, init_db +from heta.mem.embedder import embed_text +from heta.mem.kb_store import search_kb_insights +from heta.mem.l0_search import search_turns +from heta.mem.l1_search import search_episodes +from heta.mem.l2_store import search_similar_facts +from heta.mem.paths import db_path, ensure_mem_dir +from heta.mem.prompts import RECALL_ANSWER_PROMPT, RECALL_RANKING_PROMPT + +logger = logging.getLogger(__name__) + + +def _open_conn_and_embed(query: str, config: HetaConfig): + ensure_mem_dir() + conn = get_connection(db_path(), with_vec=True) + init_db(conn) + emb_client, emb_model = build_embedding_client(config) + embedding = embed_text(emb_client, emb_model, query) + return conn, embedding + + +@dataclass +class LayerEvidence: + layer: str # raw / episode / atomic_fact + items: list[dict] = field(default_factory=list) + + +@dataclass +class RecallResult: + query: str + ranking: list[str] + answer: str + reason: str + evidence: list[LayerEvidence] + sufficient: bool = False + + +def retrieve_evidence(query: str, config: HetaConfig, top_k: int = 5) -> list[LayerEvidence]: + """Pure retrieval — no LLM calls. Used by smart_query to inject context into the KB agent.""" + conn, query_embedding = _open_conn_and_embed(query, config) + l0_hits = search_turns(conn, query, top_k=top_k) + l1_hits = search_episodes(conn, query_embedding, top_k=top_k) + l2_hits = search_similar_facts(conn, query_embedding, top_k=top_k) + kb_insight_hits = search_kb_insights(conn, query_embedding, top_k=top_k) + conn.close() + return [ + LayerEvidence(layer="raw", items=l0_hits), + LayerEvidence(layer="episode", items=l1_hits), + LayerEvidence(layer="atomic_fact", items=l2_hits), + LayerEvidence(layer="kb_insight", items=kb_insight_hits), + ] + + +def recall(query: str, config: HetaConfig, top_k: int = 10) -> RecallResult: + conn, query_embedding = _open_conn_and_embed(query, config) + llm_client, llm_model = build_client(config) + + l0_hits = search_turns(conn, query, top_k=top_k) + l1_hits = search_episodes(conn, query_embedding, top_k=top_k) + l2_hits = search_similar_facts(conn, query_embedding, top_k=top_k) + kb_insight_hits = search_kb_insights(conn, query_embedding, top_k=top_k) + conn.close() + + evidence = [ + LayerEvidence(layer="raw", items=l0_hits), + LayerEvidence(layer="episode", items=l1_hits), + LayerEvidence(layer="atomic_fact", items=l2_hits), + LayerEvidence(layer="kb_insight", items=kb_insight_hits), + ] + + ranking, answer, reason, sufficient = _rank( + query=query, + evidence=evidence, + client=llm_client, + model=llm_model, + config=config, + ) + + return RecallResult( + query=query, + ranking=ranking, + answer=answer, + reason=reason, + evidence=evidence, + sufficient=sufficient, + ) + + +def _rank( + query: str, + evidence: list[LayerEvidence], + client, + model: str, + config: HetaConfig, +) -> tuple[list[str], str, str, bool]: + """Two-phase: rank layers first, then generate a strictly grounded answer.""" + evidence_text = format_evidence(evidence) + body = extra_body(config) + + # Phase A: rank layers (no answer generation) + ranking, reason = _rank_layers( + query=query, + evidence_text=evidence_text, + client=client, + model=model, + extra=body, + ) + + # Phase B: generate grounded answer (or [INSUFFICIENT]) + answer, sufficient = _generate_grounded_answer( + query=query, + evidence_text=evidence_text, + client=client, + model=model, + extra=body, + ) + + return ranking, answer, reason, sufficient + + +def _rank_layers( + query: str, + evidence_text: str, + client, + model: str, + extra: dict | None, +) -> tuple[list[str], str]: + kwargs: dict = { + "model": model, + "messages": [ + {"role": "system", "content": RECALL_RANKING_PROMPT}, + {"role": "user", "content": f"Question:\n{query}\n\nEvidence:\n{evidence_text}"}, + ], + "temperature": 0.1, + } + if extra: + kwargs["extra_body"] = extra + try: + raw = (client.chat.completions.create(**kwargs).choices[0].message.content or "").strip() + data = _parse_json(raw) + return data.get("ranking", []), data.get("reason", "") + except Exception: + logger.warning("ranking call failed", exc_info=True) + return [], "" + + +def _generate_grounded_answer( + query: str, + evidence_text: str, + client, + model: str, + extra: dict | None, +) -> tuple[str, bool]: + kwargs: dict = { + "model": model, + "messages": [ + {"role": "system", "content": RECALL_ANSWER_PROMPT}, + {"role": "user", "content": f"Question:\n{query}\n\nEvidence:\n{evidence_text}"}, + ], + "temperature": 0.2, + } + if extra: + kwargs["extra_body"] = extra + try: + raw = (client.chat.completions.create(**kwargs).choices[0].message.content or "").strip() + data = _parse_json(raw) + answer = data.get("answer", "") + sufficient = bool(data.get("sufficient", False)) + if answer == "[INSUFFICIENT]" or not sufficient: + return "", False + return answer, True + except Exception: + logger.warning("answer generation call failed", exc_info=True) + return "", False + + + +def format_evidence(evidence: list[LayerEvidence]) -> str: + parts = [] + for layer_ev in evidence: + parts.append(f"## {layer_ev.layer}") + if not layer_ev.items: + parts.append("(no results)") + else: + for i, item in enumerate(layer_ev.items, 1): + score = item.get("score", 0) + if layer_ev.layer == "raw": + parts.append(f"[{i}; score={score:.4f}] {item['text_content']}") + elif layer_ev.layer == "episode": + parts.append(f"[{i}; score={score:.4f}] {item['summary']}") + elif layer_ev.layer == "kb_insight": + src = item.get("source_path", "") + parts.append(f"[{i}; score={score:.4f}] [{src}] {item.get('insight', '')}") + else: + parts.append(f"[{i}; score={score:.4f}] {item['fact_text']}") + return "\n".join(parts) + + +def _parse_json(raw: str) -> dict: + text = raw.strip() + if text.startswith("```"): + lines = text.splitlines() + text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) + try: + return json.loads(text) + except (json.JSONDecodeError, AttributeError): + logger.warning("Failed to parse LLM JSON response: %s", raw[:200]) + return {} diff --git a/src/heta/mem/session_store.py b/src/heta/mem/session_store.py new file mode 100644 index 0000000..c59d7b8 --- /dev/null +++ b/src/heta/mem/session_store.py @@ -0,0 +1,25 @@ +"""CRUD operations for the session table.""" + +from __future__ import annotations + +import sqlite3 + +from heta.mem.models import Session + + +def create_session(conn: sqlite3.Connection, session: Session) -> None: + conn.execute( + "INSERT INTO session (session_id, started_at, ended_at, consolidated, consolidated_at) " + "VALUES (?, ?, ?, ?, ?)", + (session.session_id, session.started_at, session.ended_at, + session.consolidated, session.consolidated_at), + ) + conn.commit() + + +def close_session(conn: sqlite3.Connection, session_id: str, ended_at: int) -> None: + conn.execute( + "UPDATE session SET ended_at = ? WHERE session_id = ?", + (ended_at, session_id), + ) + conn.commit() diff --git a/src/heta/providers/clients.py b/src/heta/providers/clients.py new file mode 100644 index 0000000..b366cf6 --- /dev/null +++ b/src/heta/providers/clients.py @@ -0,0 +1,97 @@ +"""OpenAI-compatible client factories for configured LLM capabilities.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from openai import OpenAI + +from heta.config.schema import HetaConfig + +EMBEDDING_DIM = 1024 + + +@dataclass(frozen=True) +class ModelClient: + client: OpenAI + model: str + + +def build_chat_client(config: HetaConfig, *, timeout: int = 60, max_retries: int | None = None) -> ModelClient: + """Return the text chat client and model for the configured provider.""" + model = _required(config.llm.chat_model, "chat_model") + return ModelClient( + client=_openai_client( + api_key=_required(config.llm.chat_api_key, "chat_api_key"), + base_url=config.llm.chat_base_url, + timeout=timeout, + max_retries=max_retries, + ), + model=model, + ) + + +def build_multimodal_client(config: HetaConfig, *, timeout: int = 300) -> ModelClient: + """Return the multimodal client and model for image/audio-capable calls.""" + model = _required(config.llm.multimodal_model, "multimodal_model") + return ModelClient( + client=_openai_client( + api_key=_required(config.llm.multimodal_api_key, "multimodal_api_key"), + base_url=config.llm.multimodal_base_url, + timeout=timeout, + ), + model=model, + ) + + +def build_embedding_client(config: HetaConfig, *, timeout: int = 120) -> ModelClient: + """Return the embedding client and fixed-dimension embedding model.""" + model = _required(config.llm.embedding_model, "embedding_model") + return ModelClient( + client=_openai_client( + api_key=_required(config.llm.embedding_api_key, "embedding_api_key"), + base_url=config.llm.embedding_base_url, + timeout=timeout, + ), + model=model, + ) + + +def extra_body(config: HetaConfig) -> dict | None: + """Return provider-specific request options for chat completions.""" + if config.llm.chat_extra_body is not None: + return config.llm.chat_extra_body + if config.llm.provider == "qwen": + return {"enable_thinking": False} + return None + + +def _openai_client( + *, + api_key: str, + base_url: str | None, + timeout: int, + max_retries: int | None = None, +) -> OpenAI: + kwargs: dict = {"api_key": api_key, "timeout": timeout} + if base_url: + kwargs["base_url"] = base_url + if max_retries is not None: + kwargs["max_retries"] = max_retries + return OpenAI(**kwargs) + + +def _required(value: str | None, field: str) -> str: + if not value: + raise ValueError(f"Missing LLM {field} in config.") + return value + + +__all__ = [ + "EMBEDDING_DIM", + "ModelClient", + "build_chat_client", + "build_embedding_client", + "build_multimodal_client", + "extra_body", +] diff --git a/src/heta/providers/llm.py b/src/heta/providers/llm.py index b7a0fc5..8de4cde 100644 --- a/src/heta/providers/llm.py +++ b/src/heta/providers/llm.py @@ -9,7 +9,7 @@ VALIDATION_TIMEOUT_SECONDS: Final[float] = 10.0 -def validate_llm(provider: str, api_key: str) -> bool: +def validate_llm(provider: str, api_key: str, base_url: str | None = None) -> bool: """Validate that an LLM provider API key can reach its provider.""" api_key = api_key.strip() if provider == "qwen": @@ -21,6 +21,8 @@ def validate_llm(provider: str, api_key: str) -> bool: return _validate_bearer_models("https://api.openai.com/v1/models", api_key) if provider == "gemini": return _validate_gemini_models(api_key) + if provider == "custom" and base_url: + return _validate_bearer_models(base_url.rstrip("/") + "/models", api_key) return False @@ -46,4 +48,3 @@ def _validate_gemini_models(api_key: str) -> bool: except requests.RequestException: return False return response.status_code == 200 - diff --git a/src/heta/query/agent.py b/src/heta/query/agent.py index df6a80d..e9ef110 100644 --- a/src/heta/query/agent.py +++ b/src/heta/query/agent.py @@ -3,53 +3,90 @@ from __future__ import annotations import json +import re +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any from heta.config.schema import HetaConfig from heta.kb.agent import AgentStats, _chat_completion, _get_client -from heta.query.models import QueryResult, QuerySource, VectorMatch +from heta.query.models import QueryInsight, QueryResult, QuerySource, VectorMatch from heta.query.tools import ( format_vector_matches, read_index, read_page, + read_raw, search_vector, source_from_page_path, ) -QUERY_TOOLS = [ - { - "type": "function", - "function": { - "name": "read_page", - "description": "Read a wiki page. Valid paths: pages/*.md.", - "parameters": { - "type": "object", - "properties": {"path": {"type": "string"}}, - "required": ["path"], - "additionalProperties": False, - }, +READ_PAGE_TOOL = { + "type": "function", + "function": { + "name": "read_page", + "description": "Read a wiki page. Valid paths: pages/*.md.", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + "additionalProperties": False, }, }, - { - "type": "function", - "function": { - "name": "search_vector", - "description": "Search semantic wiki chunks. Returns wiki id, page path, heading path, content, and score.", - "parameters": { - "type": "object", - "properties": { - "query": {"type": "string"}, - "top_k": {"type": "integer", "minimum": 1, "maximum": 10}, - }, - "required": ["query"], - "additionalProperties": False, +} + +READ_RAW_TOOL = { + "type": "function", + "function": { + "name": "read_raw", + "description": "Read an original raw file referenced by a wiki page. Valid paths stay under raw/.", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + "additionalProperties": False, + }, + }, +} + +SEARCH_VECTOR_TOOL = { + "type": "function", + "function": { + "name": "search_vector", + "description": "Search hybrid semantic and lexical wiki chunks. Returns wiki id, page path, heading path, content, and score.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "top_k": {"type": "integer", "minimum": 1, "maximum": 10}, }, + "required": ["query"], + "additionalProperties": False, }, }, +} + +QUERY_TOOLS = [ + READ_PAGE_TOOL, + READ_RAW_TOOL, + SEARCH_VECTOR_TOOL, +] + +QUERY_TOOLS_NO_VECTOR = [ + READ_PAGE_TOOL, + READ_RAW_TOOL, ] +RAW_SNIPPET_MAX_CHARS = 16000 + + +@dataclass(frozen=True) +class FinalAnswer: + answer: str + sources: list[QuerySource] + insights: list[QueryInsight] + valid_json: bool = True + def run_query_agent( *, @@ -66,6 +103,7 @@ def run_query_agent( stats = AgentStats(task_id="query", max_steps=max_steps, max_seconds=max_seconds) index_text = read_index(base_dir) initial_matches = search_vector(question, config, top_k=top_k, base_dir=base_dir) + vector_matches = _vector_match_map(initial_matches) messages: list[dict[str, Any]] = [ { "role": "user", @@ -78,8 +116,7 @@ def run_query_agent( } ] read_paths: set[str] = set() - vector_sources: dict[str, VectorMatch] = {match.path: match for match in initial_matches} - tools = QUERY_TOOLS if config.vector_index.enable else [QUERY_TOOLS[0]] + tools = QUERY_TOOLS if config.vector_index.enable else QUERY_TOOLS_NO_VECTOR while stats.should_continue(): response = _chat_completion( @@ -94,10 +131,36 @@ def run_query_agent( tool_calls = list(message.tool_calls or []) if not tool_calls: + final_answer = _parse_final_answer( + text=message.content or "", + read_paths=read_paths, + vector_matches=vector_matches, + base_dir=base_dir, + ) + if not final_answer.valid_json: + messages.append( + { + "role": "assistant", + "content": message.content or "", + } + ) + messages.append( + { + "role": "user", + "content": ( + "Your previous response was not valid JSON. Return exactly one valid JSON object now, " + "with keys insights, answer, and used_sources. Do not include Markdown fences or text " + "outside JSON." + ), + } + ) + stats.record("retry final JSON", response.usage) + continue stats.record_completion(response.usage) return QueryResult( - answer=message.content or "", - sources=_build_sources(read_paths=read_paths, vector_sources=vector_sources, base_dir=base_dir), + answer=final_answer.answer, + sources=final_answer.sources, + insights=final_answer.insights, usage=stats.finish("completed"), ) @@ -116,7 +179,7 @@ def run_query_agent( for tool_call in tool_calls ] messages.append(assistant_message) - messages.extend(_execute_tools(tool_calls, config, base_dir, top_k, read_paths, vector_sources)) + messages.extend(_execute_tools(tool_calls, config, base_dir, top_k, read_paths, vector_matches)) stats.record(", ".join(tool.function.name for tool in tool_calls), response.usage) messages.append( @@ -134,16 +197,23 @@ def run_query_agent( config=config, ) stats.record_completion(final.usage) + final_answer = _parse_final_answer( + text=final.choices[0].message.content or "", + read_paths=read_paths, + vector_matches=vector_matches, + base_dir=base_dir, + ) return QueryResult( - answer=final.choices[0].message.content or "", - sources=_build_sources(read_paths=read_paths, vector_sources=vector_sources, base_dir=base_dir), + answer=final_answer.answer, + sources=final_answer.sources, + insights=final_answer.insights, usage=stats.finish("stopped at limit"), ) def _system_prompt(vector_enabled: bool) -> str: vector_rule = ( - "- You may call search_vector again with a refined query if the current evidence is insufficient." + "- You may call search_vector again with a refined semantic or keyword query if the current evidence is insufficient." if vector_enabled else "- Vector search is disabled; rely on the index and pages you read." ) @@ -152,17 +222,57 @@ def _system_prompt(vector_enabled: bool) -> str: Answer the user's question using the Little Heta wiki. You can inspect the wiki, but you must not create, edit, delete, rename, or commit anything. -Rules: +Reading rules: - Treat index.md as the global map of pages, ids, paths, and summaries. - Treat semantic matches as starting evidence, not final truth. - If a chunk is relevant but incomplete, call read_page(path) for the full page. +- You may call read_raw(path) only for original raw files referenced by wiki pages. + Raw files help inspect details, but raw files must never appear in used_sources or + insight source_paths. - Follow useful [[Wiki Links]] by reading the linked pages when the index gives their paths. {vector_rule} - Stop reading when the context is enough. -- If the wiki does not contain enough evidence, say what is missing. -- Answer directly in Markdown. -- Do not include a Sources, References, or Citations section in the answer body. - The CLI renders sources separately from tool usage. + +Output protocol — distill, then answer: +1. After gathering evidence, emit a list of `insights`: short, self-contained + factual claims distilled from the pages you read, each tagged with the + wiki page(s) it derives from. Insights are what the answer must rest on. +2. Compose the `answer` using ONLY facts that are present in your `insights`. + If a sentence in the answer needs a fact not yet represented in any + insight, you must add that fact as an insight first. +3. If the wiki does not contain enough evidence, return an empty insights + list and say in the answer what is missing. + +What makes a GOOD insight: +- ONE self-contained factual claim per insight. A claim may be compound + (multiple linked entities in one sentence) as long as it asserts a single + coherent fact — but never bundle two independent claims. +- Embed named entities (people, places, dates, organisations) directly in + the text. No pronouns, no "this", no "the above". +- A reader with no other context must understand exactly what is asserted. +- Only facts explicitly stated in the wiki — no inference, no speculation. + +source_paths rules: +- Each insight's source_paths lists the wiki page(s) that the insight + derives from. Use the exact page paths (e.g. "pages/12-foo.md"). +- A cross-page synthesis may list multiple paths. +- Every path in source_paths must be a page you actually read or that + appeared in semantic matches. + +Output format — exactly one valid JSON object, no Markdown fences: +{{ + "insights": [ + {{"text": "self-contained factual claim", "source_paths": ["pages/example.md"]}} + ], + "answer": "Markdown answer text, derived from the insights above", + "used_sources": [{{"path": "pages/example.md", "heading_path": "Section"}}] +}} + +- insights: emit as many as the question requires; no minimum, no maximum. + Return [] if no evidence supports an answer. +- used_sources: include only pages you actually relied on (same as the union + of insight source_paths). +- Do not include a Sources, References, or Citations section inside answer. """ @@ -177,7 +287,7 @@ def _initial_message( f"Current date: {datetime.now().date().isoformat()}", f"Question:\n{question}", f"Wiki Index:\n{index_text or '(index.md is missing or empty)'}", - f"Semantic Matches:\n{format_vector_matches(vector_matches)}", + f"Hybrid Matches:\n{format_vector_matches(vector_matches)}", ] if extra_context: parts.append(f"Extra Context:\n{extra_context}") @@ -190,7 +300,7 @@ def _execute_tools( base_dir: Path | None, default_top_k: int, read_paths: set[str], - vector_sources: dict[str, VectorMatch], + vector_matches: dict[tuple[str, str], VectorMatch], ) -> list[dict[str, Any]]: results: list[dict[str, Any]] = [] for tool_call in tool_calls: @@ -205,12 +315,14 @@ def _execute_tools( output = read_page(path, base_dir) if not output.startswith("error:"): read_paths.add(path.replace("\\", "/").strip("/")) + elif name == "read_raw": + path = str(arguments.get("path", "")) + output = _trim_raw_output(read_raw(path, base_dir)) elif name == "search_vector": query = str(arguments.get("query", "")) top_k = int(arguments.get("top_k") or default_top_k) matches = search_vector(query, config, top_k=top_k, base_dir=base_dir) - for match in matches: - vector_sources.setdefault(match.path, match) + vector_matches.update(_vector_match_map(matches)) output = format_vector_matches(matches) else: output = f"error: unknown tool {name}" @@ -218,15 +330,120 @@ def _execute_tools( return results -def _build_sources( +def _trim_raw_output(output: str) -> str: + if output.startswith("error:") or len(output) <= RAW_SNIPPET_MAX_CHARS: + return output + return output[:RAW_SNIPPET_MAX_CHARS] + "\n\n[truncated raw output]" + + +def _vector_match_map(matches: list[VectorMatch]) -> dict[tuple[str, str], VectorMatch]: + return {(_normalize_candidate_path(match.path), match.heading_path): match for match in matches} + + +def _parse_final_answer( *, + text: str, read_paths: set[str], - vector_sources: dict[str, VectorMatch], + vector_matches: dict[tuple[str, str], VectorMatch], base_dir: Path | None, -) -> list[QuerySource]: +) -> FinalAnswer: + data = _extract_json_object(text) + if data is None: + return FinalAnswer(answer=text, sources=[], insights=[], valid_json=False) + + answer = data.get("answer") + used_sources = data.get("used_sources") + raw_insights = data.get("insights") + if not isinstance(answer, str): + answer = text + if not isinstance(used_sources, list): + used_sources = [] + if not isinstance(raw_insights, list): + raw_insights = [] + sources: dict[str, QuerySource] = {} - for path, match in vector_sources.items(): - sources[path] = source_from_page_path(path, base_dir, heading_path=match.heading_path) - for path in sorted(read_paths): - sources[path] = source_from_page_path(path, base_dir) - return list(sources.values()) + normalized_read_paths = {_normalize_candidate_path(path) for path in read_paths} + valid_paths = set(normalized_read_paths) + valid_paths.update(path for path, _ in vector_matches.keys()) + + for source in used_sources: + if not isinstance(source, dict): + continue + raw_path = source.get("path") + if not isinstance(raw_path, str): + continue + try: + path = _normalize_candidate_path(raw_path) + except ValueError: + continue + heading = source.get("heading_path") + heading_path = str(heading).strip() if heading else "" + key = (path, heading_path) + + if path in normalized_read_paths: + display_heading = heading_path or None + elif key in vector_matches: + display_heading = heading_path + else: + continue + sources[f"{path}#{display_heading or ''}"] = source_from_page_path(path, base_dir, heading_path=display_heading) + + insights = _validate_insights(raw_insights, valid_paths=valid_paths) + return FinalAnswer(answer=answer, sources=list(sources.values()), insights=insights) + + +def _validate_insights( + raw_insights: list, + *, + valid_paths: set[str], +) -> list[QueryInsight]: + """Validate each insight: keep ones with non-empty text and at least one valid source_path.""" + out: list[QueryInsight] = [] + for item in raw_insights: + if not isinstance(item, dict): + continue + text = item.get("text") + if not isinstance(text, str) or not text.strip(): + continue + raw_paths = item.get("source_paths") + if not isinstance(raw_paths, list): + continue + validated_paths: list[str] = [] + for p in raw_paths: + if not isinstance(p, str): + continue + try: + norm = _normalize_candidate_path(p) + except ValueError: + continue + if norm in valid_paths: + validated_paths.append(norm) + if not validated_paths: + continue + out.append(QueryInsight(text=text.strip(), source_paths=validated_paths)) + return out + + +def _normalize_candidate_path(path: str) -> str: + return path.replace("\\", "/").strip("/") + + +def _extract_json_object(text: str) -> dict[str, Any] | None: + stripped = text.strip() + if not stripped: + return None + if stripped.startswith("```"): + stripped = stripped.strip("`") + if stripped.lower().startswith("json"): + stripped = stripped[4:].strip() + try: + value = json.loads(stripped) + except json.JSONDecodeError: + match = re.search(r"\{.*\}", stripped, flags=re.DOTALL) + if match is None: + return None + try: + value = json.loads(match.group(0)) + except json.JSONDecodeError: + return None + return value if isinstance(value, dict) else None diff --git a/src/heta/query/models.py b/src/heta/query/models.py index b66ef72..e893035 100644 --- a/src/heta/query/models.py +++ b/src/heta/query/models.py @@ -24,9 +24,17 @@ class QuerySource: heading_path: str | None = None +@dataclass(frozen=True) +class QueryInsight: + """A distilled knowledge nugget emitted by the KB agent alongside its answer.""" + text: str + source_paths: list[str] + + @dataclass(frozen=True) class QueryResult: answer: str sources: list[QuerySource] = field(default_factory=list) + insights: list[QueryInsight] = field(default_factory=list) usage: dict | None = None diff --git a/src/heta/query/pipeline.py b/src/heta/query/pipeline.py index 84ddc92..198f8ca 100644 --- a/src/heta/query/pipeline.py +++ b/src/heta/query/pipeline.py @@ -15,7 +15,6 @@ def run_wiki_query( config: HetaConfig, *, top_k: int = 5, - extra_context: str | None = None, base_dir: Path | None = None, ) -> QueryResult: if not question.strip(): @@ -27,6 +26,5 @@ def run_wiki_query( config=config, base_dir=base_dir, top_k=top_k, - extra_context=extra_context, ) diff --git a/src/heta/query/smart_query.py b/src/heta/query/smart_query.py new file mode 100644 index 0000000..ae93737 --- /dev/null +++ b/src/heta/query/smart_query.py @@ -0,0 +1,281 @@ +"""Outer agent loop with two tools: recall_memory and query_kb.""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +from heta.config.schema import HetaConfig +from heta.mem.client import build_client, extra_body +from heta.mem.kb_writer import remember_kb_insights +from heta.mem.recall import LayerEvidence, format_evidence, retrieve_evidence +from heta.query.models import QueryResult + +logger = logging.getLogger(__name__) + +MAX_OUTER_STEPS = 5 + +_NO_INFO_PHRASES = [ + "no relevant", + "not found", + "unable to find", + "cannot find", + "找不到", + "没有相关", + "无法找到", +] + +OUTER_TOOLS = [ + { + "type": "function", + "function": { + "name": "recall_memory", + "description": ( + "Search personal memory layers (past conversation turns, episodic events, " + "atomic facts, and previously cached KB insights). Fast, no LLM calls. " + "Returns formatted evidence grouped by layer; '(no results)' means a layer is empty." + ), + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The query to search memory with."} + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + }, + { + "type": "function", + "function": { + "name": "query_kb", + "description": ( + "Run a deep wiki knowledge-base search via a sub-agent that can read pages " + "and perform semantic search. Slower but authoritative. Use only when memory is " + "insufficient. Returns a synthesized answer string." + ), + "parameters": { + "type": "object", + "properties": { + "question": {"type": "string", "description": "The question to ask the KB sub-agent."} + }, + "required": ["question"], + "additionalProperties": False, + }, + }, + }, +] + +OUTER_SYSTEM_PROMPT = """\ +You are Little Heta, a knowledge management assistant with access to two +information sources via tools. + +Tools: +- recall_memory(query): fast search over personal memory (past conversations, + episodes, facts, and previously cached KB insights). Use first. +- query_kb(question): deep search over the project knowledge base via a + sub-agent that reads pages. Slower but authoritative. Use only when memory + is insufficient. + +Decision strategy: +1. Always call recall_memory FIRST with the user's question (unless the question + is a trivial greeting or meta-message that needs no retrieval). +2. Read the evidence carefully. A section showing "(no results)" means that + layer is empty. A score below ~0.3 usually means weak relevance. +3. If memory contains the specific information the question asks for, answer + directly from memory. Do NOT call query_kb. +4. If memory is empty or only thematically related (mentions the topic but + not the specific answer), call query_kb. +5. Special case: if the question is about a personal experience ("what did I + do yesterday", "我上次去哪了") and memory has no hits, answer that you + don't have that information. Do NOT search the KB for personal events. +6. After getting tool results, produce the final answer as plain Markdown in + the SAME language as the question. Do not mention the tools or your + internal reasoning. Do not include a "Sources" section. +""" + + +@dataclass +class SmartQueryResult: + answer: str + source: Literal["memory", "kb", "both"] + memory_evidence: list[LayerEvidence] = field(default_factory=list) + kb_result: QueryResult | None = None + written_back: int = 0 + agent_steps: list[str] = field(default_factory=list) + usage: dict[str, Any] | None = None + + +@dataclass +class _State: + memory_evidence: list[LayerEvidence] = field(default_factory=list) + kb_result: QueryResult | None = None + written_back: int = 0 + used_memory: bool = False + used_kb: bool = False + agent_steps: list[str] = field(default_factory=list) + outer_tokens: int = 0 + started_at: float = field(default_factory=time.time) + + +def smart_query( + question: str, + config: HetaConfig, + top_k: int = 5, + base_dir: Path | None = None, +) -> SmartQueryResult: + """Outer agent loop: lets an LLM decide when to recall memory vs. query KB.""" + state = _State() + client, model = build_client(config) + messages: list[dict[str, Any]] = [{"role": "user", "content": question}] + + for _ in range(MAX_OUTER_STEPS): + resp = _chat(client, model, messages, tools=OUTER_TOOLS, config=config) + _record_outer_usage(state, resp) + msg = resp.choices[0].message + tool_calls = list(msg.tool_calls or []) + + if not tool_calls: + return _build_result(state, answer=msg.content or "") + + assistant_msg: dict[str, Any] = {"role": "assistant"} + if msg.content: + assistant_msg["content"] = msg.content + assistant_msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.function.name, "arguments": tc.function.arguments}, + } + for tc in tool_calls + ] + messages.append(assistant_msg) + + for tc in tool_calls: + result = _exec_tool(tc, config, top_k, base_dir, state) + messages.append({"role": "tool", "tool_call_id": tc.id, "content": result}) + + # Step limit reached — force a final answer with no tools + messages.append( + {"role": "user", "content": "Step limit reached. Answer with the evidence already gathered, or say you don't know."} + ) + final = _chat(client, model, messages, tools=None, config=config) + _record_outer_usage(state, final) + return _build_result(state, answer=final.choices[0].message.content or "") + + +def _build_result(state: _State, *, answer: str) -> SmartQueryResult: + memory_has_hits = any(layer.items for layer in state.memory_evidence) + if state.used_kb and state.used_memory and memory_has_hits: + source: Literal["memory", "kb", "both"] = "both" + elif state.used_kb: + source = "kb" + else: + source = "memory" + return SmartQueryResult( + answer=answer, + source=source, + memory_evidence=state.memory_evidence, + kb_result=state.kb_result, + written_back=state.written_back, + agent_steps=list(state.agent_steps), + usage={ + "outer_tokens": state.outer_tokens, + "kb_tokens": (state.kb_result.usage or {}).get("tokens", 0) if state.kb_result else 0, + "tokens": state.outer_tokens + ((state.kb_result.usage or {}).get("tokens", 0) if state.kb_result else 0), + "elapsed_s": round(time.time() - state.started_at, 3), + }, + ) + + +def _exec_tool(tool_call: Any, config: HetaConfig, top_k: int, base_dir: Path | None, state: _State) -> str: + name = tool_call.function.name + try: + args = json.loads(tool_call.function.arguments or "{}") + except json.JSONDecodeError as exc: + return f"error: invalid tool arguments: {exc}" + + if name == "recall_memory": + return _exec_recall_memory(str(args.get("query", "")), config, top_k, state) + if name == "query_kb": + return _exec_query_kb(str(args.get("question", "")), config, top_k, base_dir, state) + return f"error: unknown tool {name}" + + +def _exec_recall_memory(query: str, config: HetaConfig, top_k: int, state: _State) -> str: + if not query.strip(): + return "error: empty query" + try: + evidence = retrieve_evidence(query, config, top_k=top_k) + except Exception as exc: + logger.exception("recall_memory failed") + return f"error: {exc}" + state.memory_evidence = evidence + state.used_memory = True + state.agent_steps.append("recall_memory") + return format_evidence(evidence) + + +def _exec_query_kb(question: str, config: HetaConfig, top_k: int, base_dir: Path | None, state: _State) -> str: + if not question.strip(): + return "error: empty question" + from heta.query.agent import run_query_agent + + try: + kb_result = run_query_agent( + question=question, + config=config, + base_dir=base_dir, + top_k=top_k, + ) + except Exception as exc: + logger.exception("query_kb failed") + return f"error: {exc}" + + state.kb_result = kb_result + state.used_kb = True + state.agent_steps.append("query_kb") + + if _kb_has_info(kb_result.answer) and kb_result.insights: + try: + state.written_back = remember_kb_insights( + question=question, + insights=kb_result.insights, + sources=kb_result.sources, + config=config, + base_dir=base_dir, + ) + except Exception: + logger.exception("kb write-back failed") + + return kb_result.answer + + +def _chat(client, model: str, messages: list[dict[str, Any]], *, tools, config: HetaConfig): + kwargs: dict[str, Any] = { + "model": model, + "messages": [{"role": "system", "content": OUTER_SYSTEM_PROMPT}, *messages], + "temperature": 0.2, + } + if tools: + kwargs["tools"] = tools + body = extra_body(config) + if body: + kwargs["extra_body"] = body + return client.chat.completions.create(**kwargs) + + +def _record_outer_usage(state: _State, response: Any) -> None: + usage = getattr(response, "usage", None) + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + state.outer_tokens += prompt_tokens + completion_tokens + + +def _kb_has_info(answer: str) -> bool: + lower = answer.lower() + return not any(phrase in lower for phrase in _NO_INFO_PHRASES) diff --git a/src/heta/query/tools.py b/src/heta/query/tools.py index da31cee..218ff5c 100644 --- a/src/heta/query/tools.py +++ b/src/heta/query/tools.py @@ -7,7 +7,7 @@ from heta.config.schema import HetaConfig from heta.kb import paths -from heta.kb.vector_index import search_wiki_vector_index +from heta.kb.vector_index import search_wiki_hybrid_index from heta.query.models import QuerySource, VectorMatch PAGE_ID_RE = re.compile(r"^(?P\d+)-.+\.md$") @@ -31,6 +31,19 @@ def read_page(path: str, base_dir: Path | None = None) -> str: return f"error: {exc}" +def read_raw(path: str, base_dir: Path | None = None) -> str: + try: + normalized = normalize_raw_path(path) + full = _resolve_safe(paths.raw_dir(base_dir), normalized) + if not full.exists(): + return f"error: raw/{normalized} does not exist" + if not full.is_file(): + return f"error: raw/{normalized} is not a file" + return full.read_text(encoding="utf-8", errors="replace") + except Exception as exc: + return f"error: {exc}" + + def search_vector( query: str, config: HetaConfig, @@ -50,7 +63,7 @@ def search_vector( content=match.content, score=match.score, ) - for match in search_wiki_vector_index(query=query, config=config, top_k=top_k, base_dir=base_dir) + for match in search_wiki_hybrid_index(query=query, config=config, top_k=top_k, base_dir=base_dir) ] @@ -74,6 +87,18 @@ def normalize_page_path(path: str) -> str: raise ValueError(f"path must be pages/*.md, got: {path!r}") +def normalize_raw_path(path: str) -> str: + normalized = path.replace("\\", "/").strip() + if "/raw/" in normalized: + normalized = normalized.split("/raw/", 1)[1] + elif normalized.startswith("raw/"): + normalized = normalized[4:] + normalized = normalized.strip("/") + if not normalized or normalized.startswith("../") or "/../" in normalized: + raise ValueError(f"path must stay within raw/, got: {path!r}") + return normalized + + def wiki_id_from_page_name(page_name: str) -> int | None: match = PAGE_ID_RE.match(page_name) if match is None: @@ -116,4 +141,3 @@ def _resolve_safe(root_dir: Path, normalized: str) -> Path: def _frontmatter_value(text: str, key: str) -> str | None: match = re.search(rf"^{re.escape(key)}:\s*(.+)$", text, flags=re.MULTILINE) return match.group(1).strip() if match else None - diff --git a/tests/eval_qa.py b/tests/eval_qa.py new file mode 100644 index 0000000..67fd201 --- /dev/null +++ b/tests/eval_qa.py @@ -0,0 +1,371 @@ +""" +QA evaluation script for heta memory system. + +Usage: + python tests/eval_qa.py # run all questions + python tests/eval_qa.py --out results.json # also save raw JSON + python tests/eval_qa.py -q 1 3 5 # run specific question numbers +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +# ── allow running from repo root without installing ────────────────────────── +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from heta.config.io import load_config +from heta.mem.client import build_client, extra_body +from heta.mem.recall import recall + +# ── QA definitions ─────────────────────────────────────────────────────────── + +QA_CASES: list[dict] = [ + # ── L2 基础事实(冲突消解)────────────────────────────────────────────── + { + "id": 1, + "category": "L2-冲突消解", + "question": "陈浩现在在哪家公司工作?", + "expected": "星图数据(极光科技已被覆盖)", + "keywords": ["星图数据"], + "anti_keywords": ["极光科技"], + }, + { + "id": 2, + "category": "L2-冲突消解", + "question": "陈浩现在住在哪里?", + "expected": "望京(经历过:朝阳区→海淀区→望京)", + "keywords": ["望京"], + "anti_keywords": ["朝阳", "海淀"], + }, + { + "id": 3, + "category": "L2-冲突消解", + "question": "陈浩现在的薪资是多少?", + "expected": "28k(薪资链:18k→22k→25k→28k)", + "keywords": ["28"], + "anti_keywords": ["18k", "22k", "25k"], + }, + { + "id": 4, + "category": "L2-冲突消解", + "question": "陈浩现在的通勤时间是多少?", + "expected": "步行10分钟(经历过:1小时→20分钟→步行10分钟)", + "keywords": ["10分钟", "步行"], + "anti_keywords": ["1小时", "20分钟"], + }, + { + "id": 5, + "category": "L2-累加事实", + "question": "陈浩喜欢什么运动?", + "expected": "爬山、羽毛球、偶尔篮球、游泳(兴趣爱好应累加保留)", + "keywords": ["爬山", "羽毛球"], + "anti_keywords": [], + }, + { + "id": 6, + "category": "L2-累加事实", + "question": "陈浩在学什么新技术?", + "expected": "Rust(用于高性能数据处理);需要加强Go经验", + "keywords": ["Rust"], + "anti_keywords": [], + }, + { + "id": 7, + "category": "L2-基础属性", + "question": "陈浩的绩效评级是什么?", + "expected": "A绩效(上个季度,在极光科技)", + "keywords": ["A"], + "anti_keywords": [], + }, + # ── L1 情景事件 ──────────────────────────────────────────────────────── + { + "id": 8, + "category": "L1-情景事件", + "question": "陈浩参加了什么技术会议?会议的主题是什么?", + "expected": "公司技术分享会,主题是大模型在工程中的落地,约50人参加", + "keywords": ["大模型", "技术分享", "50"], + "anti_keywords": [], + }, + { + "id": 9, + "category": "L1-情景事件", + "question": "陈浩和李薇开会讨论了什么?", + "expected": "用户画像模块改版需求评审,双方有争议,定三周后上线", + "keywords": ["用户画像", "李薇"], + "anti_keywords": [], + }, + { + "id": 10, + "category": "L1-情景事件", + "question": "陈浩在极光科技最后一天发生了什么?", + "expected": "完成了所有交接,和团队一起吃了散伙饭,感觉不舍", + "keywords": ["散伙饭", "交接"], + "anti_keywords": [], + }, + { + "id": 11, + "category": "L1-情景事件", + "question": "陈浩最近去哪里旅游了?和谁一起去的?大概花了多少钱?", + "expected": "青岛,和王强、赵敏、刘洋,三天,每人约1500元,住市南区民宿", + "keywords": ["青岛", "王强", "1500"], + "anti_keywords": [], + }, + { + "id": 12, + "category": "L1-事件序列", + "question": "陈浩妈妈的健康情况如何?", + "expected": "最初血压高,去协和医院检查,服降压药;后来复查血压稳定,已停药", + "keywords": ["血压", "协和"], + "anti_keywords": [], + }, + # ── 时间推理 ─────────────────────────────────────────────────────────── + { + "id": 13, + "category": "时间推理", + "question": "用户画像模块最终是什么时候上线的?经历了哪些波折?", + "expected": "比原计划晚了近两个月;需求评审定三周后→推迟到下下个月(前端资源不足)→最终上线,首日UV2万", + "keywords": ["推迟", "上线"], + "anti_keywords": [], + }, + { + "id": 14, + "category": "时间推理", + "question": "陈浩什么时候离开极光科技的?", + "expected": "拿到offer后下周一入职星图数据;在极光科技的最后一天完成交接", + "keywords": ["极光科技", "最后"], + "anti_keywords": [], + }, + { + "id": 15, + "category": "时间推理", + "question": "陈浩妈妈的降压药大概吃了多久?", + "expected": "大约一个月(医生开了一个月的药,复查后停药)", + "keywords": ["一个月", "停药"], + "anti_keywords": [], + }, + # ── 复合推理 ─────────────────────────────────────────────────────────── + { + "id": 16, + "category": "复合推理", + "question": "陈浩的职业发展轨迹是什么?", + "expected": "极光科技后端工程师(18k,绩效A)→裁员担忧→加入星图数据(25k→28k,Go/Rust技术栈)", + "keywords": ["极光科技", "星图数据"], + "anti_keywords": [], + }, + { + "id": 17, + "category": "复合推理", + "question": "陈浩换工作的原因是什么?", + "expected": "公司宣布裁员10%有担忧;拿到星图数据更高薪资的offer(25k)", + "keywords": ["裁员", "星图数据"], + "anti_keywords": [], + }, + { + "id": 18, + "category": "复合推理", + "question": "陈浩目前的生活状态怎么样?", + "expected": "住望京,在星图数据工作(28k),学Rust和Go,游泳,妈妈健康稳定", + "keywords": ["望京", "星图数据", "28"], + "anti_keywords": [], + }, + { + "id": 19, + "category": "复合推理", + "question": "用户画像模块这个项目经历了哪些波折?", + "expected": "需求评审有争议→定三周后上线→推迟到下下个月(前端资源)→上线首日UV2万", + "keywords": ["推迟", "前端"], + "anti_keywords": [], + }, + # ── 边界 / 负样本 ────────────────────────────────────────────────────── + { + "id": 20, + "category": "边界-无记忆", + "question": "陈浩有没有去过上海?", + "expected": "没有相关记忆", + "keywords": [], + "anti_keywords": ["上海"], + "expect_no_memory": True, + }, + { + "id": 21, + "category": "边界-无记忆", + "question": "陈浩结婚了吗?", + "expected": "没有相关记忆", + "keywords": [], + "anti_keywords": [], + "expect_no_memory": True, + }, + { + "id": 22, + "category": "边界-基础属性", + "question": "陈浩今年多少岁?", + "expected": "28岁", + "keywords": ["28"], + "anti_keywords": [], + }, + { + "id": 23, + "category": "边界-历史状态", + "question": "陈浩之前说要涨薪到22k,这个涨薪最终兑现了吗?", + "expected": "没有明确记录兑现;后来换工作了,该涨薪应已被覆盖", + "keywords": ["22k", "换工作"], + "anti_keywords": [], + }, + { + "id": 24, + "category": "边界-细节检索", + "question": "陈浩带妈妈去哪家医院看的病?", + "expected": "协和医院", + "keywords": ["协和"], + "anti_keywords": [], + }, + { + "id": 25, + "category": "边界-细节检索", + "question": "陈浩在极光科技认识了哪些人?", + "expected": "产品经理李薇;技术分享会上认识了做推理优化的同事(无具体名字)", + "keywords": ["李薇"], + "anti_keywords": [], + }, +] + + +# ── result dataclass ────────────────────────────────────────────────────────── + +@dataclass +class QAResult: + id: int + category: str + question: str + expected: str + actual_answer: str + layer_ranking: list[str] + keyword_hit: bool + anti_keyword_hit: bool + auto_pass: bool # keyword-based heuristic + elapsed_s: float + error: str = "" + + +# ── scoring ─────────────────────────────────────────────────────────────────── + +def _check_keywords(answer: str, keywords: list[str], anti_keywords: list[str]) -> tuple[bool, bool]: + a = answer.lower() + hit = all(kw.lower() in a for kw in keywords) if keywords else True + anti_hit = any(kw.lower() in a for kw in anti_keywords) if anti_keywords else False + return hit, anti_hit + + +def _auto_pass(case: dict, answer: str) -> bool: + """Heuristic pass: all keywords present AND no anti-keywords.""" + if case.get("expect_no_memory"): + no_mem_phrases = ["no relevant memory", "没有相关记忆", "没有记录", "未找到", "无相关"] + return any(p in answer.lower() for p in no_mem_phrases) + hit, anti = _check_keywords(answer, case["keywords"], case["anti_keywords"]) + return hit and not anti + + +# ── main ────────────────────────────────────────────────────────────────────── + +def run_eval(question_ids: list[int] | None = None) -> list[QAResult]: + config = load_config() + if config is None: + print("[ERROR] Heta is not initialised. Run `heta init` first.", file=sys.stderr) + sys.exit(1) + + cases = QA_CASES if not question_ids else [c for c in QA_CASES if c["id"] in question_ids] + + results: list[QAResult] = [] + for case in cases: + print(f" Q{case['id']:02d} [{case['category']}] {case['question']}", end=" ... ", flush=True) + t0 = time.time() + error = "" + answer = "" + ranking: list[str] = [] + try: + result = recall(case["question"], config) + answer = result.answer + ranking = result.ranking + except Exception as exc: + error = str(exc) + answer = "" + elapsed = round(time.time() - t0, 2) + + hit, anti = _check_keywords(answer, case["keywords"], case["anti_keywords"]) + passed = _auto_pass(case, answer) + + status = "PASS" if passed else "FAIL" + print(f"{status} ({elapsed}s)") + + results.append(QAResult( + id=case["id"], + category=case["category"], + question=case["question"], + expected=case["expected"], + actual_answer=answer, + layer_ranking=ranking, + keyword_hit=hit, + anti_keyword_hit=anti, + auto_pass=passed, + elapsed_s=elapsed, + error=error, + )) + return results + + +def print_report(results: list[QAResult]) -> None: + passed = sum(1 for r in results if r.auto_pass) + total = len(results) + print() + print("=" * 70) + print(f" RESULT: {passed}/{total} passed ({100*passed//total}%)") + print("=" * 70) + + # group by category + by_cat: dict[str, list[QAResult]] = {} + for r in results: + by_cat.setdefault(r.category, []).append(r) + + for cat, group in by_cat.items(): + cat_pass = sum(1 for r in group if r.auto_pass) + print(f"\n── {cat} ({cat_pass}/{len(group)}) ──") + for r in group: + icon = "✓" if r.auto_pass else "✗" + print(f" {icon} Q{r.id:02d}: {r.question}") + if not r.auto_pass: + print(f" 期望: {r.expected}") + print(f" 实际: {r.actual_answer[:200]}") + if r.error: + print(f" 错误: {r.error}") + + avg_t = sum(r.elapsed_s for r in results) / len(results) if results else 0 + print(f"\n平均响应时间: {avg_t:.1f}s") + + +def save_results(results: list[QAResult], path: str) -> None: + data = [asdict(r) for r in results] + Path(path).write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"结果已保存至 {path}") + + +# ── entry point ─────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate heta memory QA") + parser.add_argument("-q", "--questions", nargs="*", type=int, metavar="N", + help="Only run these question IDs (e.g. -q 1 3 5)") + parser.add_argument("--out", metavar="FILE", + help="Save raw JSON results to this file") + args = parser.parse_args() + + print(f"Running {len(QA_CASES) if not args.questions else len(args.questions)} QA cases...\n") + results = run_eval(args.questions) + print_report(results) + if args.out: + save_results(results, args.out) diff --git a/tests/memory_qa_test.md b/tests/memory_qa_test.md new file mode 100644 index 0000000..212466c --- /dev/null +++ b/tests/memory_qa_test.md @@ -0,0 +1,229 @@ +# Memory QA Test + +## 测试记忆(按输入顺序,即时间顺序) + +每条记忆单独运行一次 `heta remember`,顺序不可乱。 + +--- + +### 第一阶段:基础信息建立 + +```bash +heta remember "我叫陈浩,今年28岁,住在北京朝阳区,在一家叫极光科技的初创公司做后端工程师,入职快两年了" +``` + +```bash +heta remember "我平时喜欢爬山,周末经常去香山或者云蒙山,还喜欢打羽毛球,偶尔打篮球,最近开始学习游泳" +``` + +```bash +heta remember "上周去参加了公司的技术分享会,主题是大模型在工程中的落地,大概50人参加,认识了几个做推理优化的同事,收获挺大的" +``` + +```bash +heta remember "昨天和产品经理李薇开了需求评审会,讨论用户画像模块的改版方案,双方争议挺大,最后定下来计划三周后上线" +``` + +```bash +heta remember "公司上个季度给我打了A绩效,HR昨天通知我下个月涨薪,从18k涨到22k" +``` + +--- + +### 第二阶段:状态变更(制造冲突) + +```bash +heta remember "我搬家了,从朝阳区搬到了海淀区,现在离公司更近,通勤时间从原来的1小时缩短到大概20分钟" +``` + +```bash +heta remember "今天收到消息,用户画像模块上线时间推迟了,改到下下个月,原因是前端人手不够,李薇也很无奈" +``` + +```bash +heta remember "公司全员会议宣布下个季度要裁员10%,部门主管说后端组暂时安全,但我还是有点担心" +``` + +```bash +heta remember "我决定换工作了,已经拿到星图数据的offer,职位还是后端工程师,薪资直接给到25k,下周一入职" +``` + +```bash +heta remember "今天是我在极光科技的最后一天,完成了所有交接,和团队一起吃了散伙饭,挺不舍的" +``` + +--- + +### 第三阶段:新状态建立 + 更多事件 + +```bash +heta remember "在星图数据入职了,新公司在北京望京,团队规模比极光科技大很多,主要做企业数据分析产品" +``` + +```bash +heta remember "上个月和大学同学王强、赵敏、刘洋一起去青岛玩了三天,吃了很多海鲜,住在市南区的民宿,费用AA制大概每人花了1500块" +``` + +```bash +heta remember "最近开始学习Rust,打算用来做高性能的数据处理模块,买了《Rust程序设计语言》这本书" +``` + +```bash +heta remember "妈妈最近血压有点高,上周带她去协和医院检查了,医生说要控制饮食、减少盐分,开了一个月的降压药" +``` + +```bash +heta remember "星图数据这边技术栈比极光科技更新,主要用Go和Rust,Python只用来做数据脚本,我需要加强Go的经验" +``` + +--- + +### 第四阶段:进一步冲突与更新 + +```bash +heta remember "我开始学游泳了,报了公司附近游泳馆的课程,教练说我基础不错,一个月后应该能游50米了" +``` + +```bash +heta remember "搬家了,从海淀区搬到了望京,就在公司附近,步行10分钟就到,房租贵了一些但省去了通勤时间" +``` + +```bash +heta remember "星图数据给我调薪了,从25k涨到28k,理由是试用期表现优秀,正式转正同时调薪" +``` + +```bash +heta remember "用户画像模块终于上线了,比最初计划晚了将近两个月,上线后首日UV达到2万,李薇发微信说很感谢我之前的配合" +``` + +```bash +heta remember "妈妈复查了,血压已经控制稳定,医生说可以停药了,饮食方面继续保持就行" +``` + +--- + +## QA Test Cases + +--- + +### L2 基础事实查询(冲突消解验证) + +**Q1**: 陈浩现在在哪家公司工作? +- 期望:星图数据 +- 冲突链:极光科技 → 星图数据,应只保留最新 +- 考察:L2 works_at 冲突消解 + +**Q2**: 陈浩现在住在哪里? +- 期望:望京(北京) +- 冲突链:朝阳区 → 海淀区 → 望京,应只保留最新 +- 考察:多次地址更新后的最终状态 + +**Q3**: 陈浩现在的薪资是多少? +- 期望:28k +- 冲突链:18k → 22k(未兑现,被换工作覆盖)→ 25k → 28k +- 考察:多轮薪资更新,只保留最新 active fact + +**Q4**: 陈浩的通勤时间是多少? +- 期望:步行10分钟(搬到望京后) +- 冲突链:1小时 → 20分钟 → 步行10分钟 +- 考察:通勤时长的多次更新 + +**Q5**: 陈浩喜欢什么运动? +- 期望:爬山、羽毛球、偶尔篮球、游泳(最新加入) +- 考察:兴趣爱好是累加而非互斥,不应有冲突消解 + +**Q6**: 陈浩在学什么新技术? +- 期望:Rust(用于高性能数据处理);在星图数据需要加强 Go +- 考察:技能/学习类 fact 的检索 + +**Q7**: 陈浩的绩效如何? +- 期望:极光科技上个季度绩效 A +- 考察:历史绩效 fact 检索(无冲突,只有一条) + +--- + +### L1 情景事件查询 + +**Q8**: 陈浩参加了什么技术会议? +- 期望:公司技术分享会,主题大模型工程落地,约50人,认识了做推理优化的同事 +- 考察:L1 情景记忆中的事件细节检索 + +**Q9**: 陈浩和李薇开会讨论了什么? +- 期望:用户画像模块改版需求评审,双方有争议,最终定下三周后上线 +- 考察:L1 参与者 + 事件内容联合检索 + +**Q10**: 陈浩在极光科技最后一天发生了什么? +- 期望:完成交接,吃了散伙饭,感觉不舍 +- 考察:L1 情景记忆的情感和细节 + +**Q11**: 陈浩最近去哪里旅游了?和谁一起?花了多少钱? +- 期望:青岛,和王强、赵敏、刘洋,三天,每人约1500元,住市南区民宿 +- 考察:L1 多字段联合检索(who、where、when、cost) + +**Q12**: 陈浩妈妈的健康情况怎么样? +- 期望:最初血压高,去协和检查,服药控制;后来复查已稳定,停药了 +- 考察:L1 事件序列检索,包含状态变化 + +--- + +### 时间推理查询 + +**Q13**: 用户画像模块最终什么时候上线的? +- 期望:比原计划晚了将近两个月才上线(先定三周后 → 推迟到下下个月 → 最终上线,首日UV2万) +- 考察:跨多条记忆的时间线推理 + +**Q14**: 陈浩什么时候离开极光科技的? +- 期望:拿到offer后下周一入职,今天是最后一天 +- 考察:when_text 的相对时间表达 + +**Q15**: 陈浩妈妈的降压药吃了多久? +- 期望:大约一个月(开了一个月的药,复查后停药) +- 考察:L1 时间跨度推理 + +--- + +### 复合推理查询 + +**Q16**: 陈浩的职业发展轨迹是什么? +- 期望:在极光科技(后端工程师,18k→22k涨薪计划,绩效A)→ 因裁员担忧换工作 → 加入星图数据(25k,试用期转正后28k),技术栈从Python/后端转向Go/Rust +- 考察:跨多条L1+L2的时间线综合 + +**Q17**: 陈浩换工作的原因是什么? +- 期望:公司宣布裁员10%,有担忧;拿到星图数据offer薪资更高(25k) +- 考察:L1情景记忆的因果推理 + +**Q18**: 陈浩现在的生活状态怎么样? +- 期望:住望京、在星图数据工作(28k)、学Rust和Go、游泳、妈妈健康稳定 +- 考察:多个 active L2 fact 的综合呈现 + +**Q19**: 用户画像模块这个项目经历了哪些波折? +- 期望:需求评审有争议 → 定三周后上线 → 推迟到下下个月(前端资源不足)→ 最终上线,首日UV2万 +- 考察:同一主题跨多条记忆的事件串联 + +--- + +### 边界与负样本测试 + +**Q20**: 陈浩有没有去过上海? +- 期望:没有相关记忆 +- 考察:无关记忆时返回 "No relevant memory found" + +**Q21**: 陈浩结婚了吗? +- 期望:没有相关记忆 +- 考察:未提及的个人状态 + +**Q22**: 陈浩今年多少岁? +- 期望:28岁 +- 考察:基础属性直接查询 + +**Q23**: 陈浩的22k涨薪最终兑现了吗? +- 期望:没有明确记录兑现,后来换工作了(该涨薪应已被新薪资覆盖或处于 deprecated 状态) +- 考察:历史 fact 的状态追踪(deprecated 记录) + +**Q24**: 陈浩在哪家医院给妈妈看的病? +- 期望:协和医院 +- 考察:L1 细节字段的精确检索 + +**Q25**: 陈浩的极光科技同事中认识了谁? +- 期望:技术分享会上认识了做推理优化的同事(无具体名字);产品经理李薇 +- 考察:L1 who 字段的检索,及无具名时的处理 diff --git a/tests/seed_memories.sh b/tests/seed_memories.sh new file mode 100755 index 0000000..6b9f30d --- /dev/null +++ b/tests/seed_memories.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash +# Run this script to seed all test memories defined in memory_qa_test.md. +# Usage: +# bash tests/seed_memories.sh # seed only (keeps existing DB) +# bash tests/seed_memories.sh --clean # delete DB first, then seed + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +DB_PATH="${HETA_DB_PATH:-$HOME/.heta/workspace/mem/mem.sqlite3}" + +# ── optional clean ────────────────────────────────────────────────────────── +if [[ "${1:-}" == "--clean" ]]; then + echo ">> Removing existing DB: $DB_PATH" + rm -f "$DB_PATH" "${DB_PATH}-shm" "${DB_PATH}-wal" +fi + +# ── helpers ───────────────────────────────────────────────────────────────── +TOTAL=0 +FAILED=0 +FAILED_TEXTS=() + +run_remember() { + local text="$1" + local label="$2" + TOTAL=$((TOTAL + 1)) + echo "" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo "[$TOTAL] $label" + echo " $text" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + if ! heta remember "$text"; then + echo "[FAILED] ↑ 上面这条记忆写入失败" + FAILED=$((FAILED + 1)) + FAILED_TEXTS+=("[$TOTAL] $text") + fi +} + +# ── Phase 1: 基础信息建立 ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo " 第一阶段:基础信息建立" +echo "============================================================" + +run_remember \ + "我叫陈浩,今年28岁,住在北京朝阳区,在一家叫极光科技的初创公司做后端工程师,入职快两年了" \ + "个人基础信息" + +run_remember \ + "我平时喜欢爬山,周末经常去香山或者云蒙山,还喜欢打羽毛球,偶尔打篮球,最近开始学习游泳" \ + "爱好与运动" + +run_remember \ + "上周去参加了公司的技术分享会,主题是大模型在工程中的落地,大概50人参加,认识了几个做推理优化的同事,收获挺大的" \ + "技术分享会(事件)" + +run_remember \ + "昨天和产品经理李薇开了需求评审会,讨论用户画像模块的改版方案,双方争议挺大,最后定下来计划三周后上线" \ + "需求评审会(事件)" + +run_remember \ + "公司上个季度给我打了A绩效,HR昨天通知我下个月涨薪,从18k涨到22k" \ + "绩效与涨薪(18k→22k)" + +# ── Phase 2: 状态变更(冲突) ─────────────────────────────────────────────── +echo "" +echo "============================================================" +echo " 第二阶段:状态变更(制造冲突)" +echo "============================================================" + +run_remember \ + "我搬家了,从朝阳区搬到了海淀区,现在离公司更近,通勤时间从原来的1小时缩短到大概20分钟" \ + "搬家:朝阳→海淀(冲突住址)" + +run_remember \ + "今天收到消息,用户画像模块上线时间推迟了,改到下下个月,原因是前端人手不够,李薇也很无奈" \ + "项目推迟(冲突上线时间)" + +run_remember \ + "公司全员会议宣布下个季度要裁员10%,部门主管说后端组暂时安全,但我还是有点担心" \ + "裁员公告(事件)" + +run_remember \ + "我决定换工作了,已经拿到星图数据的offer,职位还是后端工程师,薪资直接给到25k,下周一入职" \ + "换工作决定:极光→星图,25k(冲突薪资/公司)" + +run_remember \ + "今天是我在极光科技的最后一天,完成了所有交接,和团队一起吃了散伙饭,挺不舍的" \ + "离职最后一天(事件)" + +# ── Phase 3: 新状态建立 + 更多事件 ───────────────────────────────────────── +echo "" +echo "============================================================" +echo " 第三阶段:新状态建立 + 更多事件" +echo "============================================================" + +run_remember \ + "在星图数据入职了,新公司在北京望京,团队规模比极光科技大很多,主要做企业数据分析产品" \ + "星图数据入职" + +run_remember \ + "上个月和大学同学王强、赵敏、刘洋一起去青岛玩了三天,吃了很多海鲜,住在市南区的民宿,费用AA制大概每人花了1500块" \ + "青岛旅游(事件)" + +run_remember \ + "最近开始学习Rust,打算用来做高性能的数据处理模块,买了《Rust程序设计语言》这本书" \ + "学习Rust" + +run_remember \ + "妈妈最近血压有点高,上周带她去协和医院检查了,医生说要控制饮食、减少盐分,开了一个月的降压药" \ + "妈妈就医(事件)" + +run_remember \ + "星图数据这边技术栈比极光科技更新,主要用Go和Rust,Python只用来做数据脚本,我需要加强Go的经验" \ + "星图技术栈,需学Go" + +# ── Phase 4: 进一步冲突与更新 ────────────────────────────────────────────── +echo "" +echo "============================================================" +echo " 第四阶段:进一步冲突与更新" +echo "============================================================" + +run_remember \ + "我开始学游泳了,报了公司附近游泳馆的课程,教练说我基础不错,一个月后应该能游50米了" \ + "游泳课(与爱好记忆互补)" + +run_remember \ + "搬家了,从海淀区搬到了望京,就在公司附近,步行10分钟就到,房租贵了一些但省去了通勤时间" \ + "再次搬家:海淀→望京(冲突住址+通勤)" + +run_remember \ + "星图数据给我调薪了,从25k涨到28k,理由是试用期表现优秀,正式转正同时调薪" \ + "调薪:25k→28k(冲突薪资)" + +run_remember \ + "用户画像模块终于上线了,比最初计划晚了将近两个月,上线后首日UV达到2万,李薇发微信说很感谢我之前的配合" \ + "项目上线(事件,首日UV2万)" + +run_remember \ + "妈妈复查了,血压已经控制稳定,医生说可以停药了,饮食方面继续保持就行" \ + "妈妈康复(事件,状态更新)" + +# ── 汇总 ──────────────────────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo " 完成" +echo "============================================================" +echo "总计:$TOTAL 条 成功:$((TOTAL - FAILED)) 条 失败:$FAILED 条" + +if [[ $FAILED -gt 0 ]]; then + echo "" + echo "失败列表:" + for t in "${FAILED_TEXTS[@]}"; do + echo " $t" + done + exit 1 +fi diff --git a/tests/test_assistant_skills.py b/tests/test_assistant_skills.py new file mode 100644 index 0000000..e11c4a6 --- /dev/null +++ b/tests/test_assistant_skills.py @@ -0,0 +1,23 @@ +from pathlib import Path + +from heta import assistants + + +def test_install_assistant_skills_installs_codex_and_claude(monkeypatch, tmp_path: Path) -> None: + heta_dir = tmp_path / "heta" / "skills" / "heta" + codex_dir = tmp_path / "codex" / "skills" / "heta" + claude_dir = tmp_path / "claude" / "skills" / "heta" + monkeypatch.setattr(assistants, "HETA_SKILL_DIR", heta_dir) + monkeypatch.setattr(assistants, "CODEX_SKILL_DIR", codex_dir) + monkeypatch.setattr(assistants, "CLAUDE_SKILL_DIR", claude_dir) + + installed = assistants.install_assistant_skills() + + assert [(item.assistant, item.path) for item in installed] == [ + ("Little Heta", heta_dir), + ("Codex", codex_dir), + ("Claude Code", claude_dir), + ] + for directory in (heta_dir, codex_dir, claude_dir): + assert (directory / "SKILL.md").read_text(encoding="utf-8").startswith("---") + assert (directory / "COMMANDS.md").read_text(encoding="utf-8").startswith("# Little Heta") diff --git a/tests/test_audio_parser.py b/tests/test_audio_parser.py new file mode 100644 index 0000000..bf33e34 --- /dev/null +++ b/tests/test_audio_parser.py @@ -0,0 +1,238 @@ +from pathlib import Path +from types import SimpleNamespace + +from heta.config.schema import HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.kb.audio_parser import build_audio_markdown, transcribe_media +from heta.kb.parser import parse_document +from heta.kb.text import extract_title + + +def _config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def _chatgpt_config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="chatgpt", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def _custom_without_multimodal_config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig( + provider="custom", + api_key="sk-test", + chat_api_key="sk-chat", + chat_model="chat-model", + chat_base_url="http://chat.local/v1", + embedding_api_key="sk-embedding", + embedding_model="embedding-model", + embedding_base_url="http://embedding.local/v1", + ), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def _custom_with_audio_config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig( + provider="custom", + api_key="sk-test", + chat_api_key="sk-chat", + chat_model="chat-model", + chat_base_url="http://chat.local/v1", + embedding_api_key="sk-embedding", + embedding_model="embedding-model", + embedding_base_url="http://embedding.local/v1", + audio_api_key="sk-audio", + audio_model="audio-model", + audio_base_url="http://audio.local/v1", + ), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def test_build_audio_markdown_uses_compact_retrieval_sections() -> None: + markdown = build_audio_markdown( + title="Audio - Meeting", + source_name="meeting.mp3", + media_path="../../raw/meeting.mp3", + media_kind="Audio", + summary="A meeting recording.", + transcript="Speaker 1: Let's ship the feature.", + key_points_metadata="Decision: ship the feature. Language: English.", + interpretation_keywords="Meeting notes. keywords: feature, release.", + ) + + assert extract_title(markdown, "fallback") == "Audio - Meeting" + assert "[Audio file](<../../raw/meeting.mp3>)" in markdown + assert "### Transcript" in markdown + assert "### Key Points and Metadata" in markdown + assert "### Interpretation and Keywords" in markdown + assert "## Related Pages" in markdown + assert "## Source" in markdown + + +def test_chatgpt_audio_transcribes_then_structures_transcript(monkeypatch, tmp_path: Path) -> None: + audio = tmp_path / "meeting.mp3" + audio.write_bytes(b"mp3 bytes") + seen: dict[str, object] = {} + + class FakeTranscriptions: + @staticmethod + def create(**kwargs): + seen["transcription"] = kwargs + return "Speaker: hello." + + class FakeOpenAIClient: + audio = SimpleNamespace(transcriptions=FakeTranscriptions()) + + class FakeOpenAIFactory: + def __init__(self, **kwargs): + seen["openai_kwargs"] = kwargs + + audio = FakeOpenAIClient.audio + + chat_client = object() + monkeypatch.setattr("heta.kb.audio_parser.OpenAI", FakeOpenAIFactory) + monkeypatch.setattr("heta.kb.audio_parser._get_client", lambda config: (chat_client, "gpt-chat")) + + def fake_chat_completion(**kwargs): + seen.update(kwargs) + return SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=( + '{"summary":"A meeting.","transcript":"Speaker: hello.",' + '"key_points_metadata":"Language: English.",' + '"interpretation_keywords":"meeting, test"}' + ) + ) + ) + ] + ) + + monkeypatch.setattr("heta.kb.audio_parser._chat_completion", fake_chat_completion) + + result = transcribe_media(source_path=audio, config=_chatgpt_config()) + + assert result["summary"] == "A meeting." + assert seen["openai_kwargs"]["api_key"] == "sk-test" + assert seen["transcription"]["model"] == "gpt-4o-transcribe" + assert seen["transcription"]["response_format"] == "text" + assert seen["client"] is chat_client + assert seen["model"] == "gpt-chat" + assert "Speaker: hello." in seen["messages"][1]["content"] + + +def test_build_audio_markdown_supports_video_link_label() -> None: + markdown = build_audio_markdown( + title="Video - Demo", + source_name="demo.mp4", + media_path="../../raw/demo.mp4", + media_kind="Video", + summary="A product demo.", + transcript="Narrator: This is the dashboard.", + key_points_metadata="Media type: video.", + interpretation_keywords="Product demo, dashboard.", + ) + + assert "[Video file](<../../raw/demo.mp4>)" in markdown + + +def test_parse_document_accepts_audio_branch(monkeypatch, tmp_path: Path) -> None: + source = tmp_path / "meeting.mp3" + archived = tmp_path / "raw_meeting.mp3" + source.write_bytes(b"mp3") + archived.write_bytes(b"mp3") + + monkeypatch.setattr( + "heta.kb.parser.parse_audio_markdown", + lambda source_path, archived_path, config: build_audio_markdown( + title="Audio - Meeting", + source_name=archived_path.name, + media_path="../../raw/raw_meeting.mp3", + media_kind="Audio", + summary="A meeting.", + transcript="Speaker 1: hello.", + key_points_metadata="Language: English.", + interpretation_keywords="meeting, test", + ), + ) + + document = parse_document(source, archived, _config()) + + assert document.title == "Audio - Meeting" + assert document.source_name == "raw_meeting.mp3" + assert document.metadata["extension"] == ".mp3" + assert "### Transcript" in document.markdown_content + + +def test_audio_is_disabled_for_custom_without_audio_adapter(tmp_path: Path) -> None: + source = tmp_path / "meeting.mp3" + source.write_bytes(b"mp3") + + try: + parse_document(source, source, _custom_without_multimodal_config()) + except ValueError as exc: + assert "Audio/video parsing is not enabled for custom providers" in str(exc) + assert "audio APIs vary by vendor" in str(exc) + else: + raise AssertionError("audio parsing should require multimodal config") + + +def test_custom_audio_uses_audio_adapter(monkeypatch, tmp_path: Path) -> None: + audio = tmp_path / "meeting.mp3" + audio.write_bytes(b"mp3 bytes") + seen: dict[str, object] = {} + + class FakeChatCompletions: + @staticmethod + def create(**kwargs): + seen["request"] = kwargs + return SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=( + '{"summary":"A meeting.","transcript":"Speaker: hello.",' + '"key_points_metadata":"Language: English.",' + '"interpretation_keywords":"meeting, test"}' + ) + ) + ) + ] + ) + + class FakeOpenAI: + def __init__(self, **kwargs): + seen["client_kwargs"] = kwargs + self.chat = SimpleNamespace(completions=FakeChatCompletions()) + + monkeypatch.setattr("heta.kb.audio_parser.OpenAI", FakeOpenAI) + + result = transcribe_media(source_path=audio, config=_custom_with_audio_config()) + + assert result["summary"] == "A meeting." + assert seen["client_kwargs"]["api_key"] == "sk-audio" + assert seen["client_kwargs"]["base_url"] == "http://audio.local/v1" + assert seen["request"]["model"] == "audio-model" + content = seen["request"]["messages"][0]["content"] + assert content[1]["type"] == "input_audio" diff --git a/tests/test_clean_memory.py b/tests/test_clean_memory.py new file mode 100644 index 0000000..d2a882a --- /dev/null +++ b/tests/test_clean_memory.py @@ -0,0 +1,215 @@ +"""Tests for heta.mem.clean.clean_memory.""" + +from __future__ import annotations + +import time +import uuid +from pathlib import Path + +import pytest + +from heta.mem.clean import clean_memory +from heta.mem.db import get_connection, init_db +from heta.mem.l0_store import insert_turn +from heta.mem.l1_store import insert_episodic +from heta.mem.l2_store import expire_fact, insert_fact +from heta.mem.meta_store import deprecate, insert_meta +from heta.mem.models import L0Turn, L1Episodic, L2Semantic, MemoryMeta, Session +from heta.mem.session_store import close_session, create_session + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture() +def conn(tmp_path: Path): + db = tmp_path / "test_mem.sqlite3" + c = get_connection(db, with_vec=True) + init_db(c) + yield c + c.close() + + +def _new_id() -> str: + return str(uuid.uuid4()) + + +def _now() -> int: + return int(time.time()) + + +def _insert_session(conn, session_id: str | None = None) -> str: + sid = session_id or _new_id() + create_session(conn, Session(session_id=sid, started_at=_now())) + close_session(conn, sid, _now()) + return sid + + +def _insert_l0(conn, session_id: str) -> None: + insert_turn(conn, L0Turn( + session_id=session_id, + turn_index=0, + role="user", + modality="text", + text_content="hello world", + created_at=_now(), + )) + + +def _insert_l2(conn, session_id: str) -> str: + mid = _new_id() + insert_meta(conn, MemoryMeta( + memory_id=mid, memory_type="L2", session_id=session_id, + origin="extracted", created_at=_now(), last_access_at=_now(), + )) + insert_fact(conn, L2Semantic( + memory_id=mid, subject="user", predicate="lives_in", object="Beijing", + object_type="literal", fact_text="user lives_in Beijing", + t_valid_start=_now(), + )) + return mid + + +def _insert_l1(conn, session_id: str) -> str: + mid = _new_id() + insert_meta(conn, MemoryMeta( + memory_id=mid, memory_type="L1", session_id=session_id, + origin="extracted", created_at=_now(), last_access_at=_now(), + )) + insert_episodic(conn, L1Episodic( + memory_id=mid, who='["user"]', what="went to the park", + where_loc="park", when_ts=None, when_text=None, + when_resolved=None, when_precision=None, why=None, + summary="user went to the park", + )) + return mid + + +def _row_count(conn, table: str) -> int: + return conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] + + +# ── tests ───────────────────────────────────────────────────────────────────── + +def test_clean_empty_db_is_idempotent(conn) -> None: + result = clean_memory(conn) + assert result.deleted_sessions == 0 + assert result.deleted_l0_turns == 0 + assert result.deleted_l1_episodes == 0 + assert result.deleted_l2_facts == 0 + assert result.deleted_meta == 0 + + +def test_clean_removes_session_and_l0_turns(conn) -> None: + sid = _insert_session(conn) + _insert_l0(conn, sid) + + result = clean_memory(conn) + + assert result.deleted_sessions == 1 + assert result.deleted_l0_turns == 1 + assert _row_count(conn, "session") == 0 + assert _row_count(conn, "l0_turn") == 0 + + +def test_clean_removes_l2_facts_and_meta(conn) -> None: + sid = _insert_session(conn) + _insert_l2(conn, sid) + + result = clean_memory(conn) + + assert result.deleted_l2_facts == 1 + assert result.deleted_meta == 1 + assert _row_count(conn, "l2_semantic") == 0 + assert _row_count(conn, "memory_meta") == 0 + + +def test_clean_removes_l1_episodes_and_meta(conn) -> None: + sid = _insert_session(conn) + _insert_l1(conn, sid) + + result = clean_memory(conn) + + assert result.deleted_l1_episodes == 1 + assert result.deleted_meta == 1 + assert _row_count(conn, "l1_episodic") == 0 + assert _row_count(conn, "memory_meta") == 0 + + +def test_clean_removes_deprecated_facts(conn) -> None: + sid = _insert_session(conn) + old_id = _insert_l2(conn, sid) + new_id = _insert_l2(conn, sid) + expire_fact(conn, old_id, _now()) + deprecate(conn, old_id, new_id) + + assert _row_count(conn, "l2_semantic") == 2 + assert _row_count(conn, "memory_meta") == 2 + + result = clean_memory(conn) + + assert result.deleted_l2_facts == 2 + assert result.deleted_meta == 2 + assert _row_count(conn, "l2_semantic") == 0 + assert _row_count(conn, "memory_meta") == 0 + + +def test_clean_removes_all_layers_together(conn) -> None: + sid = _insert_session(conn) + _insert_l0(conn, sid) + _insert_l1(conn, sid) + _insert_l2(conn, sid) + + result = clean_memory(conn) + + assert result.deleted_sessions == 1 + assert result.deleted_l0_turns == 1 + assert result.deleted_l1_episodes == 1 + assert result.deleted_l2_facts == 1 + assert result.deleted_meta == 2 # one L1 meta + one L2 meta + + +def test_clean_multiple_sessions(conn) -> None: + for _ in range(3): + sid = _insert_session(conn) + _insert_l0(conn, sid) + _insert_l2(conn, sid) + + result = clean_memory(conn) + + assert result.deleted_sessions == 3 + assert result.deleted_l0_turns == 3 + assert result.deleted_l2_facts == 3 + assert _row_count(conn, "session") == 0 + assert _row_count(conn, "l0_turn") == 0 + assert _row_count(conn, "l2_semantic") == 0 + + +def test_clean_preserves_schema(conn) -> None: + """Tables must still exist and accept inserts after a clean.""" + sid = _insert_session(conn) + _insert_l0(conn, sid) + _insert_l1(conn, sid) + _insert_l2(conn, sid) + clean_memory(conn) + + # DB should still be fully usable after clean + sid2 = _insert_session(conn) + _insert_l0(conn, sid2) + _insert_l2(conn, sid2) + + assert _row_count(conn, "session") == 1 + assert _row_count(conn, "l0_turn") == 1 + assert _row_count(conn, "l2_semantic") == 1 + + +def test_clean_is_idempotent(conn) -> None: + sid = _insert_session(conn) + _insert_l0(conn, sid) + _insert_l2(conn, sid) + + first = clean_memory(conn) + second = clean_memory(conn) # called on already-empty DB + + assert first.deleted_sessions == 1 + assert second.deleted_sessions == 0 + assert second.deleted_l2_facts == 0 diff --git a/tests/test_code_parser.py b/tests/test_code_parser.py new file mode 100644 index 0000000..49325ab --- /dev/null +++ b/tests/test_code_parser.py @@ -0,0 +1,84 @@ +from pathlib import Path + +from heta.config.schema import HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.kb.code_parser import extract_code_symbols, parse_code_markdown +from heta.kb.parser import parse_document +from heta.kb.text import extract_title + + +def _config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def test_parse_code_markdown_keeps_small_code_inline(tmp_path: Path) -> None: + source = tmp_path / "vector_index.py" + archived = tmp_path / "2026-05-15_vector_index.py" + source.write_text( + 'def search_wiki_vector_index(query, config):\n """Search semantic wiki chunks."""\n return []\n', + encoding="utf-8", + ) + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + markdown = parse_code_markdown(source, archived) + + assert extract_title(markdown, "fallback") == "Code - vector_index.py" + assert "[Raw source](<../../raw/2026-05-15_vector_index.py>)" in markdown + assert "### Code" in markdown + assert "def search_wiki_vector_index" in markdown + + +def test_parse_code_markdown_uses_symbol_index_for_large_code(tmp_path: Path) -> None: + source = tmp_path / "service.py" + archived = tmp_path / "2026-05-15_service.py" + body = "\n".join(["class MemoryService:", " \"\"\"Coordinates memory writes.\"\"\"", " pass", *["x = 1"] * 220]) + source.write_text(body, encoding="utf-8") + archived.write_text(body, encoding="utf-8") + + markdown = parse_code_markdown(source, archived) + + assert "### Symbol Index" in markdown + assert "#### MemoryService" in markdown + assert "Lines: 1-" in markdown + assert "Coordinates memory writes." in markdown + assert "### Code" not in markdown + + +def test_extract_code_symbols_handles_config_and_sql() -> None: + yaml_symbols = extract_code_symbols(Path("heta.yaml"), "vector_index:\n enable: true\nllm:\n provider: qwen\n") + sql_symbols = extract_code_symbols(Path("schema.sql"), "CREATE TABLE wiki_chunks (id integer);\nSELECT * FROM wiki_chunks;\n") + + assert [symbol.name for symbol in yaml_symbols] == ["vector_index", "llm"] + assert sql_symbols[0].name == "CREATE TABLE wiki_chunks" + + +def test_extract_code_symbols_keeps_regex_language_signatures() -> None: + cases = [ + (Path("sample.go"), "type QueryService struct{}\n\nfunc SearchWiki(query string) string {\n return query\n}\n", "SearchWiki", "func SearchWiki"), + (Path("sample.rs"), "pub struct QueryService;\n\npub fn search_wiki(query: &str) -> &str {\n query\n}\n", "search_wiki", "pub fn search_wiki"), + (Path("sample.js"), "export function runQuery(query) {\n return query;\n}\n", "runQuery", "export function runQuery"), + (Path("sample.ts"), "export function formatAnswer(result: QueryResult) {\n return result.answer;\n}\n", "formatAnswer", "export function formatAnswer"), + ] + + for path, text, name, signature_prefix in cases: + symbols = extract_code_symbols(path, text) + symbol = next(symbol for symbol in symbols if symbol.name == name) + assert symbol.signature.startswith(signature_prefix) + + +def test_parse_document_accepts_code_branch(tmp_path: Path) -> None: + source = tmp_path / "tool.ts" + archived = tmp_path / "2026-05-15_tool.ts" + source.write_text("export function runTool() { return true; }\n", encoding="utf-8") + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + document = parse_document(source, archived, _config()) + + assert document.title == "Code - tool.ts" + assert document.metadata["extension"] == ".ts" + assert "language: typescript" in document.markdown_content diff --git a/tests/test_config_io.py b/tests/test_config_io.py index 3fc8664..2203cca 100644 --- a/tests/test_config_io.py +++ b/tests/test_config_io.py @@ -1,7 +1,14 @@ from pathlib import Path from heta.config.io import load_config, save_config -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import ( + DynamicInsertConfig, + InsertPlanningConfig, + HetaConfig, + LLMConfig, + MinerUConfig, + VectorIndexConfig, +) def test_save_and_load_config(tmp_path: Path) -> None: @@ -11,6 +18,8 @@ def test_save_and_load_config(tmp_path: Path) -> None: llm=LLMConfig(provider="qwen", api_key="sk-test"), mineru=MinerUConfig.disabled(), vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), + dynamic_insert=DynamicInsertConfig.disabled(), ) save_config(config, path) @@ -20,11 +29,131 @@ def test_save_and_load_config(tmp_path: Path) -> None: assert path.exists() +def test_load_config_fills_default_llm_profile(tmp_path: Path) -> None: + path = tmp_path / ".heta" / "heta.yaml" + path.parent.mkdir(parents=True) + path.write_text( + """ +version: 1 +llm: + provider: qwen + api_key: sk-test +mineru: + enable: false + provider: + api_key: + endpoint: +vector_index: + enable: true +insert_planning: + enable: true +dynamic_insert: + enable: false +""", + encoding="utf-8", + ) + + loaded = load_config(path) + + assert loaded is not None + assert loaded.llm.chat_model == "qwen3.5-flash" + assert loaded.llm.chat_api_key == "sk-test" + assert loaded.llm.multimodal_model == "qwen3.5-omni-flash" + assert loaded.llm.multimodal_api_key == "sk-test" + assert loaded.llm.embedding_model == "text-embedding-v4" + assert loaded.llm.embedding_api_key == "sk-test" + assert loaded.dynamic_insert.enable is False + + +def test_load_config_accepts_custom_llm_profile(tmp_path: Path) -> None: + path = tmp_path / ".heta" / "heta.yaml" + path.parent.mkdir(parents=True) + path.write_text( + """ +version: 1 +llm: + provider: custom + api_key: sk-test + chat_api_key: sk-chat + chat_model: custom-chat + chat_base_url: http://llm.local/v1 + chat_extra_body: + enable_thinking: false + multimodal_api_key: sk-mm + multimodal_model: custom-mm + multimodal_base_url: http://mm.local/v1 + embedding_api_key: sk-embedding + embedding_model: custom-embedding + embedding_base_url: http://embedding.local/v1 +mineru: + enable: false + provider: + api_key: + endpoint: +vector_index: + enable: true +insert_planning: + enable: true +dynamic_insert: + enable: true +""", + encoding="utf-8", + ) + + loaded = load_config(path) + + assert loaded is not None + assert loaded.llm.provider == "custom" + assert loaded.llm.chat_api_key == "sk-chat" + assert loaded.llm.chat_model == "custom-chat" + assert loaded.llm.chat_base_url == "http://llm.local/v1" + assert loaded.llm.chat_extra_body == {"enable_thinking": False} + assert loaded.llm.multimodal_api_key == "sk-mm" + assert loaded.llm.embedding_api_key == "sk-embedding" + assert loaded.llm.embedding_model == "custom-embedding" + assert loaded.dynamic_insert.enable is True + + +def test_custom_config_requires_embedding_fields(tmp_path: Path) -> None: + path = tmp_path / ".heta" / "heta.yaml" + path.parent.mkdir(parents=True) + path.write_text( + """ +version: 1 +llm: + provider: custom + api_key: sk-test + chat_api_key: sk-chat + chat_model: custom-chat + chat_base_url: http://llm.local/v1 +mineru: + enable: false + provider: + api_key: + endpoint: +vector_index: + enable: true +insert_planning: + enable: true +dynamic_insert: + enable: false +""", + encoding="utf-8", + ) + + try: + load_config(path) + except ValueError as exc: + assert "embedding_model" in str(exc) + else: + raise AssertionError("missing custom embedding fields should fail") + + def test_load_missing_config_returns_none(tmp_path: Path) -> None: assert load_config(tmp_path / "missing.yaml") is None -def test_config_requires_vector_index(tmp_path: Path) -> None: +def test_config_requires_insert_planning(tmp_path: Path) -> None: path = tmp_path / ".heta" / "heta.yaml" path.parent.mkdir(parents=True) path.write_text( @@ -38,6 +167,8 @@ def test_config_requires_vector_index(tmp_path: Path) -> None: provider: api_key: endpoint: +vector_index: + enable: true """, encoding="utf-8", ) @@ -45,6 +176,34 @@ def test_config_requires_vector_index(tmp_path: Path) -> None: try: load_config(path) except ValueError as exc: - assert "vector_index" in str(exc) + assert "insert_planning" in str(exc) else: - raise AssertionError("missing vector_index should fail") + raise AssertionError("missing insert_planning should fail") + + +def test_load_config_defaults_missing_dynamic_insert_to_disabled(tmp_path: Path) -> None: + path = tmp_path / ".heta" / "heta.yaml" + path.parent.mkdir(parents=True) + path.write_text( + """ +version: 1 +llm: + provider: qwen + api_key: sk-test +mineru: + enable: false + provider: + api_key: + endpoint: +vector_index: + enable: true +insert_planning: + enable: true +""", + encoding="utf-8", + ) + + loaded = load_config(path) + + assert loaded is not None + assert loaded.dynamic_insert.enable is False diff --git a/tests/test_html_parser.py b/tests/test_html_parser.py new file mode 100644 index 0000000..dfeaf87 --- /dev/null +++ b/tests/test_html_parser.py @@ -0,0 +1,173 @@ +from pathlib import Path + +from heta.config.schema import HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.kb.html_parser import parse_html_markdown +from heta.kb.parser import parse_document +from heta.kb.text import extract_title + + +def _config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def test_parse_html_markdown_preserves_structure_and_inline_images(tmp_path: Path) -> None: + image = tmp_path / "arch.png" + image.write_bytes(b"png") + source = tmp_path / "attention.html" + archived = tmp_path / "raw" / "2026-05-15_attention.html" + archived.parent.mkdir() + source.write_text( + """ + + + + Attention Mechanism + + + + + +

Attention Mechanism

+

Source: Bahdanau and Vaswani.

+

Overview

+

The model focuses on relevant input tokens.

+

Types

+
TypeDescription
Self-AttentionTokens attend to tokens.
+

Architecture

+

The Transformer uses multi-head attention.

+ Transformer architecture + + +""", + encoding="utf-8", + ) + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + markdown = parse_html_markdown(source, archived) + + assert extract_title(markdown, "fallback") == "Web Page - Attention Mechanism" + assert "### Metadata" not in markdown + assert "Raw HTML" not in markdown + assert "Navigation noise" not in markdown + assert "### Attention Mechanism" in markdown + assert "#### Overview" in markdown + assert "| Type | Description |" in markdown + assert "![Transformer architecture](<../../raw/assets/2026-05-15_attention/img-001.png>)" in markdown + assert "Image note: Transformer architecture." in markdown + assert (tmp_path / "raw" / "assets" / "2026-05-15_attention" / "img-001.png").exists() + manifest = (tmp_path / "raw" / "assets" / "2026-05-15_attention" / "manifest.json").read_text(encoding="utf-8") + assert '"original_src": "arch.png"' in manifest + assert '"section": "Architecture"' in manifest + + +def test_parse_html_markdown_keeps_remote_images_as_urls(tmp_path: Path) -> None: + source = tmp_path / "remote.htm" + archived = tmp_path / "raw" / "2026-05-15_remote.htm" + archived.parent.mkdir() + source.write_text( + '

Remote Page

Intro.

Remote plot', + encoding="utf-8", + ) + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + markdown = parse_html_markdown(source, archived) + + assert "![Remote plot]()" in markdown + manifest = (tmp_path / "raw" / "assets" / "2026-05-15_remote" / "manifest.json").read_text(encoding="utf-8") + assert '"original_src": "https://example.com/plot.png"' in manifest + assert '"raw_path": null' in manifest + + +def test_parse_html_markdown_prefers_main_content_and_clean_summary(tmp_path: Path) -> None: + source = tmp_path / "wiki.html" + archived = tmp_path / "raw" / "2026-05-15_wiki.html" + archived.parent.mkdir() + source.write_text( + """ + +Knowledge graph + + +
+
Page semi-protected
+
This article has multiple issues.
+
+

+

For other uses, see Knowledge graph (disambiguation).

+

Knowledge graph is a graph-structured knowledge base used to represent entities, facts, and relationships for retrieval and reasoning systems.

+

History

+

The term has been used in several database and semantic web contexts.

+
+
+ + +""", + encoding="utf-8", + ) + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + markdown = parse_html_markdown(source, archived) + + assert "Jump to content" not in markdown + assert "Page semi-protected" not in markdown + assert "This article has multiple issues" not in markdown + assert "### Knowledge graph" in markdown + assert "## Summary\nKnowledge graph is a graph-structured knowledge base" in markdown + assert "#### History" in markdown + + +def test_parse_html_markdown_cleans_document_site_shell(tmp_path: Path) -> None: + source = tmp_path / "docs.html" + archived = tmp_path / "raw" / "2026-05-15_docs.html" + archived.parent.mkdir() + source.write_text( + """ + + + SQL Reference + + + +
LogoSmall. Fast. Reliable.
+
+
+ +

SQL Reference

+

SQL Reference explains statements, expressions, and data manipulation concepts for database users.

+

Statements

+

The reference lists supported statement syntax.

+
+ + +""", + encoding="utf-8", + ) + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + markdown = parse_html_markdown(source, archived) + + assert "Small. Fast. Reliable" not in markdown + assert "IE hack" not in markdown + assert "Previous topic" not in markdown + assert "### SQL Reference" in markdown + assert "#### Statements" in markdown + assert "## Summary\nSQL Reference explains statements" in markdown + + +def test_parse_document_accepts_html_branch(tmp_path: Path) -> None: + source = tmp_path / "page.html" + archived = tmp_path / "2026-05-15_page.html" + source.write_text("

HTML Page

Hello.

", encoding="utf-8") + archived.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") + + document = parse_document(source, archived, _config()) + + assert document.title == "Web Page - HTML Page" + assert document.metadata["extension"] == ".html" + assert "### HTML Page" in document.markdown_content diff --git a/tests/test_image_parser.py b/tests/test_image_parser.py new file mode 100644 index 0000000..c013499 --- /dev/null +++ b/tests/test_image_parser.py @@ -0,0 +1,95 @@ +from pathlib import Path + +from heta.config.schema import HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.kb.image_parser import build_image_markdown +from heta.kb.parser import parse_document +from heta.kb.text import extract_title + + +def _config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def _custom_without_multimodal_config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig( + provider="custom", + api_key="sk-test", + chat_api_key="sk-chat", + chat_model="chat-model", + chat_base_url="http://chat.local/v1", + embedding_api_key="sk-embedding", + embedding_model="embedding-model", + embedding_base_url="http://embedding.local/v1", + ), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def test_build_image_markdown_uses_compact_retrieval_sections() -> None: + markdown = build_image_markdown( + title="Image - Architecture Diagram", + source_name="diagram.png", + image_path="../../raw/diagram.png", + summary="A system architecture diagram.", + visual_facts="Scene/type: diagram. Main subject: service pipeline.", + visible_text="API Gateway", + interpretation_keywords="Represents a backend data flow. keywords: API, pipeline.", + ) + + assert extract_title(markdown, "fallback") == "Image - Architecture Diagram" + assert "![diagram.png](<../../raw/diagram.png>)" in markdown + assert "### Visual Facts" in markdown + assert "### Visible Text" in markdown + assert "### Interpretation and Keywords" in markdown + assert "## Related Pages" in markdown + assert "## Source" in markdown + + +def test_parse_document_accepts_image_branch(monkeypatch, tmp_path: Path) -> None: + source = tmp_path / "diagram.png" + archived = tmp_path / "raw_diagram.png" + source.write_bytes(b"png") + archived.write_bytes(b"png") + + monkeypatch.setattr( + "heta.kb.parser.parse_image_markdown", + lambda source_path, archived_path, config: build_image_markdown( + title="Image - Diagram", + source_name=archived_path.name, + image_path="../../raw/raw_diagram.png", + summary="A diagram.", + visual_facts="A simple diagram.", + visible_text="None detected.", + interpretation_keywords="diagram, test", + ), + ) + + document = parse_document(source, archived, _config()) + + assert document.title == "Image - Diagram" + assert document.source_name == "raw_diagram.png" + assert document.metadata["extension"] == ".png" + assert "### Visual Facts" in document.markdown_content + + +def test_image_requires_multimodal_when_custom_skips_it(tmp_path: Path) -> None: + source = tmp_path / "diagram.png" + source.write_bytes(b"png") + + try: + parse_document(source, source, _custom_without_multimodal_config()) + except ValueError as exc: + assert "requires a multimodal model" in str(exc) + assert "heta init" in str(exc) + else: + raise AssertionError("image parsing should require multimodal config") diff --git a/tests/test_kb_insert.py b/tests/test_kb_insert.py index bd9a8a0..7231bc6 100644 --- a/tests/test_kb_insert.py +++ b/tests/test_kb_insert.py @@ -2,25 +2,37 @@ import pytest -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import ( + DynamicInsertConfig, + InsertPlanningConfig, + HetaConfig, + LLMConfig, + MinerUConfig, + VectorIndexConfig, +) from heta.kb.discovery import collect_insert_files -from heta.kb.models import FileChange -from heta.kb.insert import insert_paths +from heta.kb.models import FileChange, ParsedDocument +from heta.kb.insert import _ensure_code_raw_links, insert_paths from heta.kb.text import frontmatter_page, slugify, summarize -from heta.kb.wiki import normalize_wiki_pages +from heta.kb.wiki import normalize_wiki_pages, repair_broken_wiki_links -def _config(mineru: MinerUConfig | None = None) -> HetaConfig: +def _config(mineru: MinerUConfig | None = None, *, dynamic_insert: bool = True) -> HetaConfig: return HetaConfig( version=1, llm=LLMConfig(provider="qwen", api_key="sk-test"), mineru=mineru or MinerUConfig.disabled(), vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + dynamic_insert=DynamicInsertConfig(enable=dynamic_insert), ) -def _fake_agent(monkeypatch) -> None: +def _fake_agent(monkeypatch, calls: list[list[str]] | None = None) -> None: def run_merge_agent(*, task_id, documents, root_dir, config): + assert len(documents) == 1 + if calls is not None: + calls.append([document.source_name for document in documents]) pages = root_dir / "pages" pages.mkdir(parents=True, exist_ok=True) added = [] @@ -96,6 +108,156 @@ def test_insert_same_title_updates_existing_page(monkeypatch, tmp_path: Path) -> assert "## Imported Update" in page.read_text(encoding="utf-8") +def test_insert_multiple_files_runs_agent_sequentially(monkeypatch, tmp_path: Path) -> None: + calls: list[list[str]] = [] + progress = [] + _fake_agent(monkeypatch, calls) + first = tmp_path / "alpha.md" + second = tmp_path / "beta.md" + first.write_text("# Alpha\n\nFirst details.", encoding="utf-8") + second.write_text("# Beta\n\nSecond details.", encoding="utf-8") + + result = insert_paths( + [first, second], + _config(), + base_dir=tmp_path, + on_progress=progress.append, + ) + + wiki = tmp_path / "workspace" / "kb" / "wiki" + assert calls[0][0].endswith("_alpha.md") + assert calls[1][0].endswith("_beta.md") + assert (wiki / "pages" / "1-alpha.md").exists() + assert (wiki / "pages" / "2-beta.md").exists() + assert [change.path for change in result.added] == ["pages/1-alpha.md", "pages/2-beta.md"] + assert progress[0].percent == 1 + merge_percents = [event.percent for event in progress if event.phase == "merge"] + assert 50 in merge_percents + assert 99 in merge_percents + assert progress[-1].percent == 100 + assert progress[-1].phase == "done" + + +def test_insert_defaults_to_static_pages(monkeypatch, tmp_path: Path) -> None: + def fail_agent(**kwargs): + raise AssertionError("dynamic agent should not run in static insert mode") + + monkeypatch.setattr("heta.kb.insert.run_merge_agent", fail_agent) + monkeypatch.setattr( + "heta.kb.static_insert.generate_summary", + lambda *, document, config: f"Summary for {document.title}.", + ) + source = tmp_path / "manual.md" + source.write_text("# Main Heading\n\n## Sub Heading\n\nBody text.", encoding="utf-8") + + result = insert_paths([source], _config(dynamic_insert=False), base_dir=tmp_path) + + wiki = tmp_path / "workspace" / "kb" / "wiki" + page = wiki / "pages" / "1-main-heading.md" + text = page.read_text(encoding="utf-8") + assert result.added[0].path == "pages/1-main-heading.md" + assert "Summary for Main Heading." in text + assert "### Main Heading" in text + assert "#### Sub Heading" in text + assert "## Related Pages\n\n- None yet" in text + assert "- " + result.raw_files[0].name in text + assert "[[Main Heading]]" in (wiki / "index.md").read_text(encoding="utf-8") + assert "Created static page: Main Heading" in (wiki / "log.md").read_text(encoding="utf-8") + + +def test_insert_reports_vector_sync_error_without_rolling_back(monkeypatch, tmp_path: Path) -> None: + monkeypatch.setattr( + "heta.kb.static_insert.generate_summary", + lambda *, document, config: f"Summary for {document.title}.", + ) + monkeypatch.setattr( + "heta.kb.insert.sync_wiki_vector_index", + lambda **kwargs: (_ for _ in ()).throw(RuntimeError("embedding unavailable")), + ) + source = tmp_path / "manual.md" + source.write_text("# Main Heading\n\nBody text.", encoding="utf-8") + config = HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=True), + insert_planning=InsertPlanningConfig.enabled(), + dynamic_insert=DynamicInsertConfig.disabled(), + ) + + result = insert_paths([source], config, base_dir=tmp_path) + + page = tmp_path / "workspace" / "kb" / "wiki" / "pages" / "1-main-heading.md" + assert result.commit_id + assert page.exists() + assert result.vector_index_error == "embedding unavailable" + + +def test_insert_continues_when_agent_makes_no_wiki_changes(monkeypatch, tmp_path: Path) -> None: + calls: list[str] = [] + + def run_merge_agent(*, task_id, documents, root_dir, config): + document = documents[0] + calls.append(document.source_name) + if "beta" in document.source_name: + return {"added": [], "updated": [], "deleted": []} + + pages = root_dir / "pages" + pages.mkdir(parents=True, exist_ok=True) + page = pages / f"{slugify(document.title)}.md" + page.write_text( + frontmatter_page( + document.title, + document.source_name, + summarize(document.markdown_content), + document.markdown_content, + ), + encoding="utf-8", + ) + return {"added": [FileChange("added", document.title, f"pages/{page.name}")], "updated": [], "deleted": []} + + monkeypatch.setattr("heta.kb.insert.run_merge_agent", run_merge_agent) + first = tmp_path / "alpha.md" + second = tmp_path / "beta.md" + third = tmp_path / "gamma.md" + first.write_text("# Alpha\n\nFirst details.", encoding="utf-8") + second.write_text("# Beta\n\nSecond details.", encoding="utf-8") + third.write_text("# Gamma\n\nThird details.", encoding="utf-8") + + result = insert_paths([first, second, third], _config(), base_dir=tmp_path) + + wiki = tmp_path / "workspace" / "kb" / "wiki" + assert len(calls) == 3 + assert [change.path for change in result.added] == ["pages/1-alpha.md", "pages/2-gamma.md"] + assert result.skipped_documents == [calls[1]] + assert (wiki / "pages" / "1-alpha.md").exists() + assert not (wiki / "pages" / "2-beta.md").exists() + assert (wiki / "pages" / "2-gamma.md").exists() + assert "Skipped no-op merge" not in (wiki / "log.md").read_text(encoding="utf-8") + + +def test_ensure_code_raw_links_restores_agent_dropped_raw_link(tmp_path: Path) -> None: + wiki = tmp_path / "wiki" + page = wiki / "pages" / "1-code-demo.md" + page.parent.mkdir(parents=True) + page.write_text( + frontmatter_page("Code - demo.py", "2026-05-15_demo.py", "Summary.", "### File Overview\n- language: python"), + encoding="utf-8", + ) + document = ParsedDocument( + source_path=tmp_path / "demo.py", + archived_path=tmp_path / "raw" / "2026-05-15_demo.py", + title="Code - demo.py", + markdown_content="", + source_name="2026-05-15_demo.py", + metadata={"extension": ".py"}, + ) + + _ensure_code_raw_links(wiki, document, [FileChange("added", "Code - demo.py", "pages/1-code-demo.md")]) + + assert "[Raw source](<../../raw/2026-05-15_demo.py>)" in page.read_text(encoding="utf-8") + + def test_pdf_requires_mineru_when_disabled(tmp_path: Path) -> None: source = tmp_path / "paper.pdf" source.write_bytes(b"%PDF") @@ -104,6 +266,60 @@ def test_pdf_requires_mineru_when_disabled(tmp_path: Path) -> None: collect_insert_files([source], _config()) +def test_office_requires_mineru_when_disabled(tmp_path: Path) -> None: + source = tmp_path / "deck.pptx" + source.write_bytes(b"pptx") + + with pytest.raises(ValueError, match="requires MinerU"): + collect_insert_files([source], _config()) + + +def test_collect_insert_files_accepts_office_when_mineru_enabled(tmp_path: Path) -> None: + files = [] + for name in ["notes.doc", "notes.docx", "deck.ppt", "deck.pptx", "sheet.xls", "sheet.xlsx"]: + file = tmp_path / name + file.write_bytes(b"office") + files.append(file) + + collected = collect_insert_files( + [tmp_path], + _config(MinerUConfig(enable=True, provider="cloud", api_key="mineru-token", endpoint=None)), + ) + + assert collected == sorted(files) + + +def test_collect_insert_files_accepts_common_images(tmp_path: Path) -> None: + image = tmp_path / "diagram.png" + image.write_bytes(b"png") + + files = collect_insert_files([image], _config()) + + assert files == [image] + + +def test_collect_insert_files_accepts_audio_and_video(tmp_path: Path) -> None: + audio = tmp_path / "meeting.mp3" + video = tmp_path / "demo.mp4" + audio.write_bytes(b"mp3") + video.write_bytes(b"mp4") + + files = collect_insert_files([audio, video], _config()) + + assert files == [audio, video] + + +def test_collect_insert_files_accepts_code_and_html(tmp_path: Path) -> None: + code = tmp_path / "module.py" + html = tmp_path / "index.html" + code.write_text("def run():\n pass\n", encoding="utf-8") + html.write_text("
", encoding="utf-8") + + files = collect_insert_files([tmp_path], _config()) + + assert files == [html, code] + + def test_collect_directory_skips_workspace(tmp_path: Path) -> None: source = tmp_path / "a.md" workspace_file = tmp_path / "workspace" / "kb" / "wiki" / "pages" / "old.md" @@ -142,3 +358,30 @@ def test_normalize_wiki_pages_assigns_max_plus_one_without_reusing_deleted_ids(t assert "- [1] [[Old]] (pages/1-old.md) — Old summary." in index assert "- [3] [[Existing]] (pages/3-existing.md) — Existing summary." in index assert "- [4] [[New Topic]] (pages/4-new-topic.md) — New summary." in index + + +def test_repair_broken_wiki_links_downgrades_missing_targets(tmp_path: Path) -> None: + wiki = tmp_path / "wiki" + pages = wiki / "pages" + pages.mkdir(parents=True) + page = pages / "1-topic.md" + page.write_text( + frontmatter_page( + "Topic", + "source.md", + "Summary.", + "See [[Existing Topic]] and [[Missing Topic]].", + ), + encoding="utf-8", + ) + (pages / "2-existing-topic.md").write_text( + frontmatter_page("Existing Topic", "source.md", "Summary.", "Body."), + encoding="utf-8", + ) + + repair_broken_wiki_links(wiki) + + text = page.read_text(encoding="utf-8") + assert "[[Existing Topic]]" in text + assert "[[Missing Topic]]" not in text + assert "Missing Topic" in text diff --git a/tests/test_kb_invalidate.py b/tests/test_kb_invalidate.py new file mode 100644 index 0000000..cc1e4d1 --- /dev/null +++ b/tests/test_kb_invalidate.py @@ -0,0 +1,140 @@ +"""Tests for heta.mem.kb_invalidate.""" + +from __future__ import annotations + +import time +import uuid +from pathlib import Path + +import pytest + +from heta.mem.db import get_connection, init_db +from heta.mem.kb_invalidate import delete_all_insights, delete_insights_by_paths +from heta.mem.kb_store import insert_insight_embedding, insert_kb_insight +from heta.mem.meta_store import insert_meta +from heta.mem.models import KBInsight, MemoryMeta + + +@pytest.fixture() +def conn(tmp_path: Path): + db = tmp_path / "test_mem.sqlite3" + c = get_connection(db, with_vec=True) + init_db(c) + yield c + c.close() + + +def _now() -> int: + return int(time.time()) + + +def _insert_insight(conn, source_paths, insight_text: str = "fact") -> str: + """source_paths can be a single str (for backward-compat tests) or a list.""" + if isinstance(source_paths, str): + source_paths = [source_paths] + mid = str(uuid.uuid4()) + insert_meta(conn, MemoryMeta( + memory_id=mid, memory_type="kb_insight", session_id=None, + origin="kb_insight", created_at=_now(), last_access_at=_now(), + )) + insert_kb_insight(conn, KBInsight( + memory_id=mid, insight=insight_text, question="q", + source_paths=source_paths, wiki_id=None, heading_path=None, + created_at=_now(), + )) + # 1024-dim float embedding (matches EMBEDDING_DIM) + from heta.mem.client import EMBEDDING_DIM + insert_insight_embedding(conn, mid, [0.0] * EMBEDDING_DIM) + conn.commit() + return mid + + +def _count(conn, table: str) -> int: + return conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] + + +# ── tests ───────────────────────────────────────────────────────────────────── + + +def test_delete_by_paths_removes_matching_rows(conn): + _insert_insight(conn, "pages/1-foo.md") + _insert_insight(conn, "pages/1-foo.md") # 2 insights on same page + _insert_insight(conn, "pages/2-bar.md") + + deleted = delete_insights_by_paths(conn, ["pages/1-foo.md"]) + + assert deleted == 2 + assert _count(conn, "kb_insight") == 1 + assert _count(conn, "kb_insight_vec") == 1 + assert _count(conn, "memory_meta") == 1 + + +def test_delete_by_paths_leaves_other_pages_untouched(conn): + _insert_insight(conn, "pages/1-foo.md") + bar_id = _insert_insight(conn, "pages/2-bar.md") + + delete_insights_by_paths(conn, ["pages/1-foo.md"]) + + rows = conn.execute("SELECT memory_id FROM kb_insight").fetchall() + assert [r[0] for r in rows] == [bar_id] + + +def test_delete_by_paths_empty_input(conn): + _insert_insight(conn, "pages/1-foo.md") + assert delete_insights_by_paths(conn, []) == 0 + assert _count(conn, "kb_insight") == 1 + + +def test_delete_by_paths_no_match(conn): + _insert_insight(conn, "pages/1-foo.md") + assert delete_insights_by_paths(conn, ["pages/does-not-exist.md"]) == 0 + assert _count(conn, "kb_insight") == 1 + + +def test_delete_all_clears_everything(conn): + _insert_insight(conn, "pages/1-foo.md") + _insert_insight(conn, "pages/2-bar.md") + _insert_insight(conn, "pages/3-baz.md") + + deleted = delete_all_insights(conn) + + assert deleted == 3 + assert _count(conn, "kb_insight") == 0 + assert _count(conn, "kb_insight_vec") == 0 + assert _count(conn, "memory_meta") == 0 + + +def test_delete_all_on_empty_db_returns_zero(conn): + assert delete_all_insights(conn) == 0 + + +def test_delete_by_paths_invalidates_multi_source_insight(conn): + """An insight derived from multiple pages dies when ANY of its sources changes.""" + multi = _insert_insight(conn, ["pages/1-foo.md", "pages/2-bar.md"]) + solo = _insert_insight(conn, ["pages/3-baz.md"]) + + deleted = delete_insights_by_paths(conn, ["pages/2-bar.md"]) + + assert deleted == 1 + remaining = [r[0] for r in conn.execute("SELECT memory_id FROM kb_insight").fetchall()] + assert multi not in remaining + assert solo in remaining + # both rows in kb_insight_source for the multi insight should be gone + assert _count(conn, "kb_insight_source") == 1 + + +def test_delete_by_paths_preserves_other_memory_types(conn): + """Deleting kb_insight by path must not touch L1/L2/etc.""" + _insert_insight(conn, "pages/1-foo.md") + # an unrelated memory_meta row (e.g. L2) + other = str(uuid.uuid4()) + insert_meta(conn, MemoryMeta( + memory_id=other, memory_type="L2", session_id=None, + origin="extracted", created_at=_now(), last_access_at=_now(), + )) + + delete_insights_by_paths(conn, ["pages/1-foo.md"]) + + assert _count(conn, "memory_meta") == 1 + remaining = conn.execute("SELECT memory_id FROM memory_meta").fetchone() + assert remaining[0] == other diff --git a/tests/test_memory_ingest.py b/tests/test_memory_ingest.py new file mode 100644 index 0000000..f59d066 --- /dev/null +++ b/tests/test_memory_ingest.py @@ -0,0 +1,384 @@ +"""Tests for the heta remember ingestion pipeline (pipeline.py). + +All LLM and embedding calls are mocked so tests run offline and deterministically. +""" + +from __future__ import annotations + +import json +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from heta.config.schema import InsertPlanningConfig, HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.mem.client import EMBEDDING_DIM +from heta.mem.db import get_connection, init_db +from heta.mem.pipeline import remember + + +# ── constants ───────────────────────────────────────────────────────────────── + +FAKE_EMB = [0.01] * EMBEDDING_DIM # deterministic dummy embedding + +EPISODE_DICT = { + "who": ["user"], + "what": "参加了技术分享会", + "where_loc": "公司会议室", + "when_text": "上周", + "when_resolved": "2026-W19", + "when_precision": "week", + "why": None, + "summary": "用户上周在公司会议室参加了技术分享会", +} + +FACT_DICT = { + "subject": "用户", + "predicate": "居住在", + "object": "北京朝阳区", + "object_type": "literal", + "when_text": None, + "when_resolved": None, + "when_precision": None, +} + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture() +def config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +@pytest.fixture() +def tmp_db(tmp_path: Path) -> Path: + return tmp_path / "mem.sqlite3" + + +# ── patch helper ────────────────────────────────────────────────────────────── + +@contextmanager +def _patch_pipeline( + tmp_db: Path, + episodes: list[dict] | None = None, + facts: list[dict] | None = None, + conflicts: list[str] | None = None, # memory_ids to deprecate +): + """Patch all external I/O in pipeline.remember so tests run offline.""" + if episodes is None: + episodes = [] + if facts is None: + facts = [] + if conflicts is None: + conflicts = [] + + mock_client = MagicMock() + + def _open_conn(path, *, with_vec=False): + c = get_connection(tmp_db, with_vec=True) + init_db(c) + return c + + with ( + patch("heta.mem.pipeline.ensure_mem_dir"), + patch("heta.mem.pipeline.db_path", return_value=tmp_db), + patch("heta.mem.pipeline.get_connection", side_effect=_open_conn), + patch("heta.mem.pipeline.init_db"), + patch("heta.mem.pipeline.build_client", return_value=(mock_client, "mock-llm")), + patch("heta.mem.pipeline.build_embedding_client", return_value=(mock_client, "mock-emb")), + patch("heta.mem.pipeline.extract_episodes", return_value=episodes), + patch("heta.mem.pipeline.extract_facts", return_value=facts), + patch("heta.mem.pipeline.embed_text", return_value=FAKE_EMB), + patch("heta.mem.pipeline.detect_conflicts", return_value=(conflicts, FAKE_EMB)), + ): + yield + + +def _open(tmp_db: Path): + """Open a fresh read connection to the tmp DB after pipeline closes it.""" + return get_connection(tmp_db, with_vec=True) + + +# ── basic ingestion ─────────────────────────────────────────────────────────── + +def test_remember_creates_session_and_l0_turn(config, tmp_db) -> None: + with _patch_pipeline(tmp_db): + result = remember("hello world", config) + + conn = _open(tmp_db) + sessions = conn.execute("SELECT * FROM session WHERE session_id = ?", (result.session_id,)).fetchall() + turns = conn.execute("SELECT * FROM l0_turn WHERE session_id = ?", (result.session_id,)).fetchall() + conn.close() + + assert len(sessions) == 1 + assert sessions[0]["session_id"] == result.session_id + assert sessions[0]["ended_at"] is not None # session was closed + assert len(turns) == 1 + assert turns[0]["text_content"] == "hello world" + assert turns[0]["role"] == "user" + + +def test_remember_l0_turn_indexed_in_fts(config, tmp_db) -> None: + with _patch_pipeline(tmp_db): + remember("爬山打羽毛球", config) + + conn = _open(tmp_db) + rows = conn.execute( + "SELECT text_content FROM l0_turn_fts WHERE text_content MATCH ?", + ('"爬山打羽毛球"',), + ).fetchall() + conn.close() + + assert len(rows) == 1 + + +def test_remember_empty_extraction_succeeds(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, episodes=[], facts=[]): + result = remember("no events here", config) + + assert result.l1_count == 0 + assert result.l2_count == 0 + + +def test_remember_returns_correct_counts(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, episodes=[EPISODE_DICT], facts=[FACT_DICT, FACT_DICT]): + result = remember("some text", config) + + assert result.l1_count == 1 + assert result.l2_count == 2 + assert result.session_id != "" + assert result.elapsed_s >= 0 + + +# ── L1 episode storage ──────────────────────────────────────────────────────── + +def test_remember_l1_episode_stored_correctly(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, episodes=[EPISODE_DICT]): + result = remember("some text", config) + + conn = _open(tmp_db) + row = conn.execute( + "SELECT * FROM l1_episodic e JOIN memory_meta m ON e.memory_id = m.memory_id " + "WHERE m.session_id = ?", + (result.session_id,), + ).fetchone() + conn.close() + + assert row is not None + assert row["what"] == EPISODE_DICT["what"] + assert row["where_loc"] == EPISODE_DICT["where_loc"] + assert row["when_text"] == EPISODE_DICT["when_text"] + assert row["when_resolved"] == EPISODE_DICT["when_resolved"] + assert row["when_precision"] == EPISODE_DICT["when_precision"] + assert row["summary"] == EPISODE_DICT["summary"] + assert json.loads(row["who"]) == EPISODE_DICT["who"] + assert row["memory_type"] == "L1" + assert row["status"] == "active" + + +def test_remember_l1_episode_embedding_inserted(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, episodes=[EPISODE_DICT]): + result = remember("some text", config) + + conn = _open(tmp_db) + meta = conn.execute( + "SELECT memory_id FROM memory_meta WHERE session_id = ? AND memory_type = 'L1'", + (result.session_id,), + ).fetchone() + vec_row = conn.execute( + "SELECT memory_id FROM l1_episode_vec WHERE memory_id = ?", + (meta["memory_id"],), + ).fetchone() + conn.close() + + assert vec_row is not None + + +def test_remember_l1_who_defaults_to_user_when_missing(config, tmp_db) -> None: + ep = {**EPISODE_DICT} + del ep["who"] + with _patch_pipeline(tmp_db, episodes=[ep]): + remember("some text", config) + + conn = _open(tmp_db) + row = conn.execute("SELECT who FROM l1_episodic").fetchone() + conn.close() + + assert json.loads(row["who"]) == ["user"] + + +# ── L2 fact storage ─────────────────────────────────────────────────────────── + +def test_remember_l2_fact_stored_correctly(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, facts=[FACT_DICT]): + result = remember("some text", config) + + conn = _open(tmp_db) + row = conn.execute( + "SELECT * FROM l2_semantic s JOIN memory_meta m ON s.memory_id = m.memory_id " + "WHERE m.session_id = ?", + (result.session_id,), + ).fetchone() + conn.close() + + assert row is not None + assert row["subject"] == FACT_DICT["subject"] + assert row["predicate"] == FACT_DICT["predicate"] + assert row["object"] == FACT_DICT["object"] + assert row["object_type"] == "literal" + assert row["t_valid_end"] is None # active, not yet expired + assert row["status"] == "active" + assert row["memory_type"] == "L2" + + +def test_remember_l2_fact_embedding_inserted(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, facts=[FACT_DICT]): + result = remember("some text", config) + + conn = _open(tmp_db) + meta = conn.execute( + "SELECT memory_id FROM memory_meta WHERE session_id = ? AND memory_type = 'L2'", + (result.session_id,), + ).fetchone() + vec_row = conn.execute( + "SELECT memory_id FROM l2_fact_vec WHERE memory_id = ?", + (meta["memory_id"],), + ).fetchone() + conn.close() + + assert vec_row is not None + + +def test_remember_l2_object_type_list_coerced_to_string(config, tmp_db) -> None: + """LLM occasionally returns object_type as a list; pipeline should normalise it.""" + fact = {**FACT_DICT, "object_type": ["literal", "extra"]} + with _patch_pipeline(tmp_db, facts=[fact]): + remember("some text", config) + + conn = _open(tmp_db) + row = conn.execute("SELECT object_type FROM l2_semantic").fetchone() + conn.close() + + assert isinstance(row["object_type"], str) + assert row["object_type"] == "literal" + + +# ── conflict resolution ─────────────────────────────────────────────────────── + +def test_remember_conflict_deprecates_old_fact(config, tmp_db) -> None: + """When detect_conflicts returns an old memory_id, that fact is expired + deprecated.""" + # session 1: insert a fact directly so we have an old_id to conflict + from heta.mem.l2_store import insert_fact, insert_fact_embedding + from heta.mem.meta_store import insert_meta + from heta.mem.models import L2Semantic, MemoryMeta + from heta.mem.session_store import create_session, close_session + from heta.mem.models import Session + import time, uuid + + old_id = str(uuid.uuid4()) + now = int(time.time()) + + setup_conn = get_connection(tmp_db, with_vec=True) + init_db(setup_conn) + sid0 = str(uuid.uuid4()) + create_session(setup_conn, Session(session_id=sid0, started_at=now)) + close_session(setup_conn, sid0, now) + insert_meta(setup_conn, MemoryMeta( + memory_id=old_id, memory_type="L2", session_id=sid0, + origin="extracted", created_at=now, last_access_at=now, + )) + insert_fact(setup_conn, L2Semantic( + memory_id=old_id, subject="用户", predicate="居住在", object="北京朝阳区", + object_type="literal", fact_text="用户 居住在 北京朝阳区", + t_valid_start=now, + )) + insert_fact_embedding(setup_conn, old_id, FAKE_EMB) + setup_conn.close() + + # session 2: new fact conflicts with old_id + new_fact = {**FACT_DICT, "object": "北京海淀区"} + with _patch_pipeline(tmp_db, facts=[new_fact], conflicts=[old_id]): + remember("搬家了", config) + + conn = _open(tmp_db) + old_row = conn.execute( + "SELECT s.t_valid_end, m.status, m.deprecated_by " + "FROM l2_semantic s JOIN memory_meta m ON s.memory_id = m.memory_id " + "WHERE s.memory_id = ?", + (old_id,), + ).fetchone() + active_rows = conn.execute( + "SELECT object FROM l2_semantic WHERE t_valid_end IS NULL" + ).fetchall() + conn.close() + + assert old_row["t_valid_end"] is not None # expired + assert old_row["status"] == "deprecated" + assert old_row["deprecated_by"] is not None # FK to new fact + assert len(active_rows) == 1 + assert active_rows[0]["object"] == "北京海淀区" + + +def test_remember_no_conflict_keeps_both_facts(config, tmp_db) -> None: + """When detect_conflicts returns [], both old and new facts remain active.""" + fact_a = {**FACT_DICT, "predicate": "喜欢", "object": "爬山"} + fact_b = {**FACT_DICT, "predicate": "喜欢", "object": "羽毛球"} + + with _patch_pipeline(tmp_db, facts=[fact_a], conflicts=[]): + remember("喜欢爬山", config) + with _patch_pipeline(tmp_db, facts=[fact_b], conflicts=[]): + remember("喜欢羽毛球", config) + + conn = _open(tmp_db) + active = conn.execute( + "SELECT object FROM l2_semantic WHERE t_valid_end IS NULL ORDER BY object" + ).fetchall() + conn.close() + + objects = {r["object"] for r in active} + assert objects == {"爬山", "羽毛球"} + + +def test_remember_detect_conflicts_receives_session_id(config, tmp_db) -> None: + """detect_conflicts must be called with the current session_id so same-session + facts are excluded from conflict candidates.""" + captured_kwargs: dict = {} + + def _fake_detect_conflicts(**kwargs): + captured_kwargs.update(kwargs) + return ([], FAKE_EMB) + + with ( + _patch_pipeline(tmp_db, facts=[FACT_DICT]), + patch("heta.mem.pipeline.detect_conflicts", side_effect=_fake_detect_conflicts), + ): + result = remember("some text", config) + + assert "session_id" in captured_kwargs + assert captured_kwargs["session_id"] == result.session_id + + +# ── multiple sessions ───────────────────────────────────────────────────────── + +def test_remember_multiple_sessions_accumulate(config, tmp_db) -> None: + with _patch_pipeline(tmp_db, facts=[FACT_DICT]): + r1 = remember("first", config) + with _patch_pipeline(tmp_db, facts=[FACT_DICT]): + r2 = remember("second", config) + + assert r1.session_id != r2.session_id + + conn = _open(tmp_db) + n_sessions = conn.execute("SELECT COUNT(*) FROM session").fetchone()[0] + n_turns = conn.execute("SELECT COUNT(*) FROM l0_turn").fetchone()[0] + conn.close() + + assert n_sessions == 2 + assert n_turns == 2 diff --git a/tests/test_mineru_parser.py b/tests/test_mineru_parser.py new file mode 100644 index 0000000..697e762 --- /dev/null +++ b/tests/test_mineru_parser.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import zipfile +from io import BytesIO +from pathlib import Path + +from heta.config.schema import HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.kb.parser import parse_document + + +class _Response: + def __init__( + self, + *, + status_code: int = 200, + payload: dict | None = None, + content: bytes = b"", + text: str = "", + headers: dict[str, str] | None = None, + ) -> None: + self.status_code = status_code + self._payload = payload or {} + self.content = content + self.text = text + self.headers = headers or {} + + def json(self) -> dict: + return self._payload + + +def _config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig(enable=True, provider="cloud", api_key="mineru-token", endpoint=None), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def test_parse_document_accepts_office_via_mineru_cloud(monkeypatch, tmp_path: Path) -> None: + source = tmp_path / "slides.pptx" + archived = tmp_path / "2026-05-15_slides.pptx" + source.write_bytes(b"pptx") + archived.write_bytes(b"pptx") + zip_bytes = _mineru_zip("# Slides\n\nParsed by MinerU.") + requests_seen: list[tuple[str, str]] = [] + + def post(url, **kwargs): + requests_seen.append(("POST", url)) + assert kwargs["headers"]["Authorization"] == "Bearer mineru-token" + assert kwargs["json"]["files"] == [{"name": "2026-05-15_slides.pptx", "data_id": "2026-05-15_slides"}] + assert kwargs["json"]["model_version"] == "vlm" + return _Response(payload={"code": 0, "data": {"batch_id": "batch-1", "file_urls": ["https://upload"]}}) + + def put(url, **kwargs): + requests_seen.append(("PUT", url)) + assert url == "https://upload" + return _Response(status_code=200) + + def get(url, **kwargs): + requests_seen.append(("GET", url)) + if url.endswith("/batch-1"): + assert kwargs["headers"]["Authorization"] == "Bearer mineru-token" + return _Response( + payload={ + "code": 0, + "data": { + "extract_result": [ + { + "file_name": "2026-05-15_slides.pptx", + "state": "done", + "full_zip_url": "https://result.zip", + } + ] + }, + } + ) + assert url == "https://result.zip" + return _Response(content=zip_bytes) + + monkeypatch.setattr("heta.kb.parser.requests.post", post) + monkeypatch.setattr("heta.kb.parser.requests.put", put) + monkeypatch.setattr("heta.kb.parser.requests.get", get) + + document = parse_document(source, archived, _config()) + + assert document.title == "Slides" + assert document.metadata["extension"] == ".pptx" + assert document.markdown_content == "# Slides\n\nParsed by MinerU." + assert [method for method, _ in requests_seen] == ["POST", "PUT", "GET", "GET"] + + +def test_parse_document_uses_local_mineru_zip_artifacts(monkeypatch, tmp_path: Path) -> None: + source = tmp_path / "paper.pdf" + archived = tmp_path / "2026-05-15_paper.pdf" + source.write_bytes(b"%PDF") + archived.write_bytes(b"%PDF") + zip_bytes = _mineru_zip( + "# Paper\n\n![](images/figure.jpg)\n", + content_list=[ + { + "type": "image", + "img_path": "images/figure.jpg", + "bbox": [1, 2, 3, 4], + "page_idx": 2, + } + ], + images={"images/figure.jpg": b"jpg"}, + ) + config = HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig(enable=True, provider="local", api_key=None, endpoint="http://127.0.0.1:8000"), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + def post(url, **kwargs): + assert url == "http://127.0.0.1:8000/file_parse" + assert "files" in kwargs["files"] + assert kwargs["data"]["response_format_zip"] == "true" + assert kwargs["data"]["return_content_list"] == "true" + return _Response(content=zip_bytes, headers={"content-type": "application/zip"}) + + monkeypatch.setattr("heta.kb.parser.requests.post", post) + + document = parse_document( + source, + archived, + config, + original_name="paper.pdf", + page_offset=20, + base_dir=tmp_path, + ) + + assert "../../raw/parsed/2026-05-15_paper/images/figure.jpg" in document.markdown_content + assert "Source: paper.pdf, page 23, bbox [1, 2, 3, 4]" in document.markdown_content + parsed = tmp_path / "workspace" / "kb" / "raw" / "parsed" / "2026-05-15_paper" + assert (parsed / "full.md").exists() + assert (parsed / "content_list.json").exists() + assert (parsed / "images" / "figure.jpg").read_bytes() == b"jpg" + + +def _mineru_zip( + markdown: str, + *, + content_list: list[dict] | None = None, + images: dict[str, bytes] | None = None, +) -> bytes: + buffer = BytesIO() + with zipfile.ZipFile(buffer, "w") as archive: + archive.writestr("full.md", markdown) + if content_list is not None: + import json + + archive.writestr("demo_content_list.json", json.dumps(content_list)) + for path, data in (images or {}).items(): + archive.writestr(path, data) + return buffer.getvalue() diff --git a/tests/test_pdf_plan.py b/tests/test_pdf_plan.py index 1dee6dd..0a76351 100644 --- a/tests/test_pdf_plan.py +++ b/tests/test_pdf_plan.py @@ -3,7 +3,7 @@ from pypdf import PdfWriter -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import InsertPlanningConfig, HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig from heta.kb import paths from heta.kb.pdf_plan import PDF_PAGE_THRESHOLD, SplitUnit, estimate_pdf_pages, plan_insert_files @@ -17,11 +17,11 @@ def test_plan_insert_files_splits_large_pdf_and_keeps_original(tmp_path: Path) - assert len(plans) == 1 assert plans[0].enabled is True assert plans[0].page_count == PDF_PAGE_THRESHOLD + 1 - assert plans[0].parts == 3 - assert len(prepared) == 3 + assert plans[0].parts == 5 + assert len(prepared) == 5 assert (paths.raw_dir(tmp_path) / "originals").exists() assert all(item.archived_path.exists() for item in prepared) - assert [estimate_pdf_pages(item.archived_path) for item in prepared] == [40, 40, 1] + assert [estimate_pdf_pages(item.archived_path) for item in prepared] == [20, 20, 20, 20, 1] assert all(item.original_path is not None for item in prepared) assert all(item.metadata_path is not None and item.metadata_path.exists() for item in prepared) @@ -60,12 +60,12 @@ def test_plan_insert_files_uses_agent_split_plan(monkeypatch, tmp_path: Path) -> assert plans[0].document_type == "textbook" assert plans[0].split_strategy == "chapter" - assert [item.page_start for item in prepared] == [1, 31, 71] - assert [item.page_end for item in prepared] == [30, 70, 81] + assert [item.page_start for item in prepared] == [1, 21, 31, 51, 71] + assert [item.page_end for item in prepared] == [20, 30, 50, 70, 81] metadata = json.loads(prepared[0].metadata_path.read_text(encoding="utf-8")) assert metadata["original"].endswith("large.pdf") assert metadata["start_page"] == 1 - assert metadata["end_page"] == 30 + assert metadata["end_page"] == 20 assert metadata["split_strategy"] == "chapter" @@ -84,7 +84,7 @@ def test_plan_insert_files_falls_back_when_agent_plan_is_invalid(monkeypatch, tm prepared, plans = plan_insert_files([source], config=_config(), base_dir=tmp_path) assert plans[0].split_strategy == "fixed_page_window" - assert [estimate_pdf_pages(item.archived_path) for item in prepared] == [40, 40, 1] + assert [estimate_pdf_pages(item.archived_path) for item in prepared] == [20, 20, 20, 20, 1] def test_plan_insert_files_fills_pages_missing_from_agent_plan(monkeypatch, tmp_path: Path) -> None: @@ -102,7 +102,7 @@ def test_plan_insert_files_fills_pages_missing_from_agent_plan(monkeypatch, tmp_ prepared, plans = plan_insert_files([source], config=_config(), base_dir=tmp_path) assert plans[0].split_strategy == "section" - assert [(item.page_start, item.page_end) for item in prepared] == [(1, 40), (41, 60), (61, 81)] + assert [(item.page_start, item.page_end) for item in prepared] == [(1, 20), (21, 40), (41, 60), (61, 80), (81, 81)] def _write_pdf(path: Path, pages: int) -> None: @@ -119,4 +119,5 @@ def _config() -> HetaConfig: llm=LLMConfig(provider="qwen", api_key="sk-test"), mineru=MinerUConfig.disabled(), vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), ) diff --git a/tests/test_provider_clients.py b/tests/test_provider_clients.py new file mode 100644 index 0000000..2ebe73b --- /dev/null +++ b/tests/test_provider_clients.py @@ -0,0 +1,83 @@ +from heta.config.schema import HetaConfig, InsertPlanningConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.providers import clients + + +def _config() -> HetaConfig: + return HetaConfig( + version=1, + llm=LLMConfig( + provider="custom", + api_key="legacy-key", + chat_api_key="chat-key", + chat_model="chat-model", + chat_base_url="http://chat.local/v1", + multimodal_api_key="mm-key", + multimodal_model="mm-model", + multimodal_base_url="http://mm.local/v1", + embedding_api_key="embedding-key", + embedding_model="embedding-model", + embedding_base_url="http://embedding.local/v1", + ), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + +def test_provider_clients_use_capability_specific_api_keys(monkeypatch) -> None: + seen: list[dict] = [] + + class FakeOpenAI: + def __init__(self, **kwargs): + seen.append(kwargs) + + monkeypatch.setattr(clients, "OpenAI", FakeOpenAI) + config = _config() + + chat = clients.build_chat_client(config) + multimodal = clients.build_multimodal_client(config) + embedding = clients.build_embedding_client(config) + + assert chat.model == "chat-model" + assert multimodal.model == "mm-model" + assert embedding.model == "embedding-model" + assert seen == [ + {"api_key": "chat-key", "timeout": 60, "base_url": "http://chat.local/v1"}, + {"api_key": "mm-key", "timeout": 300, "base_url": "http://mm.local/v1"}, + {"api_key": "embedding-key", "timeout": 120, "base_url": "http://embedding.local/v1"}, + ] + + +def test_extra_body_prefers_explicit_config() -> None: + config = _config() + explicit = HetaConfig( + version=config.version, + llm=LLMConfig( + provider="custom", + api_key="legacy-key", + chat_api_key="chat-key", + chat_model="chat-model", + chat_base_url="http://chat.local/v1", + chat_extra_body={"enable_thinking": False}, + embedding_api_key="embedding-key", + embedding_model="embedding-model", + embedding_base_url="http://embedding.local/v1", + ), + mineru=config.mineru, + vector_index=config.vector_index, + insert_planning=config.insert_planning, + ) + + assert clients.extra_body(explicit) == {"enable_thinking": False} + + +def test_extra_body_keeps_qwen_default() -> None: + config = HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig(enable=False), + insert_planning=InsertPlanningConfig.enabled(), + ) + + assert clients.extra_body(config) == {"enable_thinking": False} diff --git a/tests/test_query.py b/tests/test_query.py index a76eb16..51396d9 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,12 +2,13 @@ import pytest -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import InsertPlanningConfig, HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig from heta.kb import paths from heta.kb.text import frontmatter_page +from heta.query.agent import _parse_final_answer, _vector_match_map from heta.query.models import QueryResult, QuerySource, VectorMatch from heta.query.pipeline import run_wiki_query -from heta.query.tools import format_vector_matches, read_page, source_from_page_path +from heta.query.tools import format_vector_matches, read_page, read_raw, source_from_page_path def _config(vector_enabled: bool = False) -> HetaConfig: @@ -16,6 +17,7 @@ def _config(vector_enabled: bool = False) -> HetaConfig: llm=LLMConfig(provider="qwen", api_key="sk-test"), mineru=MinerUConfig.disabled(), vector_index=VectorIndexConfig(enable=vector_enabled), + insert_planning=InsertPlanningConfig.enabled(), ) @@ -51,6 +53,16 @@ def test_read_page_is_limited_to_pages(tmp_path: Path) -> None: assert read_page("index.md", tmp_path).startswith("error:") +def test_read_raw_is_limited_to_raw_directory(tmp_path: Path) -> None: + raw = paths.raw_dir(tmp_path) + raw.mkdir(parents=True) + (raw / "module.py").write_text("def run():\n return True\n", encoding="utf-8") + + assert "def run" in read_raw("raw/module.py", tmp_path) + assert "def run" in read_raw("../../raw/module.py", tmp_path) + assert read_raw("../heta.yaml", tmp_path).startswith("error:") + + def test_source_from_page_path_reads_frontmatter_and_wiki_id(tmp_path: Path) -> None: page = paths.pages_dir(tmp_path) / "12-hetagen.md" page.parent.mkdir(parents=True) @@ -61,6 +73,79 @@ def test_source_from_page_path_reads_frontmatter_and_wiki_id(tmp_path: Path) -> assert source == QuerySource(12, "HetaGen", "pages/12-hetagen.md", "Content") +def test_query_sources_include_validated_vector_chunks_only(tmp_path: Path) -> None: + pages = paths.pages_dir(tmp_path) + pages.mkdir(parents=True) + (pages / "8-image.md").write_text(frontmatter_page("Image", "image.png", "Image summary.", "Body."), encoding="utf-8") + (pages / "10-audio.md").write_text( + frontmatter_page("Audio", "audio.mp3", "Audio summary.", "Transcript."), + encoding="utf-8", + ) + vector_matches = _vector_match_map( + [ + VectorMatch(8, "8-image.md", "pages/8-image.md", "8:abc", "Content > Visible Text", "image text", 0.8), + VectorMatch(10, "10-audio.md", "pages/10-audio.md", "10:def", "Content > Transcript", "hello", 0.9), + ] + ) + + final = _parse_final_answer( + text=( + '{"answer": "The audio says hello.", "used_sources": [' + '{"path": "pages/10-audio.md", "heading_path": "Content > Transcript"},' + '{"path": "pages/8-image.md", "heading_path": "Content > Missing"}' + "]}" + ), + read_paths=set(), + vector_matches=vector_matches, + base_dir=tmp_path, + ) + + assert final.answer == "The audio says hello." + assert final.sources == [QuerySource(10, "Audio", "pages/10-audio.md", "Content > Transcript")] + + +def test_query_sources_accept_read_pages_without_vector_heading(tmp_path: Path) -> None: + pages = paths.pages_dir(tmp_path) + pages.mkdir(parents=True) + (pages / "10-audio.md").write_text( + frontmatter_page("Audio", "audio.mp3", "Audio summary.", "Transcript."), + encoding="utf-8", + ) + + final = _parse_final_answer( + text='{"answer": "From the full page.", "used_sources": [{"path": "pages/10-audio.md"}]}', + read_paths={"pages/10-audio.md"}, + vector_matches={}, + base_dir=tmp_path, + ) + + assert final.sources == [QuerySource(10, "Audio", "pages/10-audio.md")] + + +def test_query_sources_reject_raw_used_sources(tmp_path: Path) -> None: + final = _parse_final_answer( + text='{"answer": "Raw helped.", "used_sources": [{"path": "raw/module.py"}]}', + read_paths=set(), + vector_matches={}, + base_dir=tmp_path, + ) + + assert final.sources == [] + + +def test_parse_final_answer_marks_non_json_response_invalid(tmp_path: Path) -> None: + final = _parse_final_answer( + text="This is not JSON.", + read_paths=set(), + vector_matches={}, + base_dir=tmp_path, + ) + + assert final.answer == "This is not JSON." + assert final.sources == [] + assert not final.valid_json + + def test_format_vector_matches_includes_chunk_identity() -> None: text = format_vector_matches( [ diff --git a/tests/test_status.py b/tests/test_status.py index 3d231dd..e342824 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -1,7 +1,7 @@ from pathlib import Path from heta.cli.status import build_status_summary -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import InsertPlanningConfig, HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig def test_status_summary_counts_kb_and_wiki_pages(tmp_path: Path) -> None: @@ -19,12 +19,15 @@ def test_status_summary_counts_kb_and_wiki_pages(tmp_path: Path) -> None: llm=LLMConfig(provider="qwen", api_key="sk-test"), mineru=MinerUConfig(enable=True, provider="local", api_key=None, endpoint="http://127.0.0.1:8000"), vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), ) summary = build_status_summary(config, tmp_path) assert summary.llm_provider == "qwen" assert summary.mineru == "local (http://127.0.0.1:8000)" + assert summary.insert_planning == "enabled" + assert summary.dynamic_insert == "disabled" assert summary.kb_files == 2 assert summary.wiki_pages == 1 assert summary.heta_space == tmp_path @@ -37,6 +40,8 @@ def test_status_summary_handles_missing_config_and_workspace(tmp_path: Path) -> assert summary.llm_provider == "not configured" assert summary.mineru == "not configured" + assert summary.insert_planning == "not configured" + assert summary.dynamic_insert == "not configured" assert summary.kb_files == 0 assert summary.wiki_pages == 0 assert summary.heta_used_bytes == 0 diff --git a/tests/test_validators.py b/tests/test_validators.py index a62d58c..507bfae 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -18,6 +18,14 @@ def test_validate_llm_non_200_fails(get: Mock) -> None: assert validate_llm("qwen", "bad-key") is False +@patch("heta.providers.llm.requests.get") +def test_validate_custom_llm_uses_base_url_models(get: Mock) -> None: + get.return_value.status_code = 200 + + assert validate_llm("custom", "sk-test", "http://llm.local/v1") is True + assert get.call_args.args[0] == "http://llm.local/v1/models" + + @patch("heta.providers.mineru.requests.post") def test_validate_mineru_cloud_success(post: Mock) -> None: post.return_value.status_code = 200 diff --git a/tests/test_vector_index.py b/tests/test_vector_index.py index e945d1e..975d2aa 100644 --- a/tests/test_vector_index.py +++ b/tests/test_vector_index.py @@ -5,9 +5,19 @@ import sqlite_vec -from heta.config.schema import HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig +from heta.config.schema import InsertPlanningConfig, HetaConfig, LLMConfig, MinerUConfig, VectorIndexConfig from heta.kb import paths -from heta.kb.vector_index import _ensure_schema, _insert_chunk, chunk_wiki_page, search_wiki_vector_index +from heta.kb.models import FileChange +from heta.kb.vector_index import ( + _ensure_schema, + _fts_terms, + _insert_chunk, + chunk_wiki_page, + search_wiki_fts_index, + search_wiki_hybrid_index, + search_wiki_vector_index, + sync_wiki_vector_index, +) def test_chunk_wiki_page_uses_heading_path_and_page_context(tmp_path: Path) -> None: @@ -77,6 +87,7 @@ def test_search_wiki_vector_index_returns_ranked_chunks(monkeypatch, tmp_path: P llm=LLMConfig(provider="qwen", api_key="sk-test"), mineru=MinerUConfig.disabled(), vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), ) results = search_wiki_vector_index(query="table synthesis", config=config, top_k=3, base_dir=tmp_path) @@ -88,8 +99,183 @@ def test_search_wiki_vector_index_returns_ranked_chunks(monkeypatch, tmp_path: P assert results[0].score == 1.0 +def test_search_wiki_fts_index_matches_mixed_chinese_and_codes(tmp_path: Path) -> None: + db_path = paths.vector_db_path(tmp_path) + db_path.parent.mkdir(parents=True) + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + _ensure_schema(conn) + target = chunk_wiki_page( + _write_page( + tmp_path, + "1-flow-control.md", + "ZXDOC-A-10-20-30 温度控制策略", + "结构图和部件说明。", + "温度控制策略结构图包含文档代码 ZXDOC-A-10-20-30-ALPHA。", + ) + )[0] + other = chunk_wiki_page( + _write_page( + tmp_path, + "2-power.md", + "ZXDOC-B-40-50-60 存储模块", + "存储说明。", + "主存储模块说明。", + ) + )[0] + _insert_chunk(conn, target, [1.0] + [0.0] * 1023) + _insert_chunk(conn, other, [0.0, 1.0] + [0.0] * 1022) + conn.commit() + conn.close() + + results = search_wiki_fts_index(query="温度控制策略 10-20-30", top_k=3, base_dir=tmp_path) + + assert results + assert results[0].wiki_id == 1 + assert results[0].retrieval == "fts" + + +def test_search_wiki_hybrid_index_fuses_vector_and_fts(monkeypatch, tmp_path: Path) -> None: + db_path = paths.vector_db_path(tmp_path) + db_path.parent.mkdir(parents=True) + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + _ensure_schema(conn) + vector_chunk = chunk_wiki_page( + _write_page( + tmp_path, + "1-semantic.md", + "HetaGen", + "Structured content generation.", + "HetaGen supports table synthesis.", + ) + )[0] + keyword_chunk = chunk_wiki_page( + _write_page( + tmp_path, + "2-code.md", + "ZXDOC-A-10-20-30", + "结构图。", + "文档代码 ZXDOC-A-10-20-30-ALPHA。", + ) + )[0] + _insert_chunk(conn, vector_chunk, [1.0] + [0.0] * 1023) + _insert_chunk(conn, keyword_chunk, [0.0, 1.0] + [0.0] * 1022) + conn.commit() + conn.close() + + monkeypatch.setattr("heta.kb.vector_index._embed_texts", lambda texts, config: [[1.0] + [0.0] * 1023]) + config = HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), + ) + + results = search_wiki_hybrid_index( + query="table synthesis ZXDOC-A-10-20-30", + config=config, + top_k=2, + base_dir=tmp_path, + ) + + assert {result.wiki_id for result in results} == {1, 2} + + +def test_fts_terms_normalizes_mixed_width_punctuation() -> None: + assert _fts_terms("zxdoc-a-10-20 温度控制,结构图") == [ + "ZXDOC-A-10-20", + "温度控制", + "结构图", + ] + + +def test_search_wiki_fts_index_backfills_existing_chunks(tmp_path: Path) -> None: + db_path = paths.vector_db_path(tmp_path) + db_path.parent.mkdir(parents=True) + conn = sqlite3.connect(db_path) + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + _ensure_schema(conn) + conn.execute( + """ + INSERT INTO wiki_chunks ( + wiki_id, page_name, chunk_id, heading_path, content, content_hash + ) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + 7, + "7-existing.md", + "7:existing", + "Content", + "Page: ZXDOC-A-10-20 结构图\n\n已有 chunk 应该自动回填到 FTS。", + "existing", + ), + ) + conn.commit() + conn.close() + + results = search_wiki_fts_index(query="10-20 结构图", top_k=1, base_dir=tmp_path) + + assert len(results) == 1 + assert results[0].wiki_id == 7 + conn = sqlite3.connect(db_path) + try: + assert conn.execute("SELECT count(*) FROM wiki_chunk_fts").fetchone()[0] == 1 + finally: + conn.close() + + +def test_sync_wiki_vector_index_deduplicates_repeated_page_changes(monkeypatch, tmp_path: Path) -> None: + _write_page( + tmp_path, + "1-hetagen.md", + "HetaGen", + "Structured content generation.", + """ +### Capabilities + +HetaGen supports table synthesis. + +### Query + +HetaGen can answer structured questions. +""", + ) + monkeypatch.setattr("heta.kb.vector_index._embed_texts", lambda texts, config: [[1.0] + [0.0] * 1023 for _ in texts]) + config = HetaConfig( + version=1, + llm=LLMConfig(provider="qwen", api_key="sk-test"), + mineru=MinerUConfig.disabled(), + vector_index=VectorIndexConfig.enabled(), + insert_planning=InsertPlanningConfig.enabled(), + ) + + sync_wiki_vector_index( + changes=[ + FileChange("added", "HetaGen", "pages/1-hetagen.md"), + FileChange("updated", "HetaGen", "pages/1-hetagen.md"), + ], + config=config, + base_dir=tmp_path, + ) + + conn = sqlite3.connect(paths.vector_db_path(tmp_path)) + try: + assert conn.execute("SELECT count(*) FROM wiki_chunks").fetchone()[0] == 2 + finally: + conn.close() + + def _write_page(tmp_path: Path, name: str, title: str, summary: str, content: str) -> Path: page = paths.pages_dir(tmp_path) / name - page.parent.mkdir(parents=True) + page.parent.mkdir(parents=True, exist_ok=True) page.write_text(frontmatter_page(title, "source.md", summary, content), encoding="utf-8") return page