diff --git a/docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md b/docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md new file mode 100644 index 00000000000..3be9c207bf0 --- /dev/null +++ b/docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md @@ -0,0 +1,265 @@ +# Native Hybrid Search with Score Fusion — Design Spec + +**Branch:** `sp/hybrid-search` (off `sp/bm25`) +**Date:** 2026-06-03 +**Status:** Approved (Option C) + +## 1. Problem + +Dgraph can rank text (BM25, on `sp/bm25`) and find nearest vectors (HNSW `similar_to`), +but cannot **combine** them. A consumer wanting hybrid retrieval (the standard RAG +pattern: dense vector + sparse/keyword, fused into one ranked list) must issue +separate DQL queries and fuse the results in application code. + +Concretely, the reference consumer (modelhub, a GraphRAG product running on Dgraph) +issues **three** separate DQL queries per search — (1) `similar_to` vector, (2) +BM25-ish term, (3) entity-anchored term — and fuses them in Python with Reciprocal +Rank Fusion (`score = Σ 1/(k+rank)`, k=60), sometimes with a linear combination +(`α·vec + (1-α)·text`, scores max-normalized first). + +This spec brings that fusion **natively into a single DQL query**. + +## 2. Approach (Option C) + +Two pieces, agreed after consulting GPT-5 and Gemini (both recommended the N-way +primitive as the core; Option C adds a thin convenience wrapper): + +1. **`fuse()`** — an N-way fusion combinator over already-scored DQL **value + variables**. The general primitive. Handles any number of channels and any + scored signal, not just bm25/vector. +2. **`hybrid()`** — convenience sugar for the common 2-channel bm25+vector case, + expanded at query-rewrite time into two channel blocks + a `fuse()`. No new + executor path. + +A prerequisite makes `fuse()` useful with vectors: + +3. **Surface `similar_to` similarity scores as a value variable** (today it returns + only uids). Required so vector results can be a fusion channel. + +### Why this fits Dgraph's architecture + +- DQL already has **value variables** carrying uid→score maps (`var(func: bm25(...))` + binds per-doc scores into a `varValue{Uids, Vals}` — see `query/query.go` + `populateUidValVar`, the `bm25` case). +- `ProcessQuery` already has a **variable-dependency scheduler** (`canExecute` over + `QueryVars.Needs`/`.Defines`): a block runs only once the variables it needs are + populated. A `fuse()` block that *needs* its channel vars and *defines* the fused + var is scheduled automatically after its inputs — no new sequencing logic. +- Fusion is a **coordinator-side** operation over resolved variables (like `math()` + and `uid(a,b)`), which is correct under predicate sharding: each channel may be + computed on a different Raft group; their value variables are already merged to + the coordinator before fusion runs. + +## 3. DQL Surface + +### 3.1 `fuse()` — N-way primitive + +``` +v as var(func: bm25(text, "quick brown fox")) +e as var(func: similar_to(emb, 100, $queryVec)) + +f as var(func: fuse(v, e, method: "rrf", k: 60)) + +{ + result(func: uid(f), orderdesc: val(f), first: 10) { + uid + val(f) + } +} +``` + +- **Positional args**: two or more value-variable references (the channels). These + become the block's `NeedsVar`. +- **Named options** (parsed like `similar_to`'s `ef:`/`distance_threshold:`): + - `method`: `"rrf"` (default) or `"linear"`. + - `k`: RRF rank constant (default `60`). Ignored for linear. + - `weights`: comma-separated per-channel weights for linear, aligned positionally + with the channel args (e.g. `weights: "0.3,0.7"`). Default: `1.0` each. Ignored + for RRF. + - `normalize`: linear-only score normalization, `"max"` (default) or `"none"`. + - `topk`: optional cap on the number of fused results emitted (default: unbounded; + downstream `first/offset` still applies). Bounds coordinator work. +- **Output**: binds `f` as a value variable — both the **union** uid set + (`uid(f)`) and the uid→fusedScore map (`val(f)`, `orderdesc: val(f)`), exactly + like the bm25 ranker var. + +### 3.2 `hybrid()` — 2-channel sugar + +``` +f as var(func: hybrid(text, "quick brown fox", emb, $queryVec, 100, method: "rrf", k: 60)) + +{ result(func: uid(f), orderdesc: val(f), first: 10) { uid val(f) } } +``` + +Positional: `textPredicate, "queryText", vectorPredicate, $queryVec, topk`. Named +options as for `fuse()`. Rewritten at query-build time to: + +``` +__hybrid_0_ as var(func: bm25(text, "quick brown fox")) +__hybrid_1_ as var(func: similar_to(emb, 100, $queryVec)) +f as var(func: fuse(__hybrid_0_, __hybrid_1_, method: "rrf", k: 60)) +``` + +The synthetic var names are unique per hybrid block. After rewrite there is no +distinct `hybrid` execution path — it is purely a parser/builder transformation. + +## 4. Fusion Semantics (the core, `query/fuse.go`) + +Pure function over channels, fully unit-testable, no I/O: + +```go +type fuseChannel struct { + scores map[uint64]float64 // uid -> raw channel score (higher = better) + weight float64 // linear weight (default 1.0) +} + +func fuseRRF(channels []fuseChannel, k float64, topk int) []scoredUid +func fuseLinear(channels []fuseChannel, normalize string, topk int) []scoredUid +``` + +### Outer-join / union semantics (the key correctness point) + +Both models independently flagged this. Fusion is a **set union** of candidate uids +across channels, NOT an intersection. For a uid present in some channels but not +others: + +- **RRF**: only ranked channels contribute. `fused[u] = Σ_{c: u∈c} 1/(k + rank_c(u))`. + A channel that doesn't contain `u` adds nothing (equivalent to rank = ∞). +- **Linear**: missing channel contributes `0`. `fused[u] = Σ_c weight_c · norm_c(score_c(u))`, + where `norm_c(missing) = 0`. + +Standard DQL `math()` across var blocks aligns/intersects on uid and would drop or +NaN-poison such uids — `fuse()` must not. + +### RRF ranks + +Each channel is independently sorted by raw score **descending**, tie-broken by uid +**ascending**, to assign 1-based ranks. (Deterministic; matches the bm25/HNSW +`sorted()` tie-break already in the codebase.) + +### Linear normalization + +`"max"` (default): divide each channel's scores by that channel's max +(`|max|`, guard against 0 → channel contributes 0). Brings heterogeneous score +scales (BM25 ∈ [0,∞), cosine ∈ [-1,1]) onto a comparable range before weighting. +`"none"`: use raw scores (caller asserts they're comparable). + +### Output ordering + +Fused results sorted by fused score descending, tie-broken by uid ascending. If +`topk > 0`, truncate to `topk`. The output is emitted to the value variable as the +union uid set (ascending, per the query pipeline contract) + positionally-aligned +fused scores — identical shape to the bm25 var binding. + +## 5. Vector score surfacing (prerequisite) + +`similar_to` currently emits only `UidMatrix` (worker/task.go ~L417). Change: + +- Extend `index.SearchPathResult` with `Distances []float64` parallel to `Neighbors`, + populated from the search-layer heap (which already holds metric-domain distances, + `n.value`) in `addFinalNeighbors`. +- In the `similarToFn` worker path, when the function is bound to a value variable, + use the path-returning search to obtain distances and emit a `ValueMatrix` of + **similarity** scores (higher = better): + - **cosine**: cosine similarity in [-1,1], as-is. + - **dotproduct**: dot product, as-is. + - **euclidean**: `1/(1 + d)` where `d` is the (non-squared) Euclidean distance, + mapping to (0,1] so higher = better and linear normalization is well-behaved. +- Bind in `populateUidValVar` exactly like the `bm25` case (reuse/generalize that + branch). When `similar_to` is **not** bound to a var, behavior is unchanged + (uids only) — zero overhead for existing queries. + +Orientation contract: **all rankers surface higher-is-better scores**, so RRF and +linear fusion compose without per-channel sign handling. + +## 6. Execution flow + +1. Parse: `fuse`/`hybrid` recognized as functions in `dql/parser.go`; channel args → + `NeedsVar`, named options stored on the function. `hybrid` expands to channel + blocks + `fuse`. The block `Defines` its output var, `Needs` its channel vars. +2. Schedule: `ProcessQuery`'s `canExecute` runs channel blocks first; the `fuse` + block becomes runnable once all channel vars are in `req.Vars`. +3. Compute: the `fuse` block is **coordinator-only** — like the existing + `similar_to`-empty/`IsEmpty` cases, it is **not** dispatched to a worker. Fusion + is computed in `populateVarMap` (new `fuse` case) reading channel `varValue`s + from `doneVars`, producing the fused `varValue`. +4. Consume: downstream `uid(f)` / `orderdesc: val(f)` / `first`/`offset` work via the + existing value-variable machinery. + +## 7. Validation & errors + +- ≥1 channel var required (2+ to be meaningful; 1 is allowed and passes through). +- Unknown `method` → error. `k <= 0` → error. `weights` count must match channel + count when provided → error. Malformed `weights` floats → error. +- Empty channels (no matches) are valid and contribute nothing. +- A channel var that is a **uid variable without scores** (no `Vals`): for RRF, rank + by the var's intrinsic order if any, else treat as unscored → error with a clear + message ("fuse channel %q has no scores; use a ranker like bm25/similar_to"). MVP: + require scored channels. + +## 8. Testing + +**Unit (`query/fuse_test.go`)** — pure fusion core: +- RRF: known ranks → known `Σ 1/(k+rank)`; default k=60; custom k. +- Linear: max-normalize, weights, `normalize:none`. +- Union semantics: uid in 1 of N channels; disjoint channels; full overlap. +- Ties (equal scores → uid-ascending rank); empty channels; single channel. +- `topk` truncation; determinism. + +**Worker (`worker/`)** — vector score surfacing: +- `similar_to` bound to a var emits similarity scores with correct orientation per + metric; unbound `similar_to` unchanged. + +**Integration (systest/DQL)** — end-to-end: +- `fuse()` RRF over bm25 + vector; ordering matches hand-computed RRF. +- `fuse()` linear with weights. +- 3-channel fusion (modelhub's shape). +- Pagination (`first`/`offset`) on the consuming block. +- Missing-uid union correctness. +- `hybrid()` produces results identical to the equivalent explicit `fuse()`. +- Error paths (unknown method, bad weights, unscored channel). + +## 9. Out of scope (future) + +- Filter pushdown into HNSW (pre-filtered ANN) — separate gap. +- Worker-side fusion pushdown / `topk` propagation into channel funcs. +- Additional methods (weighted-RRF, ISR, distribution-based fusion). +- Reranking primitives. + +## 9a. Adversarial review outcomes (GPT-5 + Gemini) + +Both models deep-reviewed the diff. Findings triaged and resolved: + +- **Behavior preservation for `similar_to` (High, both):** the worker no longer + routes plain (no-option) vector queries through the options path. New + `SearchScored` / `SearchWithUidScored` mirror `Search` / `SearchWithUid` exactly + and just also return scores, so existing queries' neighbor selection is unchanged; + the `*Options*` scored variants are used only when ef/distance-threshold is given. +- **NaN/Inf safety (both):** `scoresFromVar` drops non-finite scores and + `channelMaxAbs` ignores them, so a pathological score can never break the sort + comparator's strict-weak-ordering or poison a linear sum. +- **"Ascending Uids destroys ranking" (both, flagged Critical):** not a bug — this + follows the same value-variable contract as bm25 (Uids is the unordered set for + `uid(var)`; ranked order is recovered via `orderdesc: val(var)`; `topk` selection + happens before the ascending sort). Documented in `computeFuse`. +- **Undefined channel var (GPT-5 Critical "stall"):** not a stall — `checkDependency` + rejects it at parse time ("variables used but not defined"). Regression test added. +- **Synthetic var collision (both, Low):** `__hybrid` is now a reserved prefix; + user vars using it are rejected with a clear message. +- **`@filter(similar_to(...))` + ValueMatrix (both, Med):** follows the proven bm25 + precedent (bm25 emits ValueMatrix and is used in filters); to be confirmed by the + vector integration suite in CI. +- **Coordinator GC churn in `channelRanks` (Gemini, perf):** acknowledged; acceptable + at expected channel sizes, noted as future optimization (buffer pooling) if needed. + +## 10. Files touched + +| Area | File(s) | +|---|---| +| Fusion core (new) | `query/fuse.go`, `query/fuse_test.go` | +| Parser | `dql/parser.go` (recognize `fuse`/`hybrid`, args+opts), `dql/state.go` if needed | +| hybrid rewrite | `query/query.go` (ToSubGraph/build) or `dql` transform | +| Var binding | `query/query.go` `populateUidValVar` (fuse case; generalize bm25 case) | +| Scheduler skip-worker | `query/query.go` `ProcessQuery` (fuse → no worker dispatch) | +| Vector scores | `tok/index/search_path.go` (`Distances`), `tok/hnsw/persistent_hnsw.go` (populate), `worker/task.go` (`similar_to` ValueMatrix) | +| Docs | DQL docs for `fuse`/`hybrid` | diff --git a/dql/fuse_parser_test.go b/dql/fuse_parser_test.go new file mode 100644 index 00000000000..a1e223144ca --- /dev/null +++ b/dql/fuse_parser_test.go @@ -0,0 +1,244 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package dql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// findBlock returns the query block whose result variable is varName. +func findVarBlock(t *testing.T, res Result, varName string) *GraphQuery { + for _, q := range res.Query { + if q.Var == varName { + return q + } + } + t.Fatalf("no block defines var %q", varName) + return nil +} + +func TestParseFuse_ChannelsAndOptions(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "quick brown fox")) + e as var(func: similar_to(emb, 100, "[0.1, 0.2]")) + f as var(func: fuse(v, e, method: "rrf", k: 60)) + result(func: uid(f), orderdesc: val(f)) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + + fb := findVarBlock(t, res, "f") + require.NotNil(t, fb.Func) + require.Equal(t, "fuse", fb.Func.Name) + + // Channels are captured as value-variable NeedsVar in order. + require.Len(t, fb.Func.NeedsVar, 2) + require.Equal(t, "v", fb.Func.NeedsVar[0].Name) + require.Equal(t, ValueVar, fb.Func.NeedsVar[0].Typ) + require.Equal(t, "e", fb.Func.NeedsVar[1].Name) + + // Named options are captured as [key, value] arg pairs. + args := map[string]string{} + for i := 0; i+1 < len(fb.Func.Args); i += 2 { + args[fb.Func.Args[i].Value] = fb.Func.Args[i+1].Value + } + require.Equal(t, "rrf", args["method"]) + require.Equal(t, "60", args["k"]) +} + +func TestParseFuse_LinearWeights(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "fox")) + e as var(func: bm25(title, "fox")) + f as var(func: fuse(v, e, method: "linear", weights: "0.3,0.7", normalize: "max")) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + fb := findVarBlock(t, res, "f") + require.Len(t, fb.Func.NeedsVar, 2) + args := map[string]string{} + for i := 0; i+1 < len(fb.Func.Args); i += 2 { + args[fb.Func.Args[i].Value] = fb.Func.Args[i+1].Value + } + require.Equal(t, "linear", args["method"]) + require.Equal(t, "0.3,0.7", args["weights"]) + require.Equal(t, "max", args["normalize"]) +} + +func TestParseFuse_ThreeChannels(t *testing.T) { + query := ` + { + a as var(func: bm25(text, "fox")) + b as var(func: bm25(title, "fox")) + c as var(func: bm25(body, "fox")) + f as var(func: fuse(a, b, c)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + fb := findVarBlock(t, res, "f") + require.Len(t, fb.Func.NeedsVar, 3) +} + +func TestParseFuse_UnknownOption(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "fox")) + f as var(func: fuse(v, bogus: "x")) + result(func: uid(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unknown option") +} + +func TestParseFuse_DuplicateOption(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "fox")) + f as var(func: fuse(v, k: 10, k: 20)) + result(func: uid(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Duplicate key") +} + +func TestParseFuse_NoChannels(t *testing.T) { + query := ` + { + f as var(func: fuse(method: "rrf")) + result(func: uid(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "at least one value variable") +} + +func TestParseHybrid_ExpandsToThreeBlocks(t *testing.T) { + query := ` + { + f as var(func: hybrid(description, "quick brown fox", emb, "[0.1, 0.2]", topk: 50, method: "rrf", k: 60)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + + // The hybrid block is replaced by bm25 + similar_to + fuse (plus the result block). + var bm25Block, simBlock, fuseBlock *GraphQuery + for _, q := range res.Query { + if q.Func == nil { + continue + } + switch q.Func.Name { + case "bm25": + bm25Block = q + case "similar_to": + simBlock = q + case "fuse": + fuseBlock = q + case "hybrid": + t.Fatal("hybrid block should have been rewritten away") + } + } + require.NotNil(t, bm25Block, "bm25 channel block must exist") + require.NotNil(t, simBlock, "similar_to channel block must exist") + require.NotNil(t, fuseBlock, "fuse block must exist") + + // bm25 channel: predicate + query text. + require.Equal(t, "description", bm25Block.Func.Attr) + require.Equal(t, "quick brown fox", bm25Block.Func.Args[0].Value) + + // similar_to channel: predicate + topk + vector. + require.Equal(t, "emb", simBlock.Func.Attr) + require.Equal(t, "50", simBlock.Func.Args[0].Value) + + // fuse block keeps the original variable name and the fuse options. + require.Equal(t, "f", fuseBlock.Var) + require.Len(t, fuseBlock.Func.NeedsVar, 2) + args := map[string]string{} + for i := 0; i+1 < len(fuseBlock.Func.Args); i += 2 { + args[fuseBlock.Func.Args[i].Value] = fuseBlock.Func.Args[i+1].Value + } + require.Equal(t, "rrf", args["method"]) + require.Equal(t, "60", args["k"]) + // topk is consumed by similar_to, not forwarded to fuse. + require.NotContains(t, args, "topk") +} + +func TestParseHybrid_BoundsBM25Channel(t *testing.T) { + // The generated bm25 channel must be bounded to topk so a broad text query does + // not score the whole corpus before fusion. + query := ` + { + f as var(func: hybrid(description, "fox", emb, "[0.1]", topk: 25)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + for _, q := range res.Query { + if q.Func != nil && q.Func.Name == "bm25" { + require.Equal(t, "25", q.Args["first"], "bm25 channel should be capped at topk") + } + } +} + +func TestParseHybrid_MalformedOptions(t *testing.T) { + // A trailing option key without a value must be rejected, not silently dropped. + query := ` + { + f as var(func: hybrid(description, "fox", emb, "[0.1]", method)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) +} + +func TestParseHybrid_DefaultTopK(t *testing.T) { + query := ` + { + f as var(func: hybrid(description, "fox", emb, "[0.1]")) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + for _, q := range res.Query { + if q.Func != nil && q.Func.Name == "similar_to" { + require.Equal(t, "100", q.Func.Args[0].Value, "default topk should be 100") + } + } +} + +func TestParseFuse_UndefinedChannelVarErrors(t *testing.T) { + // A fuse channel referencing a variable that no block defines must be rejected + // at parse time (not silently stall the scheduler). + query := ` + { + v as var(func: bm25(text, "fox")) + f as var(func: fuse(v, ghost, method: "rrf")) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "not defined") +} + +func TestParseHybrid_RequiresVar(t *testing.T) { + query := ` + { + result(func: hybrid(description, "fox", emb, "[0.1]")) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "must be assigned to a variable") +} diff --git a/dql/hybrid.go b/dql/hybrid.go new file mode 100644 index 00000000000..15db5b742c7 --- /dev/null +++ b/dql/hybrid.go @@ -0,0 +1,173 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package dql + +import ( + "fmt" + "strings" +) + +// hybrid() is convenience sugar for the common two-channel hybrid-search case: +// BM25 text relevance fused with vector similarity. It is rewritten, before +// variable collection and dependency checking, into the equivalent explicit form: +// +// x as var(func: hybrid(textPred, "query", vecPred, $vec, topk: 100, method: "rrf", k: 60)) +// +// becomes +// +// __hybrid0_bm25 as var(func: bm25(textPred, "query"), first: 100) +// __hybrid0_vec as var(func: similar_to(vecPred, 100, $vec)) +// x as var(func: fuse(__hybrid0_bm25, __hybrid0_vec, method: "rrf", k: 60)) +// +// There is therefore no distinct hybrid execution path: it desugars entirely to +// the fuse() primitive. Positional args are textPred, "query", vecPred and the +// query vector ($var or literal); named options are topk (vector neighbor count, +// default 100) plus the fuse options method/k/weights/normalize. +const ( + hybridTopKOption = "topk" + hybridDefaultTopK = "100" + // hybridVarPrefix namespaces the synthetic channel variables a hybrid() block + // expands into. It is reserved: user variables may not start with it. + hybridVarPrefix = "__hybrid" +) + +// rewriteHybridBlocks expands every hybrid() query block in res into its three +// constituent blocks. It runs before fragment expansion, variable substitution, +// and dependency checking so the generated blocks participate normally. +func rewriteHybridBlocks(res *Result) error { + hasHybrid := false + for _, qu := range res.Query { + if qu != nil && qu.Func != nil && qu.Func.Name == hybridFunc { + hasHybrid = true + break + } + } + if !hasHybrid { + return nil + } + + // Guard against the (extremely unlikely) case of a user variable colliding with + // the synthetic channel names we generate, which would otherwise produce a + // confusing "defined multiple times" error the user can't act on. Variables can + // be defined in nested blocks too, so check the whole query tree, not just roots. + for _, qu := range res.Query { + if v, ok := findReservedHybridVar(qu); ok { + return fmt.Errorf("variable %q uses the reserved prefix %q (used internally by hybrid)", + v, hybridVarPrefix) + } + } + + expanded := make([]*GraphQuery, 0, len(res.Query)+2) + hybridIdx := 0 + for _, qu := range res.Query { + if qu == nil || qu.Func == nil || qu.Func.Name != hybridFunc { + expanded = append(expanded, qu) + continue + } + blocks, err := expandHybridBlock(qu, hybridIdx) + if err != nil { + return err + } + hybridIdx++ + expanded = append(expanded, blocks...) + } + res.Query = expanded + return nil +} + +// findReservedHybridVar walks a query block and its children for any variable using +// the reserved hybrid prefix, returning the first one found. +func findReservedHybridVar(qu *GraphQuery) (string, bool) { + if qu == nil { + return "", false + } + if strings.HasPrefix(qu.Var, hybridVarPrefix) { + return qu.Var, true + } + for _, ch := range qu.Children { + if v, ok := findReservedHybridVar(ch); ok { + return v, true + } + } + return "", false +} + +// expandHybridBlock turns a single hybrid() block into [bm25, similar_to, fuse]. +func expandHybridBlock(qu *GraphQuery, idx int) ([]*GraphQuery, error) { + if qu.Var == "" { + return nil, fmt.Errorf("hybrid must be assigned to a variable, e.g. " + + "`x as var(func: hybrid(textPred, \"query\", vecPred, $vec))`") + } + fn := qu.Func + textPred := fn.Attr + if textPred == "" { + return nil, fmt.Errorf("hybrid: missing text predicate (first argument)") + } + + // hybrid has exactly three positional args in Args (the text predicate is the + // function Attr): queryText, vecPred and the query vector. Any further args are + // key/value option pairs appended by the parser. + const numPositional = 3 + if len(fn.Args) < numPositional { + return nil, fmt.Errorf("hybrid requires textPred, \"query text\", vecPred and a "+ + "query vector; got %d positional arguments", len(fn.Args)+1) + } + queryText := fn.Args[0].Value + vecPred := fn.Args[1].Value + vecArg := fn.Args[2] + + // Options follow the positionals as key/value pairs; an odd remainder means a + // malformed option list rather than something to silently drop. + if (len(fn.Args)-numPositional)%2 != 0 { + return nil, fmt.Errorf("hybrid: malformed options (expected key:value pairs)") + } + + // Parse options: topk feeds similar_to's neighbor count; the rest feed fuse. + topk := hybridDefaultTopK + var fuseArgs []Arg + for i := numPositional; i+1 < len(fn.Args); i += 2 { + key := strings.ToLower(fn.Args[i].Value) + val := fn.Args[i+1] + if key == hybridTopKOption { + topk = val.Value + continue + } + fuseArgs = append(fuseArgs, Arg{Value: key}, val) + } + + chanBM25 := fmt.Sprintf("%s%d_bm25", hybridVarPrefix, idx) + chanVec := fmt.Sprintf("%s%d_vec", hybridVarPrefix, idx) + + // Bound the bm25 channel to the same topk candidate budget as the vector channel + // so a broad text query does not score and materialize the entire corpus before + // fusion. bm25 honors `first` with WAND top-k early termination. + bm25Block := &GraphQuery{ + Alias: "var", + Var: chanBM25, + Func: &Function{Name: "bm25", Attr: textPred, Args: []Arg{{Value: queryText}}}, + Args: map[string]string{"first": topk}, + } + simBlock := &GraphQuery{ + Alias: "var", + Var: chanVec, + Func: &Function{Name: similarToFn, Attr: vecPred, Args: []Arg{{Value: topk}, vecArg}}, + Args: map[string]string{}, + } + + channels := []VarContext{ + {Name: chanBM25, Typ: ValueVar}, + {Name: chanVec, Typ: ValueVar}, + } + fuseBlock := &GraphQuery{ + Alias: "var", + Var: qu.Var, + Func: &Function{Name: fuseFunc, NeedsVar: channels, Args: fuseArgs}, + NeedsVar: channels, + Args: map[string]string{}, + } + + return []*GraphQuery{bm25Block, simBlock, fuseBlock}, nil +} diff --git a/dql/parser.go b/dql/parser.go index 0dd6e1db7ac..930ba2d58d5 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -29,8 +29,15 @@ const ( countFunc = "count" uidInFunc = "uid_in" similarToFn = "similar_to" + fuseFunc = "fuse" + hybridFunc = "hybrid" ) +// fuseOptionKeys is the set of named options accepted by the fuse() function. +var fuseOptionKeys = map[string]struct{}{ + "method": {}, "k": {}, "weights": {}, "normalize": {}, "topk": {}, +} + var ( errExpandType = "expand is only compatible with type filters" ) @@ -699,6 +706,12 @@ func ParseWithNeedVars(r Request, needVars []string) (res Result, rerr error) { } } + // Expand any hybrid() sugar blocks into explicit bm25 + similar_to channel + // blocks plus a fuse() block before variable collection and dependency checks. + if err := rewriteHybridBlocks(&res); err != nil { + return res, err + } + if len(res.Query) != 0 { res.QueryVars = make([]*Vars, 0, len(res.Query)) for i := range res.Query { @@ -1701,7 +1714,8 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25", + fuseFunc, hybridFunc: return true } return false @@ -1749,6 +1763,10 @@ L: if function.Name == similarToFn { similarToOptSeen = make(map[string]struct{}) } + var fuseOptSeen map[string]struct{} + if function.Name == fuseFunc || function.Name == hybridFunc { + fuseOptSeen = make(map[string]struct{}) + } if _, ok := tryParseItemType(it, itemLeftRound); !ok { return nil, it.Errorf("Expected ( after func name [%s]", function.Name) } @@ -1980,6 +1998,46 @@ L: // Disallow extra positional args after (k, vec). Options must be named. return nil, itemInFunc.Errorf("Expected named parameter in similar_to options (e.g. ef: 64)") } + + // fuse(v1, v2, ..., method: "rrf", k: 60) collects bare value-variable + // names as channels (handled by the fuseFunc case in the NeedsVar switch + // below) and key:value pairs as named options. An itemName followed by a + // colon is an option; otherwise it falls through to bare-name handling. + // hybrid() shares the same named options (its positional args are handled + // by the generic bare-name path). + if itemInFunc.Typ == itemName && + (function.Name == fuseFunc || function.Name == hybridFunc) { + next, ok := it.PeekOne() + if ok && next.Typ == itemColon { + key := strings.ToLower(collectName(it, itemInFunc.Val)) + if _, valid := fuseOptionKeys[key]; !valid { + return nil, itemInFunc.Errorf("Unknown option %q in fuse", key) + } + if _, exists := fuseOptSeen[key]; exists { + return nil, itemInFunc.Errorf("Duplicate key %q in fuse options", key) + } + fuseOptSeen[key] = struct{}{} + if ok := trySkipItemTyp(it, itemColon); !ok { + return nil, it.Errorf("Expected colon(:) after %s", key) + } + if !it.Next() { + return nil, it.Errorf("Expected value for %s", key) + } + valItem := it.Item() + if valItem.Typ != itemName { + return nil, valItem.Errorf("Expected value for %s", key) + } + v := strings.Trim(collectName(it, valItem.Val), " \t") + uq, err := unquoteIfQuoted(v) + if err != nil { + return nil, err + } + function.Args = append(function.Args, Arg{Value: key}, Arg{Value: uq}) + expectArg = false + continue + } + } + if itemInFunc.Typ != itemName { return nil, itemInFunc.Errorf("Expected arg after func [%s], but got item %v", function.Name, itemInFunc) @@ -2029,7 +2087,7 @@ L: // Unlike other functions, uid function has no attribute, everything is args. switch { case len(function.Attr) == 0 && function.Name != uidFunc && - function.Name != typFunc: + function.Name != typFunc && function.Name != fuseFunc: if strings.ContainsRune(itemInFunc.Val, '"') { return nil, itemInFunc.Errorf("Attribute in function"+ @@ -2047,7 +2105,7 @@ L: } function.Lang = val expectLang = false - case function.Name != uidFunc: + case function.Name != uidFunc && function.Name != fuseFunc: // For UID function. we set g.UID function.Args = append(function.Args, Arg{Value: val}) } @@ -2058,6 +2116,12 @@ L: expectArg = false switch function.Name { + case fuseFunc: + // fuse(v1, v2, ...) takes scored value variables as channels. + function.NeedsVar = append(function.NeedsVar, VarContext{ + Name: val, + Typ: ValueVar, + }) case valueFunc: // E.g. @filter(gt(val(a), 10)) function.NeedsVar = append(function.NeedsVar, VarContext{ @@ -2099,10 +2163,15 @@ L: } } - if function.Name != uidFunc && function.Name != typFunc && len(function.Attr) == 0 { + if function.Name != uidFunc && function.Name != typFunc && function.Name != fuseFunc && + len(function.Attr) == 0 { return nil, it.Errorf("Got empty attr for function: [%s]", function.Name) } + if function.Name == fuseFunc && len(function.NeedsVar) == 0 { + return nil, it.Errorf("fuse function requires at least one value variable channel") + } + if function.Name == typFunc && len(function.Args) != 1 { return nil, it.Errorf("type function only supports one argument. Got: %v", function.Args) } diff --git a/posting/bm25.go b/posting/bm25.go new file mode 100644 index 00000000000..9dcedc91aff --- /dev/null +++ b/posting/bm25.go @@ -0,0 +1,229 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "context" + "encoding/binary" + + ostats "go.opencensus.io/stats" + + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/tok" + "github.com/dgraph-io/dgraph/v25/x" +) + +// BM25Posting is a single materialized entry of a BM25 term posting list: the +// document UID together with the term frequency and document length decoded from +// the posting value. +type BM25Posting struct { + Uid uint64 + TF uint32 + DocLen uint32 +} + +// ReadBM25TermPostings materializes the postings of a BM25 term's standard index +// list at readTs into a UID-ascending slice, decoding (tf, docLen) from each +// posting value. getList reads a posting list for a key (e.g. LocalCache.Get), +// keeping the value encoding encapsulated in this package. +func ReadBM25TermPostings(getList func(key []byte) (*List, error), attr, encodedTerm string, + readTs uint64) ([]BM25Posting, error) { + key := x.BM25IndexKey(attr, encodedTerm) + pl, err := getList(key) + if err != nil { + return nil, err + } + var out []BM25Posting + err = pl.Iterate(readTs, 0, func(p *pb.Posting) error { + tf, docLen, ok := decodeBM25Value(p.Value) + if !ok { + // Corrupt/truncated posting value: skip rather than inject a bogus + // zero-frequency match that would silently distort scoring. + return nil + } + out = append(out, BM25Posting{Uid: p.Uid, TF: tf, DocLen: docLen}) + return nil + }) + if err != nil { + return nil, err + } + return out, nil +} + +// numBM25StatsBuckets is the number of buckets the BM25 corpus statistics (document +// count and total term count) are sharded across, keyed by uid%numBM25StatsBuckets. +// Sharding spreads the read-modify-write contention of stats maintenance across +// independent posting lists so that concurrent mutations on different documents +// rarely conflict, while same-bucket updates still conflict (and retry) — avoiding +// lost updates. A single hot stats key would serialize all writes to the predicate. +const numBM25StatsBuckets = 32 + +// encodeBM25Value packs a posting's term frequency and document length into the +// posting Value as two unsigned varints. Storing the document length alongside the +// term frequency makes scoring read (tf, docLen) in a single posting access — no +// separate document-length list (which would be a write-hot key) and no random +// per-candidate lookup at query time. The document length is duplicated across a +// document's unique terms, but a document's postings are always rewritten together +// on update, so they stay consistent. +func encodeBM25Value(tf, docLen uint32) []byte { + buf := make([]byte, binary.MaxVarintLen32*2) + n := binary.PutUvarint(buf, uint64(tf)) + n += binary.PutUvarint(buf[n:], uint64(docLen)) + return buf[:n] +} + +// decodeBM25Value reverses encodeBM25Value. ok is false when the input does not +// hold two complete varints (e.g. a truncated or corrupt posting value), so the +// caller can skip it rather than silently scoring it as a zero-frequency match. +func decodeBM25Value(b []byte) (tf, docLen uint32, ok bool) { + tf64, n := binary.Uvarint(b) + if n <= 0 { + return 0, 0, false + } + docLen64, m := binary.Uvarint(b[n:]) + if m <= 0 { + return 0, 0, false + } + return uint32(tf64), uint32(docLen64), true +} + +// encodeBM25Stats encodes corpus statistics (document count, total term count) as +// two unsigned varints. +func encodeBM25Stats(docCount, totalTerms uint64) []byte { + buf := make([]byte, binary.MaxVarintLen64*2) + n := binary.PutUvarint(buf, docCount) + n += binary.PutUvarint(buf[n:], totalTerms) + return buf[:n] +} + +// decodeBM25Stats reverses encodeBM25Stats. It returns (0, 0) on malformed input. +func decodeBM25Stats(b []byte) (docCount, totalTerms uint64) { + docCount, n := binary.Uvarint(b) + if n <= 0 { + return 0, 0 + } + totalTerms, m := binary.Uvarint(b[n:]) + if m <= 0 { + return docCount, 0 + } + return docCount, totalTerms +} + +// addBM25TermPosting writes (op=SET) or removes (op=DEL) the posting for the given +// (term, uid) pair in the term's standard index posting list. On SET the posting's +// Value packs (tf, docLen); on DEL only the UID matters. The posting is a REF +// posting (ValueId set) that carries a Value — List.encode retains such postings +// through rollup (see the len(p.Value) > 0 clause there). +func (txn *Txn) addBM25TermPosting(ctx context.Context, attr, term string, uid uint64, + tf, docLen uint32, op pb.DirectedEdge_Op) error { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: op, + } + if op != pb.DirectedEdge_DEL { + edge.Value = encodeBM25Value(tf, docLen) + edge.ValueType = pb.Posting_BINARY + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + ostats.Record(ctx, x.NumEdges.M(1)) + return nil +} + +// updateBM25Stats applies (docCountDelta, totalTermsDelta) to the bucketed corpus +// statistics for attr. The bucket is selected by uid%numBM25StatsBuckets. The +// running totals are stored as a single value posting per bucket; the read at +// txn.StartTs sees this transaction's own earlier writes (read-your-own-writes), +// so multiple documents in the same transaction that land in the same bucket +// accumulate correctly. +func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, uid uint64, + docCountDelta, totalTermsDelta int64) error { + bucket := int(uid % numBM25StatsBuckets) + key := x.BM25StatsKey(attr, bucket) + // Stats are maintained by read-modify-write: we must read the committed total + // from disk (and merge this transaction's own writes), not just the in-memory + // delta. GetFromDelta skips disk and is only safe for write-only index mutations, + // so each transaction would otherwise overwrite the bucket instead of + // accumulating across transactions. Get reads committed state. + plist, err := txn.cache.Get(key) + if err != nil { + return err + } + + var docCount, totalTerms uint64 + val, err := plist.Value(txn.StartTs) + switch { + case err == nil: + if data, ok := val.Value.([]byte); ok { + docCount, totalTerms = decodeBM25Stats(data) + } + case err == ErrNoValue: + // No stats yet for this bucket; start from zero. + default: + return err + } + + docCount = applyBM25Delta(docCount, docCountDelta) + totalTerms = applyBM25Delta(totalTerms, totalTermsDelta) + + edge := &pb.DirectedEdge{ + Attr: attr, + Value: encodeBM25Stats(docCount, totalTerms), + ValueType: pb.Posting_BINARY, + Op: pb.DirectedEdge_SET, + } + return plist.addMutation(ctx, txn, edge) +} + +// applyBM25Delta adds a signed delta to an unsigned counter, clamping at zero. +func applyBM25Delta(v uint64, delta int64) uint64 { + if delta >= 0 { + return v + uint64(delta) + } + dec := uint64(-delta) + if dec > v { + return 0 + } + return v - dec +} + +// ReadBM25Stats sums the bucketed corpus statistics for attr at readTs, returning +// the document count and total term count. avgDL = totalTerms / docCount. The +// getList closure reads a posting list for a key (e.g. LocalCache.Get) so the +// caller controls caching and the read timestamp. +func ReadBM25Stats(getList func(key []byte) (*List, error), attr string, + readTs uint64) (docCount, totalTerms uint64, err error) { + for b := 0; b < numBM25StatsBuckets; b++ { + key := x.BM25StatsKey(attr, b) + pl, perr := getList(key) + if perr != nil { + return 0, 0, perr + } + val, verr := pl.Value(readTs) + if verr == ErrNoValue { + continue + } + if verr != nil { + return 0, 0, verr + } + data, ok := val.Value.([]byte) + if !ok || len(data) == 0 { + continue + } + dc, tt := decodeBM25Stats(data) + docCount += dc + totalTerms += tt + } + return docCount, totalTerms, nil +} diff --git a/posting/bm25_test.go b/posting/bm25_test.go new file mode 100644 index 00000000000..169c3be437e --- /dev/null +++ b/posting/bm25_test.go @@ -0,0 +1,174 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/x" +) + +func TestBM25ValueCodecRoundTrip(t *testing.T) { + cases := [][2]uint32{{0, 0}, {1, 1}, {3, 12}, {7, 200}, {65535, 1 << 20}, {1 << 24, 1 << 24}} + for _, c := range cases { + tf, dl, ok := decodeBM25Value(encodeBM25Value(c[0], c[1])) + require.True(t, ok) + require.Equal(t, c[0], tf) + require.Equal(t, c[1], dl) + } + // Malformed/truncated input is reported as invalid so callers can skip it. + _, _, ok := decodeBM25Value(nil) + require.False(t, ok) + _, _, ok = decodeBM25Value([]byte{0x80}) // varint continuation byte with no terminator + require.False(t, ok) +} + +// TestBM25ValueSurvivesRollup verifies the linchpin of the BM25 redesign: a REF +// index posting that carries a packed (tf, docLen) value is retained — value and +// all — through rollup, instead of being collapsed to a UID-only Pack entry (the +// default behavior for REF postings before the len(p.Value) > 0 retention clause +// in List.encode). +func TestBM25ValueSurvivesRollup(t *testing.T) { + attr := x.AttrInRootNamespace("bm25rollup") + encodedTerm := string([]byte{0x10}) + "fox" // IdentBM25 || term + key := x.BM25IndexKey(attr, encodedTerm) + + docs := []struct { + uid uint64 + tf uint32 + docLen uint32 + }{ + {uid: 5, tf: 3, docLen: 12}, + {uid: 9, tf: 1, docLen: 40}, + {uid: 100, tf: 7, docLen: 200}, + } + + ts := uint64(1) + for _, d := range docs { + l, err := GetNoStore(key, ts) + require.NoError(t, err) + edge := &pb.DirectedEdge{ + ValueId: d.uid, + Attr: attr, + Value: encodeBM25Value(d.tf, d.docLen), + ValueType: pb.Posting_BINARY, + } + addMutation(t, l, edge, Set, ts, ts+1, false) + ts += 2 + } + + // Force a rollup and decode the resulting posting list directly. + l, err := getNew(key, pstore, math.MaxUint64, false) + require.NoError(t, err) + kvs, err := l.Rollup(nil, math.MaxUint64) + require.NoError(t, err) + require.NotEmpty(t, kvs) + + var plist pb.PostingList + require.NoError(t, proto.Unmarshal(kvs[0].Value, &plist)) + + got := make(map[uint64][2]uint32) + for _, p := range plist.Postings { + tf, docLen, ok := decodeBM25Value(p.Value) + require.True(t, ok) + got[p.Uid] = [2]uint32{tf, docLen} + } + for _, d := range docs { + v, ok := got[d.uid] + require.Truef(t, ok, "uid %d posting missing after rollup (value stripped?)", d.uid) + require.Equal(t, d.tf, v[0], "tf for uid %d", d.uid) + require.Equal(t, d.docLen, v[1], "docLen for uid %d", d.uid) + } + + // Reading the list back materializes the same (uid, tf, docLen) triples. + posts, err := ReadBM25TermPostings(func(k []byte) (*List, error) { + return getNew(k, pstore, math.MaxUint64, false) + }, attr, encodedTerm, math.MaxUint64) + require.NoError(t, err) + require.Len(t, posts, len(docs)) + for _, p := range posts { + require.Equal(t, got[p.Uid][0], p.TF) + require.Equal(t, got[p.Uid][1], p.DocLen) + } +} + +// TestBM25StatsBucketed verifies that bucketed corpus statistics accumulate +// correctly across documents (including two documents that hash to the same +// bucket, exercising in-transaction read-your-own-writes) and that deletes +// subtract correctly. +func TestBM25StatsBucketed(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25stats") + ts := uint64(101) + txn := Oracle().RegisterStartTs(ts) + + // uid 1 and uid 33 both fall in bucket 1 (mod 32), exercising same-bucket + // accumulation within a single transaction. + docs := []struct { + uid uint64 + dl int64 + }{{1, 10}, {2, 20}, {33, 5}, {64, 7}, {100, 8}} + + var wantCount, wantTerms int64 + for _, d := range docs { + require.NoError(t, txn.updateBM25Stats(ctx, attr, d.uid, 1, d.dl)) + wantCount++ + wantTerms += d.dl + } + + get := func(k []byte) (*List, error) { return txn.cache.GetFromDelta(k) } + dc, tt, err := ReadBM25Stats(get, attr, ts) + require.NoError(t, err) + require.Equal(t, uint64(wantCount), dc) + require.Equal(t, uint64(wantTerms), tt) + + // Delete uid 2: docCount and totalTerms drop accordingly. + require.NoError(t, txn.updateBM25Stats(ctx, attr, 2, -1, -20)) + dc, tt, err = ReadBM25Stats(get, attr, ts) + require.NoError(t, err) + require.Equal(t, uint64(wantCount-1), dc) + require.Equal(t, uint64(wantTerms-20), tt) +} + +// TestBM25StatsAccumulateAcrossTxns verifies that stats accumulate across +// separately-committed transactions (not just within one). This guards against +// the read-modify-write reading only the in-memory delta instead of committed +// disk state, which would make each transaction overwrite its bucket and collapse +// the corpus document count. +func TestBM25StatsAccumulateAcrossTxns(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25statsxtxn") + + // Two documents in the SAME bucket (uid 5 and uid 37 → bucket 5), committed in + // two separate transactions. + commitDoc := func(startTs, commitTs, uid uint64, docLen int64) { + txn := Oracle().RegisterStartTs(startTs) + txn.cache = NewLocalCache(startTs) + require.NoError(t, txn.updateBM25Stats(ctx, attr, uid, 1, docLen)) + txn.Update() + txn.UpdateCachedKeys(commitTs) + writer := NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + } + + commitDoc(201, 202, 5, 10) + commitDoc(203, 204, 37, 6) + + // A fresh reader at a later ts must see BOTH documents (count 2, terms 16), + // not just the most recently committed one. + get := func(k []byte) (*List, error) { return GetNoStore(k, 205) } + dc, tt, err := ReadBM25Stats(get, attr, 205) + require.NoError(t, err) + require.Equal(t, uint64(2), dc, "doc count must accumulate across transactions") + require.Equal(t, uint64(16), tt, "total terms must accumulate across transactions") +} diff --git a/posting/index.go b/posting/index.go index ae6c3352a44..edb997cc4b6 100644 --- a/posting/index.go +++ b/posting/index.go @@ -68,6 +68,10 @@ func indexTokens(ctx context.Context, info *indexMutationInfo) ([]string, error) var tokens []string for _, it := range info.tokenizers { + // BM25 tokenizer is handled separately in addBM25IndexMutations. + if it.Identifier() == tok.IdentBM25 { + continue + } toks, err := tok.BuildTokens(sv.Value, tok.GetTokenizerForLang(it, lang)) if err != nil { return tokens, err @@ -179,6 +183,17 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) } } + // Check if any tokenizer is BM25 and handle separately. + for _, it := range info.tokenizers { + if _, ok := tok.GetTokenizerForLang(it, info.edge.GetLang()).(tok.BM25Tokenizer); ok { + if err := txn.addBM25IndexMutations(ctx, info); err != nil { + return []*pb.DirectedEdge{}, err + } + // Continue to process remaining non-BM25 tokenizers below. + continue + } + } + tokens, err := indexTokens(ctx, info) if err != nil { // This data is not indexable @@ -215,6 +230,58 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok return nil } +// addBM25IndexMutations handles index mutations for the BM25 tokenizer. Unlike +// other tokenizers, each BM25 index posting carries a value that packs the term +// frequency together with the document length (see encodeBM25Value). The postings +// are written through the standard delta path (plist.addMutation), so BM25 rides +// Dgraph's normal posting-list machinery — MVCC, deltas, rollup, splits, backup — +// with no separate storage path. Corpus statistics (document count and total term +// count, from which the average document length is derived) are kept in bucketed +// stats posting lists keyed by uid%numBM25StatsBuckets to avoid a single write-hot +// key while preserving conflict detection per bucket. +// +// Updates are driven entirely by the caller (AddMutationWithIndex), which issues a +// DEL for the previous value followed by a SET for the new one. The DEL re-tokenizes +// the old value and removes its postings and stats contribution; the SET adds the new +// ones. We therefore never need to detect updates here. +func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { + attr := info.edge.Attr + uid := info.edge.Entity + lang := info.edge.GetLang() + + schemaType, err := schema.State().TypeOf(attr) + if err != nil || !schemaType.IsScalar() { + return errors.Errorf("Cannot BM25 index attribute %s of type object.", attr) + } + + sv, err := types.Convert(info.val, schemaType) + if err != nil { + return err + } + + bm25Tok := tok.BM25Tokenizer{} + termFreqs, docLen, err := bm25Tok.TokensWithFrequency(sv.Value, lang) + if err != nil { + return err + } + + // Skip documents that tokenize to zero terms (e.g., all stopwords). + if docLen == 0 { + return nil + } + + for term, tf := range termFreqs { + if err := txn.addBM25TermPosting(ctx, attr, term, uid, tf, docLen, info.op); err != nil { + return err + } + } + + if info.op == pb.DirectedEdge_DEL { + return txn.updateBM25Stats(ctx, attr, uid, -1, -int64(docLen)) + } + return txn.updateBM25Stats(ctx, attr, uid, 1, int64(docLen)) +} + // countParams is sent to updateCount function. It is used to update the count index. // It deletes the uid from the key corresponding to and adds it // to . diff --git a/posting/list.go b/posting/list.go index 1c0c7a0fc55..610eaf5b5c2 100644 --- a/posting/list.go +++ b/posting/list.go @@ -1627,7 +1627,14 @@ func (l *List) encode(out *rollupOutput, readTs uint64, split bool) error { } enc.Add(p.Uid) - if p.Facets != nil || p.PostingType != pb.Posting_REF { + // Retain the full posting (not just its UID in the Pack) whenever it + // carries facets, is not a plain UID reference, or carries a value. + // BM25 index postings are REF postings that pack (term-frequency, + // doc-length) into Value; without the len(p.Value) > 0 clause that + // value would be stripped at rollup, silently losing all term + // frequencies. This mirrors how faceted postings already coexist in + // both Pack (UID) and Postings (payload). + if p.Facets != nil || p.PostingType != pb.Posting_REF || len(p.Value) > 0 { plist.Postings = append(plist.Postings, p) } return nil diff --git a/posting/mvcc.go b/posting/mvcc.go index 81c5e375553..108cdfc3b3e 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -318,6 +318,7 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { return err } } + return nil } diff --git a/query/common_test.go b/query/common_test.go index e36211f7a18..64a88bd865e 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -390,6 +390,15 @@ func populateCluster(dc dgraphapi.Cluster) { testSchema += "\ndescription: string @index(ngram) ." } + // BM25 indexing - uses same version gate as ngram for now + if ngramSupport { + testSchema += "\ndescription_bm25: string @index(bm25) ." + // A parallel vector predicate on the same documents enables hybrid-search + // (fuse) tests that combine BM25 with vector similarity. Gated together with + // BM25 so both channels are always present for those tests. + testSchema += "\ndescription_vec: float32vector @index(hnsw(metric:\"euclidean\")) ." + } + setSchema(testSchema) err = addTriplesToCluster(` @@ -1007,4 +1016,31 @@ func populateCluster(dc dgraphapi.Cluster) { <415> "Linguistic analysis helps understand text meaning" . `) x.Panic(err) + + // Add data for BM25 tests - uses separate predicate to avoid conflicts + err = addTriplesToCluster(` + <501> "The quick brown fox jumps over the lazy dog" . + <502> "A quick brown fox leaps over a sleeping dog" . + <503> "fox fox fox" . + <504> "The lazy dog sleeps under the warm sun all day long in the garden" . + <505> "Dogs are loyal companions to humans and families everywhere" . + <506> "Quick movements help foxes catch their prey in the wild" . + <507> "Brown foxes are quick and agile animals in the forest" . + `) + x.Panic(err) + + // Vector embeddings on the same documents (dims = [fox, dog, quick, brown]). + // Chosen so a "pure fox" query vector [3,0,0,0] ranks 503 first, then the + // fox/quick docs (506,507,501,502), and the dog-only docs (504,505) last — + // letting hybrid (fuse) tests observe both channels and union semantics. + err = addTriplesToCluster(` + <501> "[1.0, 1.0, 1.0, 1.0]" . + <502> "[1.0, 1.0, 1.0, 1.0]" . + <503> "[3.0, 0.0, 0.0, 0.0]" . + <504> "[0.0, 2.0, 0.0, 0.0]" . + <505> "[0.0, 2.0, 0.0, 0.0]" . + <506> "[1.0, 0.0, 1.0, 0.0]" . + <507> "[1.0, 0.0, 1.0, 1.0]" . + `) + x.Panic(err) } diff --git a/query/fuse.go b/query/fuse.go new file mode 100644 index 00000000000..3d36c7e4796 --- /dev/null +++ b/query/fuse.go @@ -0,0 +1,373 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package query + +import ( + "math" + "sort" + "strconv" + "strings" + + "github.com/dgraph-io/dgraph/v25/dql" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/types" + "github.com/pkg/errors" +) + +// This file implements the pure, I/O-free core of native hybrid search: combining +// several already-scored result sets ("channels") into a single ranked list. +// +// A channel is a uid->score map produced by an upstream ranker bound to a DQL value +// variable (e.g. bm25(...) or similar_to(...)). Fusion is a coordinator-side +// operation over resolved value variables; everything here is deterministic and +// independent of storage, sharding, and the query pipeline so it can be tested in +// isolation. The query-layer adapter (see populateUidValVar) converts value +// variables into fuseChannels and the result back into a value variable. + +// fusionMethod selects how channel scores are combined. +type fusionMethod int + +const ( + // fusionRRF is Reciprocal Rank Fusion: each channel contributes 1/(k+rank), + // where rank is the 1-based position of the uid within that channel. Robust to + // heterogeneous score scales because it uses ranks, not raw scores. + fusionRRF fusionMethod = iota + // fusionLinear is a weighted sum of (optionally normalized) raw scores. + fusionLinear +) + +// linearNormalize selects score normalization for fusionLinear. +type linearNormalize int + +const ( + // normalizeMax divides each channel's scores by that channel's maximum absolute + // score, bringing heterogeneous scales onto a comparable [-1,1]-ish range. + normalizeMax linearNormalize = iota + // normalizeNone uses raw scores as-is (the caller asserts comparability). + normalizeNone +) + +// defaultRRFK is the conventional RRF rank constant. Larger k flattens the +// contribution of top ranks; 60 is the widely used default. +const defaultRRFK = 60.0 + +// fuseChannel is one scored input to fusion. scores maps uid -> raw score with the +// convention that higher is always better (all Dgraph rankers surface +// higher-is-better scores). weight applies only to fusionLinear. +type fuseChannel struct { + scores map[uint64]float64 + weight float64 +} + +// fuseOpts configures a fusion run. +type fuseOpts struct { + method fusionMethod + k float64 // RRF rank constant; <=0 falls back to defaultRRFK. + normalize linearNormalize // linear only. + topk int // if >0, truncate output to the top topk results. +} + +// scoredUid is a uid paired with its fused score. +type scoredUid struct { + uid uint64 + score float64 +} + +// fuseChannels combines channels into a single ranked list, sorted by fused score +// descending and tie-broken by uid ascending. The candidate set is the UNION of all +// channels' uids (outer join): a uid missing from a channel simply receives no +// contribution from it (RRF: as if rank = infinity; linear: as if score = 0). It is +// never dropped and never produces NaN. +func fuseChannels(channels []fuseChannel, opts fuseOpts) []scoredUid { + var fused map[uint64]float64 + switch opts.method { + case fusionLinear: + fused = fuseLinear(channels, opts.normalize) + default: + fused = fuseRRF(channels, opts.k) + } + + out := make([]scoredUid, 0, len(fused)) + for uid, s := range fused { + out = append(out, scoredUid{uid: uid, score: s}) + } + sort.Slice(out, func(i, j int) bool { + if out[i].score != out[j].score { + return out[i].score > out[j].score + } + return out[i].uid < out[j].uid + }) + if opts.topk > 0 && len(out) > opts.topk { + out = out[:opts.topk] + } + return out +} + +// channelRanks returns the 1-based rank of every uid in a channel, computed by +// sorting on score descending and tie-breaking by uid ascending. The tie-break +// matches the deterministic ordering used elsewhere in the codebase (bm25/HNSW +// sorted()), so equal-scored uids rank stably by uid. +func channelRanks(c fuseChannel) map[uint64]int { + uids := make([]uint64, 0, len(c.scores)) + for uid := range c.scores { + uids = append(uids, uid) + } + sort.Slice(uids, func(i, j int) bool { + si, sj := c.scores[uids[i]], c.scores[uids[j]] + if si != sj { + return si > sj + } + return uids[i] < uids[j] + }) + ranks := make(map[uint64]int, len(uids)) + for i, uid := range uids { + ranks[uid] = i + 1 + } + return ranks +} + +// fuseRRF computes (weighted) Reciprocal Rank Fusion over the channels. Each +// channel contributes weight * 1/(k+rank); with the default weight of 1.0 this is +// standard RRF, and per-channel weights let callers bias channels under either +// fusion method rather than silently ignoring weights for rrf. +func fuseRRF(channels []fuseChannel, k float64) map[uint64]float64 { + if k <= 0 || math.IsNaN(k) || math.IsInf(k, 0) { + k = defaultRRFK + } + fused := make(map[uint64]float64) + for _, c := range channels { + for uid, rank := range channelRanks(c) { + fused[uid] += c.weight * (1.0 / (k + float64(rank))) + } + } + return fused +} + +// channelMaxAbs returns the maximum finite absolute score in a channel, used to +// max-normalize heterogeneous score scales. Non-finite scores are ignored; +// returns 0 for an empty channel (callers treat a 0 denominator as "contribute 0"). +func channelMaxAbs(c fuseChannel) float64 { + var maxAbs float64 + for _, s := range c.scores { + if math.IsNaN(s) || math.IsInf(s, 0) { + continue + } + if a := math.Abs(s); a > maxAbs { + maxAbs = a + } + } + return maxAbs +} + +// fuseLinear computes a weighted sum of (optionally max-normalized) raw scores. A +// uid missing from a channel contributes 0 from that channel. A channel whose +// maximum absolute score is 0 (all zeros / empty) contributes 0 rather than +// dividing by zero. +func fuseLinear(channels []fuseChannel, normalize linearNormalize) map[uint64]float64 { + denoms := make([]float64, len(channels)) + for i, c := range channels { + if normalize == normalizeMax { + denoms[i] = channelMaxAbs(c) + } else { + denoms[i] = 1.0 + } + } + + fused := make(map[uint64]float64) + for i, c := range channels { + denom := denoms[i] + for uid, s := range c.scores { + norm := s + if normalize == normalizeMax { + if denom == 0 { + norm = 0 + } else { + norm = s / denom + } + } + fused[uid] += c.weight * norm + } + } + // Ensure uids that appear only in zero-contribution channels are still present + // in the union (e.g. an all-zero max-normalized channel). The loop above already + // inserts them with a running sum (possibly 0), so the union is complete. + return fused +} + +// --- Query-layer adapter ----------------------------------------------------- +// +// The functions below bridge the pure fusion core to DQL value variables. They +// parse the fuse() options, read each channel's scores from the already-resolved +// variable map, run fusion, and return a varValue carrying both the union uid set +// and the fused uid->score map (the same shape the bm25 ranker binds). + +// parseFuseOpts extracts fuse() options from the function's key/value arg pairs and +// returns the resolved fuseOpts plus the optional per-channel linear weights +// (nil when unspecified). numChannels is used to validate the weights count. +func parseFuseOpts(args []dql.Arg, numChannels int) (fuseOpts, []float64, error) { + opts := fuseOpts{method: fusionRRF, k: defaultRRFK, normalize: normalizeMax} + var weights []float64 + + for i := 0; i+1 < len(args); i += 2 { + key := strings.ToLower(args[i].Value) + val := args[i+1].Value + switch key { + case "method": + switch strings.ToLower(val) { + case "rrf": + opts.method = fusionRRF + case "linear": + opts.method = fusionLinear + default: + return opts, nil, errors.Errorf("fuse: unknown method %q (want rrf or linear)", val) + } + case "k": + k, err := strconv.ParseFloat(val, 64) + if err != nil || k <= 0 || math.IsNaN(k) || math.IsInf(k, 0) { + return opts, nil, errors.Errorf("fuse: k must be a positive finite number, got %q", val) + } + opts.k = k + case "normalize": + switch strings.ToLower(val) { + case "max": + opts.normalize = normalizeMax + case "none": + opts.normalize = normalizeNone + default: + return opts, nil, errors.Errorf("fuse: unknown normalize %q (want max or none)", val) + } + case "topk": + tk, err := strconv.Atoi(val) + if err != nil || tk < 0 { + return opts, nil, errors.Errorf("fuse: topk must be a non-negative integer, got %q", val) + } + opts.topk = tk + case "weights": + parts := strings.Split(val, ",") + weights = make([]float64, 0, len(parts)) + for _, p := range parts { + w, err := strconv.ParseFloat(strings.TrimSpace(p), 64) + if err != nil || math.IsNaN(w) || math.IsInf(w, 0) { + return opts, nil, errors.Errorf("fuse: invalid weight %q", p) + } + weights = append(weights, w) + } + default: + return opts, nil, errors.Errorf("fuse: unknown option %q", key) + } + } + + if weights != nil && len(weights) != numChannels { + return opts, nil, errors.Errorf("fuse: weights count (%d) must match channel count (%d)", + len(weights), numChannels) + } + return opts, weights, nil +} + +// computeFuse reads the channel value variables named in the fuse function's +// NeedsVar from doneVars, runs fusion, and returns a varValue with the union uid +// set and the fused uid->score map. +func computeFuse(args []dql.Arg, needsVar []dql.VarContext, + doneVars map[string]varValue, sgPath []*SubGraph) (varValue, error) { + + if len(needsVar) == 0 { + return varValue{}, errors.Errorf("fuse: requires at least one value variable channel") + } + + opts, weights, err := parseFuseOpts(args, len(needsVar)) + if err != nil { + return varValue{}, err + } + + channels := make([]fuseChannel, len(needsVar)) + for i, nv := range needsVar { + v, ok := doneVars[nv.Name] + switch { + case !ok: + // The dependency scheduler guarantees every channel block has run and + // populated doneVars before this fuse block. A genuinely absent channel + // therefore signals an internal invariant violation rather than an empty + // result — surface it instead of silently degrading the fusion. + return varValue{}, errors.Errorf("fuse: channel %q was not produced", nv.Name) + case v.Vals == nil || v.Vals.IsEmpty(): + // A channel that ran but matched nothing is a valid empty channel: it + // contributes nothing but must not drop the other channels' results. + channels[i] = fuseChannel{scores: map[uint64]float64{}, weight: 1.0} + default: + scores, err := scoresFromVar(v, nv.Name) + if err != nil { + return varValue{}, err + } + channels[i] = fuseChannel{scores: scores, weight: 1.0} + } + if weights != nil { + channels[i].weight = weights[i] + } + } + + fused := fuseChannels(channels, opts) + + out := varValue{Vals: types.NewShardedMap(), path: sgPath} + uids := make([]uint64, len(fused)) + for i, r := range fused { + uids[i] = r.uid + out.Vals.Set(r.uid, types.Val{Tid: types.FloatID, Value: r.score}) + } + // Emit the uid set in ascending order — the Dgraph value-variable contract (the + // same one bm25 follows): Uids is the unordered candidate set used by uid(var), + // and the fused score lives in Vals. Callers recover ranked order with + // `orderdesc: val(var)`. When `topk` is set, fuseChannels has already selected + // the top-k by fused score before this ascending sort, so top-k + orderdesc is + // correct. (Sorting here by score would break uid(var) set semantics.) + sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) + out.Uids = &pb.List{Uids: uids} + return out, nil +} + +// scoresFromVar extracts a uid->float64 score map from a value variable. The +// variable must carry numeric scores (as bm25/similar_to bind); a uid variable +// without scores cannot be a fusion channel. +func scoresFromVar(v varValue, name string) (map[uint64]float64, error) { + scores := make(map[uint64]float64, v.Vals.Len()) + var convErr error + err := v.Vals.Iterate(func(uid uint64, val types.Val) error { + var s float64 + // bm25 and similar_to bind FloatID scores directly; take that fast path so a + // non-finite value (which types.Convert rejects) is dropped rather than + // mis-reported as "non-numeric". Other numeric types go through Convert. + if val.Tid == types.FloatID { + f, ok := val.Value.(float64) + if !ok { + convErr = errors.Errorf("fuse: channel %q has a malformed score value", name) + return convErr + } + s = f + } else { + f, err := types.Convert(val, types.FloatID) + if err != nil { + convErr = errors.Errorf("fuse: channel %q has non-numeric scores; use a "+ + "ranker such as bm25 or similar_to", name) + return convErr + } + s = f.Value.(float64) + } + // Drop non-finite scores so they can never poison fusion: a NaN/Inf would + // break the sort comparator's strict-weak-ordering and propagate through + // linear sums. A uid dropped here simply doesn't participate via this channel. + if math.IsNaN(s) || math.IsInf(s, 0) { + return nil + } + scores[uid] = s + return nil + }) + if convErr != nil { + return nil, convErr + } + if err != nil { + return nil, err + } + return scores, nil +} diff --git a/query/fuse_test.go b/query/fuse_test.go new file mode 100644 index 00000000000..af1450f66a9 --- /dev/null +++ b/query/fuse_test.go @@ -0,0 +1,229 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package query + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/types" +) + +// ch is a small helper to build a fusion channel from a uid->score map with a +// default weight of 1.0. +func ch(scores map[uint64]float64) fuseChannel { + return fuseChannel{scores: scores, weight: 1.0} +} + +// asMap collapses a fused result slice into a uid->score map for assertions that +// don't care about ordering. +func asMap(res []scoredUid) map[uint64]float64 { + m := make(map[uint64]float64, len(res)) + for _, r := range res { + m[r.uid] = r.score + } + return m +} + +func TestFuseRRF_BasicRanks(t *testing.T) { + // Channel A order: 10, 20, 30 (ranks 1,2,3) + // Channel B order: 30, 10, 40 (ranks 1,2,3) + a := ch(map[uint64]float64{10: 9.0, 20: 5.0, 30: 1.0}) + b := ch(map[uint64]float64{30: 0.9, 10: 0.5, 40: 0.1}) + + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + + const k = 60.0 + // uid 10: rank1 in A, rank2 in B + require.InDelta(t, 1/(k+1)+1/(k+2), got[10], 1e-9) + // uid 20: rank2 in A only + require.InDelta(t, 1/(k+2), got[20], 1e-9) + // uid 30: rank3 in A, rank1 in B + require.InDelta(t, 1/(k+3)+1/(k+1), got[30], 1e-9) + // uid 40: rank3 in B only + require.InDelta(t, 1/(k+3), got[40], 1e-9) +} + +func TestFuseRRF_OrderingAndUnion(t *testing.T) { + a := ch(map[uint64]float64{10: 9.0, 20: 5.0, 30: 1.0}) + b := ch(map[uint64]float64{30: 0.9, 10: 0.5, 40: 0.1}) + + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + + // Union of all uids is present (outer join, not intersection). + require.Len(t, res, 4) + // Sorted by fused score descending. uid 10 and 30 both appear in both channels + // near the top; 10 is rank1+rank2, 30 is rank3+rank1 -> 10 slightly higher. + require.Equal(t, uint64(10), res[0].uid) + require.Equal(t, uint64(30), res[1].uid) + // Scores must be monotonically non-increasing. + for i := 1; i < len(res); i++ { + require.LessOrEqual(t, res[i].score, res[i-1].score) + } +} + +func TestFuseRRF_DefaultK(t *testing.T) { + a := ch(map[uint64]float64{1: 1.0}) + // k<=0 should fall back to the default of 60. + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 0}) + require.InDelta(t, 1/(60.0+1), res[0].score, 1e-9) +} + +func TestFuseRRF_TieBreakByUidAscending(t *testing.T) { + // Equal scores within a channel -> lower uid gets the better (smaller) rank. + a := ch(map[uint64]float64{2: 5.0, 1: 5.0, 3: 5.0}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + // uid 1 rank1, uid 2 rank2, uid 3 rank3. + require.InDelta(t, 1/(60.0+1), got[1], 1e-9) + require.InDelta(t, 1/(60.0+2), got[2], 1e-9) + require.InDelta(t, 1/(60.0+3), got[3], 1e-9) + // Final output tie-broken by uid ascending when fused scores are equal. + require.Equal(t, uint64(1), res[0].uid) +} + +func TestFuseRRF_DisjointChannels(t *testing.T) { + a := ch(map[uint64]float64{1: 9.0, 2: 8.0}) + b := ch(map[uint64]float64{3: 9.0, 4: 8.0}) + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + require.Len(t, res, 4) + got := asMap(res) + // Each uid scored only by its single channel rank. + require.InDelta(t, 1/(60.0+1), got[1], 1e-9) + require.InDelta(t, 1/(60.0+1), got[3], 1e-9) + require.InDelta(t, 1/(60.0+2), got[2], 1e-9) + require.InDelta(t, 1/(60.0+2), got[4], 1e-9) +} + +func TestFuseRRF_AppliesWeights(t *testing.T) { + // Weights must affect RRF (not only linear): a uid ranked #1 in a 2x-weighted + // channel should beat a uid ranked #1 in a unit-weighted channel. + heavy := fuseChannel{scores: map[uint64]float64{1: 9.0}, weight: 2.0} + light := fuseChannel{scores: map[uint64]float64{2: 9.0}, weight: 1.0} + res := fuseChannels([]fuseChannel{heavy, light}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + require.InDelta(t, 2.0*(1/(60.0+1)), got[1], 1e-9) + require.InDelta(t, 1.0*(1/(60.0+1)), got[2], 1e-9) + require.Equal(t, uint64(1), res[0].uid, "heavier-weighted channel's top doc wins") +} + +func TestFuseRRF_DefaultWeightIsStandardRRF(t *testing.T) { + // With the default weight of 1.0, weighted RRF reduces to standard RRF. + a := ch(map[uint64]float64{10: 9.0, 20: 5.0}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + require.InDelta(t, 1/(60.0+1), got[10], 1e-9) + require.InDelta(t, 1/(60.0+2), got[20], 1e-9) +} + +func TestFuseLinear_MaxNormalizeAndWeights(t *testing.T) { + // BM25-ish scale vs cosine-ish scale. + text := fuseChannel{scores: map[uint64]float64{1: 10.0, 2: 5.0}, weight: 0.3} + vec := fuseChannel{scores: map[uint64]float64{1: 0.8, 2: 0.4, 3: 0.2}, weight: 0.7} + + res := fuseChannels([]fuseChannel{text, vec}, + fuseOpts{method: fusionLinear, normalize: normalizeMax}) + got := asMap(res) + + // max-normalize: text/10, vec/0.8. + // uid1: 0.3*(10/10) + 0.7*(0.8/0.8) = 0.3 + 0.7 = 1.0 + require.InDelta(t, 1.0, got[1], 1e-9) + // uid2: 0.3*(5/10) + 0.7*(0.4/0.8) = 0.15 + 0.35 = 0.5 + require.InDelta(t, 0.5, got[2], 1e-9) + // uid3: only vec: 0.7*(0.2/0.8) = 0.175 (text contributes 0, not NaN) + require.InDelta(t, 0.175, got[3], 1e-9) + + require.Equal(t, uint64(1), res[0].uid) +} + +func TestFuseLinear_NoNormalize(t *testing.T) { + a := fuseChannel{scores: map[uint64]float64{1: 2.0, 2: 1.0}, weight: 1.0} + b := fuseChannel{scores: map[uint64]float64{1: 3.0}, weight: 2.0} + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionLinear, normalize: normalizeNone}) + got := asMap(res) + // uid1: 1*2 + 2*3 = 8 ; uid2: 1*1 = 1 + require.InDelta(t, 8.0, got[1], 1e-9) + require.InDelta(t, 1.0, got[2], 1e-9) +} + +func TestFuseLinear_ZeroMaxChannelContributesZero(t *testing.T) { + // A channel whose scores are all zero must not divide-by-zero / NaN. + a := fuseChannel{scores: map[uint64]float64{1: 0.0, 2: 0.0}, weight: 1.0} + b := fuseChannel{scores: map[uint64]float64{1: 4.0}, weight: 1.0} + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionLinear, normalize: normalizeMax}) + got := asMap(res) + require.False(t, math.IsNaN(got[1])) + require.False(t, math.IsNaN(got[2])) + // uid1: a contributes 0, b contributes 4/4=1 -> 1.0 + require.InDelta(t, 1.0, got[1], 1e-9) + // uid2: only in a (all-zero) -> 0.0 + require.InDelta(t, 0.0, got[2], 1e-9) +} + +func TestFuse_TopKTruncation(t *testing.T) { + a := ch(map[uint64]float64{1: 9, 2: 8, 3: 7, 4: 6, 5: 5}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60, topk: 3}) + require.Len(t, res, 3) + require.Equal(t, uint64(1), res[0].uid) + require.Equal(t, uint64(3), res[2].uid) +} + +func TestFuse_SingleChannelPassthroughOrder(t *testing.T) { + a := ch(map[uint64]float64{1: 1, 2: 9, 3: 5}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60}) + // Order should reflect channel ranking: 2 (rank1), 3 (rank2), 1 (rank3). + require.Equal(t, []uint64{2, 3, 1}, []uint64{res[0].uid, res[1].uid, res[2].uid}) +} + +func TestFuse_EmptyChannels(t *testing.T) { + res := fuseChannels([]fuseChannel{ch(nil), ch(map[uint64]float64{})}, + fuseOpts{method: fusionRRF, k: 60}) + require.Empty(t, res) +} + +func TestScoresFromVar_DropsNonFinite(t *testing.T) { + // Non-finite scores from a channel must be dropped so they can't break the sort + // comparator or poison linear sums. + m := types.NewShardedMap() + m.Set(1, types.Val{Tid: types.FloatID, Value: 0.5}) + m.Set(2, types.Val{Tid: types.FloatID, Value: math.NaN()}) + m.Set(3, types.Val{Tid: types.FloatID, Value: math.Inf(1)}) + m.Set(4, types.Val{Tid: types.FloatID, Value: math.Inf(-1)}) + m.Set(5, types.Val{Tid: types.FloatID, Value: 2.0}) + + scores, err := scoresFromVar(varValue{Vals: m}, "ch") + require.NoError(t, err) + require.Len(t, scores, 2, "only finite scores should survive") + require.Contains(t, scores, uint64(1)) + require.Contains(t, scores, uint64(5)) + require.NotContains(t, scores, uint64(2)) + require.NotContains(t, scores, uint64(3)) + require.NotContains(t, scores, uint64(4)) +} + +func TestFuseLinear_NonFiniteChannelDoesNotPoison(t *testing.T) { + // Even if a NaN slips into a channel passed directly to the core, max-normalize + // must not produce a NaN denominator that propagates. + bad := fuseChannel{scores: map[uint64]float64{1: math.NaN(), 2: math.NaN()}, weight: 1.0} + good := fuseChannel{scores: map[uint64]float64{1: 4.0}, weight: 1.0} + res := fuseChannels([]fuseChannel{bad, good}, fuseOpts{method: fusionLinear, normalize: normalizeMax}) + for _, r := range res { + require.False(t, math.IsNaN(r.score), "uid %d score must not be NaN", r.uid) + } +} + +func TestFuse_Determinism(t *testing.T) { + a := ch(map[uint64]float64{1: 5, 2: 5, 3: 5}) + b := ch(map[uint64]float64{3: 1, 2: 1, 1: 1}) + first := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + for i := 0; i < 20; i++ { + again := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + require.Equal(t, first, again) + } +} diff --git a/query/query.go b/query/query.go index 6926e2ac6ed..c48cf2d11fe 100644 --- a/query/query.go +++ b/query/query.go @@ -7,6 +7,7 @@ package query import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -268,6 +269,15 @@ type SubGraph struct { // In graph terms, a list is a slice of outgoing edges from a node. uidMatrix []*pb.List + // rankerScores maps a matched document UID to its ranker score (BM25 relevance + // or vector similarity). It is snapshotted from the (uid-aligned) worker result + // the moment it arrives, before filters or pagination can shrink/reorder + // uidMatrix out of step with valueMatrix. populateUidValVar binds the score + // variable from this map keyed by UID, so the score stays correct even when the + // ranker block carries an @filter. nil unless the source function is a ranker + // (bm25 / similar_to). + rankerScores map[uint64]float64 + // facetsMatrix contains the facet values. There would a list corresponding to each uid in // uidMatrix. facetsMatrix []*pb.FacetsList @@ -373,6 +383,19 @@ func getValue(tv *pb.TaskValue) (types.Val, error) { return val, nil } +func valToTaskValue(v types.Val) *pb.TaskValue { + data := types.ValueForType(types.BinaryID) + res := &pb.TaskValue{ValType: v.Tid.Enum(), Val: x.Nilbyte} + if v.Value == nil { + return res + } + if err := types.Marshal(v, &data); err != nil { + return res + } + res.Val = data.Value.([]byte) + return res +} + var ( // ErrEmptyVal is returned when a value is empty. ErrEmptyVal = errors.New("Query: harmless error, e.g. task.Val is nil") @@ -1556,6 +1579,17 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su var ok bool switch { + case sg.SrcFunc != nil && sg.SrcFunc.Name == "fuse": + // Native hybrid search: fuse() combines several already-scored value + // variables (its NeedsVar channels) into one ranked value variable. Fusion is + // a coordinator-side operation over resolved variables — the channel blocks + // have already populated doneVars by the time this block is scheduled. We bind + // both the union uid set (uid(var)) and the uid->fused-score map (val(var)). + fv, err := computeFuse(sg.SrcFunc.Args, sg.Params.NeedsVar, doneVars, sgPath) + if err != nil { + return err + } + doneVars[sg.Params.Var] = fv case len(sg.counts) > 0: // 1. When count of a predicate is assigned a variable, we store the mapping of uid => // count(predicate). @@ -1591,6 +1625,27 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su Value: int64(len(sg.SrcUIDs.Uids)), } doneVars[sg.Params.Var].Vals.Set(math.MaxUint64, val) + case sg.SrcFunc != nil && (sg.SrcFunc.Name == "bm25" || sg.SrcFunc.Name == "similar_to") && + sg.rankerScores != nil: + // A query-side ranker (BM25 relevance or vector similarity) binds its + // per-document score as a value variable. We populate BOTH the matched uid set + // and the uid->score map so the variable works with uid(var), val(var) and + // orderdesc: val(var) — surfacing and ordering by score without a + // pseudo-predicate or a ParentVars channel. Scores are looked up from the + // uid-keyed snapshot taken at result time (sg.rankerScores), so they remain + // correct even after an @filter on the ranker block shrinks DestUIDs. For + // similar_to the score is a higher-is-better similarity; this also lets vector + // results feed fuse(). + if v, ok = doneVars[sg.Params.Var]; !ok { + v = varValue{Vals: types.NewShardedMap(), path: sgPath, strList: sg.valueMatrix} + } + v.Uids = sg.DestUIDs + for _, uid := range sg.DestUIDs.GetUids() { + if score, has := sg.rankerScores[uid]; has { + v.Vals.Set(uid, types.Val{Tid: types.FloatID, Value: score}) + } + } + doneVars[sg.Params.Var] = v case len(sg.DestUIDs.Uids) != 0 || (sg.Attr == "uid" && sg.SrcUIDs != nil): // 3. A uid variable. The variable could be defined in one of two places. // a) Either on the actual predicate. @@ -2173,6 +2228,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { rch <- nil return } + var err error switch { case parent == nil && sg.SrcFunc != nil && sg.SrcFunc.Name == "uid": @@ -2275,6 +2331,27 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.List = result.List sg.vectorMetrics = result.VectorMetrics + // bm25 and similar_to return their per-document scores in valueMatrix + // positionally aligned with uidMatrix[0]. Snapshot them into a uid-keyed + // map now, while the two are still aligned — later filters/pagination + // shrink uidMatrix without touching valueMatrix, which would otherwise + // misbind scores to UIDs (and feed wrong scores into fuse() channels). + if sg.SrcFunc != nil && (sg.SrcFunc.Name == "bm25" || sg.SrcFunc.Name == "similar_to") && + len(result.UidMatrix) > 0 { + uids := result.UidMatrix[0].GetUids() + sg.rankerScores = make(map[uint64]float64, len(uids)) + for idx, uid := range uids { + if idx >= len(result.ValueMatrix) || len(result.ValueMatrix[idx].Values) == 0 { + continue + } + tv := result.ValueMatrix[idx].Values[0] + if len(tv.Val) != 8 { + continue + } + sg.rankerScores[uid] = math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) + } + } + if sg.Params.DoCount { if len(sg.Filters) == 0 { // If there is a filter, we need to do more work to get the actual count. @@ -2373,9 +2450,12 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } if len(sg.Params.Order) == 0 && len(sg.Params.FacetsOrder) == 0 { - // for `has` function when there is no filtering and ordering, we fetch - // correct paginated results so no need to apply pagination here. - if !(len(sg.Filters) == 0 && sg.SrcFunc != nil && sg.SrcFunc.Name == "has") { + // For `has` and `bm25`, the worker already returns correctly paginated + // results (bm25 paginates over score order, which the uid-sorted query-layer + // pagination cannot reproduce), so applying pagination again here would + // double-apply first/offset. Skip it when there is no filtering/ordering. + if !(len(sg.Filters) == 0 && sg.SrcFunc != nil && + (sg.SrcFunc.Name == "has" || sg.SrcFunc.Name == "bm25")) { // There is no ordering. Just apply pagination and return. if err = sg.applyPagination(ctx); err != nil { rch <- err @@ -2452,6 +2532,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } child.SrcUIDs = sg.DestUIDs // Make the connection. + if child.IsInternal() { // We dont have to execute these nodes. continue @@ -2751,7 +2832,8 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25", + "fuse": return true } return isInequalityFn(f) || types.IsGeoFunc(f) @@ -2959,6 +3041,15 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { continue } + // fuse() is a coordinator-side fusion over already-resolved value + // variables; it is never dispatched to a worker. The fused variable is + // produced from doneVars in populateVarMap (the fuse case in + // populateUidValVar) once its channel variables are populated. + if sg.SrcFunc != nil && sg.SrcFunc.Name == "fuse" { + errChan <- nil + continue + } + switch { case sg.Params.Alias == "shortest": // We allow only one shortest path block per query. diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go new file mode 100644 index 00000000000..8a16c31f1a3 --- /dev/null +++ b/query/query_bm25_test.go @@ -0,0 +1,1029 @@ +//go:build integration || cloud + +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +//nolint:lll +package query + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// uidHex queries Dgraph for the hex UID string of a given decimal UID. +// This avoids hardcoding hex values that depend on UID assignment order. +func uidHex(t *testing.T, decimalUID int) string { + t.Helper() + js := processQueryNoErr(t, fmt.Sprintf(`{ me(func: uid(%d)) { uid } }`, decimalUID)) + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me, "UID %d should exist", decimalUID) + return resp.Data.Me[0].UID +} + +func TestBM25Basic(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick brown fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return documents containing "quick", "brown", or "fox" + require.Contains(t, js, "quick brown fox jumps") + require.Contains(t, js, "quick brown fox leaps") +} + +func TestBM25Ordering(t *testing.T) { + // BM25 returns all matching documents. Use first:1 to verify the highest-scored + // document is "fox fox fox" (tf=3, short doc). + query := ` + { + me(func: bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should contain all fox-mentioning documents. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + + // first:1 should return the top-ranked document. + topQuery := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + topJs := processQueryNoErr(t, topQuery) + require.Contains(t, topJs, "fox fox fox", + "top-1 BM25 result for 'fox' should be 'fox fox fox' (highest tf, shortest doc)") +} + +func TestBM25WithParams(t *testing.T) { + // Custom k and b parameters + query := ` + { + me(func: bm25(description_bm25, "fox", "1.5", "0.5")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") +} + +func TestBM25InvalidParams(t *testing.T) { + // Negative k should be rejected. + query := ` + { + me(func: bm25(description_bm25, "fox", "-1.0", "0.75")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: k must be a positive finite number") + + // b > 1 should be rejected. + query2 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "1.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query2) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") + + // b < 0 should be rejected. + query3 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "-0.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query3) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") +} + +func TestBM25AsFilter(t *testing.T) { + query := ` + { + me(func: has(description_bm25)) @filter(bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") + // Should not contain documents without "fox" + require.NotContains(t, js, "Dogs are loyal") +} + +func TestBM25NoResults(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "xyznonexistent")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25SingleTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "dog")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "dog") +} + +func TestBM25MultiTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick lazy")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should find docs with "quick" or "lazy" (scores accumulate). + // Doc 501 has both "quick" and "lazy", so it should rank high. + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25AllStopwords(t *testing.T) { + // A query consisting entirely of stopwords should return no results. + query := ` + { + me(func: bm25(description_bm25, "the a an")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25EmptyPredicate(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "")) { + uid + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25WithCount(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox")) { + count(uid) + } + } + ` + js := processQueryNoErr(t, query) + // Should have at least 2 results (docs with "fox") + require.Contains(t, js, "count") +} + +func TestBM25Pagination(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // With first:1, should return exactly one result (the highest-scoring). + // Doc 503 "fox fox fox" should be the top result. + require.Contains(t, js, "fox fox fox") +} + +func TestBM25ScoreOrdering(t *testing.T) { + // Bind the bm25 score to a value variable and order results by it via val(). + query := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // "fox fox fox" (doc 503) has the highest BM25 score (tf=3, shortest doc). + require.Contains(t, js, "fox fox fox") +} + +func TestBM25ScoreOrderingMultiTerm(t *testing.T) { + // Multi-term query with score ordering: "quick lazy" should rank doc 501 highest + // since it contains both terms. + query := ` + { + score as var(func: bm25(description_bm25, "quick lazy")) + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25ScoreOrderingAllResults(t *testing.T) { + // Verify all results are returned in score-descending order via val(score). + query := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // All fox-containing docs should appear. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + // Score values should be present. + require.Contains(t, js, "val(score)") +} + +func TestBM25ScoreWithPagination(t *testing.T) { + // Use offset with score ordering. + query := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score), first: 1, offset: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return the second-highest scored document (not "fox fox fox"). + require.NotContains(t, js, "fox fox fox") + require.Contains(t, js, "fox") +} + +// parseScoresFromJSON extracts uid → score from JSON responses containing val(score). +func parseScoresFromJSON(t *testing.T, js string) map[string]float64 { + t.Helper() + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + scores := make(map[string]float64) + for _, item := range resp.Data.Me { + scores[item.UID] = item.Score + } + return scores +} + +func TestBM25IncrementalAddBatch(t *testing.T) { + batch1 := ` + <600> "alpha bravo charlie" . + <601> "delta echo foxtrot" . + ` + batch2 := ` + <602> "golf hotel india" . + <603> "juliet kilo lima" . + <604> "mike november oscar" . + ` + batch3 := ` + <605> "papa quebec romeo" . + <606> "sierra tango uniform" . + <607> "victor whiskey xray" . + ` + cleanup := func() { + deleteTriplesInCluster(` + <600> * . + <601> * . + <602> * . + <603> * . + <604> * . + <605> * . + <606> * . + <607> * . + `) + } + t.Cleanup(cleanup) + + countQuery := ` + { + me(func: bm25(description_bm25, "alpha bravo delta echo golf juliet mike papa sierra victor")) { + count(uid) + } + } + ` + + // Batch 1: add 2 docs. + require.NoError(t, addTriplesToCluster(batch1)) + js := processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":2`) + + // Batch 2: add 3 more docs → total 5. + require.NoError(t, addTriplesToCluster(batch2)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":5`) + + // Batch 3: add 3 more docs → total 8. + require.NoError(t, addTriplesToCluster(batch3)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":8`) + + // Verify specific new terms are searchable. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid description_bm25 } }`) + require.Contains(t, js, "whiskey") +} + +func TestBM25CorpusStatsAffectIDF(t *testing.T) { + // Capture baseline score for "fox" query. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + jsBefore := processQueryNoErr(t, scoreQuery) + scoresBefore := parseScoresFromJSON(t, jsBefore) + require.NotEmpty(t, scoresBefore, "baseline should have fox results") + + // Add 10 non-fox docs → N grows, df("fox") stays same → IDF should increase. + var triples string + for i := 610; i < 620; i++ { + triples += fmt.Sprintf(`<%d> "completely unrelated document about cats and dogs number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 610; i < 620; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + jsAfter := processQueryNoErr(t, scoreQuery) + scoresAfter := parseScoresFromJSON(t, jsAfter) + + // Compare score for UID 503 ("fox fox fox") — should increase. + uid503 := uidHex(t, 503) + before, ok1 := scoresBefore[uid503] + after, ok2 := scoresAfter[uid503] + require.True(t, ok1 && ok2, "UID 503 should appear in both before and after results") + require.Greater(t, after, before, + "IDF should increase when corpus grows with non-matching docs (before=%f, after=%f)", before, after) +} + +func TestBM25DocumentUpdate(t *testing.T) { + // Add a doc with lots of "fox". + require.NoError(t, addTriplesToCluster(`<620> "fox fox fox fox" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<620> * .`) + }) + + uid620 := uidHex(t, 620) + + // Should rank top for "fox". + js := processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + } + }`) + require.Contains(t, js, `"`+uid620+`"`) + + // Update to remove "fox", add "cat". + deleteTriplesInCluster(`<620> "fox fox fox fox" .`) + require.NoError(t, addTriplesToCluster(`<620> "the cat sat on the mat" .`)) + + // Should no longer appear in "fox" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox")) { + uid + } + }`) + require.NotContains(t, js, `"`+uid620+`"`) + + // Should appear in "cat" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "cat")) { + uid + } + }`) + require.Contains(t, js, `"`+uid620+`"`) +} + +func TestBM25DocumentDeletion(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<625> "unique elephant term" .`)) + t.Cleanup(func() { + // Cleanup in case test fails before explicit delete. + deleteTriplesInCluster(`<625> * .`) + }) + + uid625 := uidHex(t, 625) + + // Should find the elephant doc. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"`+uid625+`"`) + + // Delete it. + deleteTriplesInCluster(`<625> "unique elephant term" .`) + + // Should return empty for "elephant". + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.JSONEq(t, `{"data": {"me":[]}}`, js) + + // Baseline "fox" results should be unaffected. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid description_bm25 } }`) + require.Contains(t, js, "fox") +} + +func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + uid503 := uidHex(t, 503) + + // Phase 1: baseline score. + js1 := processQueryNoErr(t, scoreQuery) + scores1 := parseScoresFromJSON(t, js1) + score1, ok := scores1[uid503] + require.True(t, ok, "UID 503 must appear in baseline") + + // Phase 2: add 5 fox docs → IDF decreases. + var foxTriples string + for i := 630; i < 635; i++ { + foxTriples += fmt.Sprintf(`<%d> "the fox runs quickly across the field number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(foxTriples)) + t.Cleanup(func() { + var del string + for i := 630; i < 640; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + js2 := processQueryNoErr(t, scoreQuery) + scores2 := parseScoresFromJSON(t, js2) + score2, ok := scores2[uid503] + require.True(t, ok, "UID 503 must appear after adding fox docs") + require.Greater(t, score1, score2, + "Adding fox docs should decrease IDF and thus score (phase1=%f, phase2=%f)", score1, score2) + + // Phase 3: add 5 non-fox docs → IDF increases relative to phase 2. + var nonFoxTriples string + for i := 635; i < 640; i++ { + nonFoxTriples += fmt.Sprintf(`<%d> "unrelated content about birds and fish number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(nonFoxTriples)) + + js3 := processQueryNoErr(t, scoreQuery) + scores3 := parseScoresFromJSON(t, js3) + score3, ok := scores3[uid503] + require.True(t, ok, "UID 503 must appear after adding non-fox docs") + require.Greater(t, score3, score2, + "Adding non-fox docs should increase IDF relative to phase2 (phase2=%f, phase3=%f)", score2, score3) +} + +func TestBM25LargeCorpus(t *testing.T) { + // Add 100 docs: 50 with "alpha", 50 with "beta". + var triples string + for i := 700; i < 750; i++ { + triples += fmt.Sprintf(`<%d> "alpha document content number %d with some padding words" . +`, i, i) + } + for i := 750; i < 800; i++ { + triples += fmt.Sprintf(`<%d> "beta document content number %d with some padding words" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 700; i < 800; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + // Count alpha docs. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Count beta docs. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "beta")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Union count: "alpha beta" should match all 100. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha beta")) { count(uid) } }`) + require.Contains(t, js, `"count":100`) + + // Pagination: first:10, offset:40 for alpha should return 10 results. + js = processQueryNoErr(t, ` + { + score as var(func: bm25(description_bm25, "alpha")) + me(func: uid(score), orderdesc: val(score), first: 10, offset: 40) { + uid + } + }`) + var resp struct { + Data struct { + Me []struct{ UID string } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.Len(t, resp.Data.Me, 10, "pagination first:10 offset:40 should return exactly 10 results") +} + +func TestBM25EdgeCaseSingleCharTerm(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<640> "x y z" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<640> * .`) + }) + + // Single-char terms may or may not be indexed depending on tokenizer. + // Just verify no panic/error. + _, err := processQuery(context.Background(), t, ` + { + me(func: bm25(description_bm25, "x")) { + uid + } + }`) + require.NoError(t, err) +} + +func TestBM25EdgeCaseLongDocument(t *testing.T) { + // Build a ~500-word document with "fox" appearing once. + words := make([]string, 500) + for i := range words { + words[i] = "padding" + } + words[250] = "fox" + longDoc := strings.Join(words, " ") + + require.NoError(t, addTriplesToCluster(fmt.Sprintf(`<645> %q .`, longDoc))) + t.Cleanup(func() { + deleteTriplesInCluster(`<645> * .`) + }) + + // Get scores for "fox" query. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + uid503 := uidHex(t, 503) // "fox fox fox" (doclen=3) + uid645 := uidHex(t, 645) // long doc (doclen~500) + s503, ok1 := scores[uid503] + s645, ok2 := scores[uid645] + require.True(t, ok1, "UID 503 must appear in fox results") + require.True(t, ok2, "UID 645 must appear in fox results") + require.Greater(t, s503, s645, + "Short doc with high tf should score higher than long doc with low tf (503=%f, 645=%f)", s503, s645) +} + +func TestBM25EdgeCaseUnicode(t *testing.T) { + triples := ` + <650> "der schnelle braune Fuchs springt" . + <651> "le renard brun rapide saute" . + <652> "el zorro marrón rápido salta" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <650> * . + <651> * . + <652> * . + `) + }) + + uid650 := uidHex(t, 650) + uid651 := uidHex(t, 651) + uid652 := uidHex(t, 652) + + // Query German term. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "Fuchs")) { uid } }`) + require.Contains(t, js, `"`+uid650+`"`) + + // Query French term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "renard")) { uid } }`) + require.Contains(t, js, `"`+uid651+`"`) + + // Query Spanish term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "zorro")) { uid } }`) + require.Contains(t, js, `"`+uid652+`"`) +} + +func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<655> "the a an is are was were" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<655> * .`) + }) + + uid655 := uidHex(t, 655) + + // Query "the" — should return empty since "the" is a stopword. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "the")) { uid } }`) + require.NotContains(t, js, `"`+uid655+`"`) // 655 should not appear + + // But the doc should exist via has(). + js = processQueryNoErr(t, ` + { + me(func: has(description_bm25)) @filter(uid(655)) { + uid + } + }`) + require.Contains(t, js, `"`+uid655+`"`) +} + +func TestBM25WithUidFilter(t *testing.T) { + // BM25 root with uid filter to restrict results. + query := ` + { + me(func: bm25(description_bm25, "fox")) @filter(uid(501, 503)) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + uid501 := uidHex(t, 501) + uid502 := uidHex(t, 502) + uid503 := uidHex(t, 503) + uid506 := uidHex(t, 506) + // Should contain only UIDs 501 and 503. + require.Contains(t, js, `"`+uid501+`"`) + require.Contains(t, js, `"`+uid503+`"`) + // Should NOT contain other fox docs like 502, 506. + require.NotContains(t, js, `"`+uid502+`"`) + require.NotContains(t, js, `"`+uid506+`"`) +} + +func TestBM25ScoreValuesAreValidFloats(t *testing.T) { + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + require.NotEmpty(t, scores, "should have at least one result") + + var prevScore float64 + first := true + // Iterate over results in order (they're orderdesc by score). + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + + for _, item := range resp.Data.Me { + score := item.Score + require.False(t, math.IsNaN(score), "score should not be NaN for uid %s", item.UID) + require.False(t, math.IsInf(score, 0), "score should not be Inf for uid %s", item.UID) + require.Greater(t, score, 0.0, "score should be positive for uid %s", item.UID) + + if !first { + require.GreaterOrEqual(t, prevScore, score, + "scores should be in descending order: %f >= %f", prevScore, score) + } + prevScore = score + first = false + } +} + +func TestBM25IncrementalAddThenDeleteThenReadd(t *testing.T) { + t.Cleanup(func() { + deleteTriplesInCluster(`<670> * .`) + }) + + // Phase 1: add with "elephant". + require.NoError(t, addTriplesToCluster(`<670> "elephant roams the savanna" .`)) + uid670 := uidHex(t, 670) + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"`+uid670+`"`) + + // Phase 2: delete. + deleteTriplesInCluster(`<670> "elephant roams the savanna" .`) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"`+uid670+`"`) + + // Phase 3: re-add with different content. + require.NoError(t, addTriplesToCluster(`<670> "penguin waddles on the ice" .`)) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "penguin")) { uid } }`) + require.Contains(t, js, `"`+uid670+`"`) + + // "elephant" should still not match 670. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"`+uid670+`"`) +} + +func TestBM25NonIndexedPredicateError(t *testing.T) { + // "name" predicate does not have @index(bm25). + query := ` + { + me(func: bm25(name, "alice")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25") +} + +func TestBM25ConcurrentBatchAdd(t *testing.T) { + // Add 5 batches of 4 docs each (UIDs 680-699) back-to-back. + t.Cleanup(func() { + var del string + for i := 680; i < 700; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + for batch := 0; batch < 5; batch++ { + var triples string + for j := 0; j < 4; j++ { + uid := 680 + batch*4 + j + triples += fmt.Sprintf(`<%d> "searchterm batch%d doc%d content here" . +`, uid, batch, j) + } + require.NoError(t, addTriplesToCluster(triples)) + } + + // All 20 docs should be findable. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "searchterm")) { count(uid) } }`) + require.Contains(t, js, `"count":20`) + + // Spot-check a doc from each batch. + for batch := 0; batch < 5; batch++ { + decUID := 680 + batch*4 + hexUID := uidHex(t, decUID) + term := fmt.Sprintf("batch%d", batch) + js = processQueryNoErr(t, fmt.Sprintf(`{ me(func: bm25(description_bm25, "%s")) { uid } }`, term)) + require.Contains(t, js, `"`+hexUID+`"`, "doc %d from batch %d should be searchable", decUID, batch) + } +} + +// parseCorpusCount returns the total number of documents with the description_bm25 predicate. +func parseCorpusCount(t *testing.T) float64 { + t.Helper() + js := processQueryNoErr(t, `{ me(func: has(description_bm25)) { count(uid) } }`) + var resp struct { + Data struct { + Me []struct { + Count int `json:"count"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me) + n := float64(resp.Data.Me[0].Count) + require.Greater(t, n, 0.0, "corpus must have documents") + return n +} + +func TestBM25ExactScoreValues(t *testing.T) { + // Exact score verification using b=0 (BM15 variant) to eliminate avgDL dependency. + // With b=0: score = idf * (k+1) * tf / (k + tf) + // This validates the core BM25 formula computes correct numerical values. + triples := ` + <850> "quasar quasar quasar" . + <851> "quasar nebula pulsar" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <850> * . + <851> * . + `) + }) + + N := parseCorpusCount(t) + + // Query "quasar" with b=0 so score depends only on tf, k, and IDF (not avgDL). + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "quasar", "1.2", "0")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + k := 1.2 + df := 2.0 // both 850 and 851 contain "quasar" + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + + // Doc 850 "quasar quasar quasar": tf=3, b=0 → score = idf * 2.2 * 3 / 4.2 + expected850 := idf * (k + 1) * 3.0 / (k + 3.0) + // Doc 851 "quasar nebula pulsar": tf=1, b=0 → score = idf * 2.2 * 1 / 2.2 = idf + expected851 := idf * (k + 1) * 1.0 / (k + 1.0) + + uid850 := uidHex(t, 850) + uid851 := uidHex(t, 851) + actual850, ok := scores[uid850] + require.True(t, ok, "UID 850 (%s) must be in results", uid850) + actual851, ok := scores[uid851] + require.True(t, ok, "UID 851 (%s) must be in results", uid851) + + require.InEpsilon(t, expected850, actual850, 1e-6, + "Doc 850 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected850, actual850, N, df, idf) + require.InEpsilon(t, expected851, actual851, 1e-6, + "Doc 851 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected851, actual851, N, df, idf) + + // Verify ordering: higher tf should yield higher score. + require.Greater(t, actual850, actual851) +} + +func TestBM25BM15NoLengthNormalization(t *testing.T) { + // With b=0 (BM15 variant), document length should NOT affect the score. + // Two docs with the same term frequency but different lengths must score identically. + triples := ` + <860> "vortex" . + <861> "vortex alpha bravo charlie delta echo foxtrot golf hotel india" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <860> * . + <861> * . + `) + }) + + // Query with b=0: length normalization disabled. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "vortex", "1.2", "0")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + uid860 := uidHex(t, 860) + uid861 := uidHex(t, 861) + score860, ok1 := scores[uid860] + score861, ok2 := scores[uid861] + require.True(t, ok1, "UID 860 must be in results") + require.True(t, ok2, "UID 861 must be in results") + + // With b=0 and same tf=1, scores must be equal regardless of document length. + require.InDelta(t, score860, score861, 1e-9, + "b=0 should disable length normalization: short doc score=%f, long doc score=%f", + score860, score861) + + // Now verify that with default b=0.75, the shorter doc scores higher. + scoreQueryDefault := ` + { + score as var(func: bm25(description_bm25, "vortex")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js = processQueryNoErr(t, scoreQueryDefault) + scoresDefault := parseScoresFromJSON(t, js) + + defScore860, ok1 := scoresDefault[uid860] + defScore861, ok2 := scoresDefault[uid861] + require.True(t, ok1, "UID 860 must be in default results") + require.True(t, ok2, "UID 861 must be in default results") + require.Greater(t, defScore860, defScore861, + "With b=0.75, shorter doc (doclen=1) should score higher than longer doc (doclen=10)") +} + +func TestBM25SingleMatchingDocument(t *testing.T) { + // Edge case: a single document matching the query term (df=1). + // IDF should be high since the term is very rare. + triples := `<865> "aardvark" .` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(`<865> * .`) + }) + + N := parseCorpusCount(t) + + // Query with b=0 for exact verification. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "aardvark", "1.2", "0")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + require.Len(t, scores, 1, "exactly one document should match 'aardvark'") + + uid865 := uidHex(t, 865) + actual, ok := scores[uid865] + require.True(t, ok, "UID 865 (%s) must be in results", uid865) + + // With df=1, tf=1, b=0, k=1.2: + // idf = log1p((N - 1 + 0.5) / (1 + 0.5)) = log1p((N - 0.5) / 1.5) + // score = idf * 2.2 * 1 / (1.2 + 1) = idf * 2.2 / 2.2 = idf + k := 1.2 + df := 1.0 + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + expected := idf * (k + 1) * 1.0 / (k + 1.0) // simplifies to idf + + require.InEpsilon(t, expected, actual, 1e-6, + "Single-doc score mismatch: expected %f, got %f (N=%f, idf=%f)", + expected, actual, N, idf) + require.Greater(t, actual, 0.0, "score must be positive") + require.False(t, math.IsInf(actual, 0), "score must be finite") +} diff --git a/query/query_hybrid_test.go b/query/query_hybrid_test.go new file mode 100644 index 00000000000..c75786677be --- /dev/null +++ b/query/query_hybrid_test.go @@ -0,0 +1,276 @@ +//go:build integration || cloud + +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +//nolint:lll +package query + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +// hybridResult unmarshals a {uid, val(f)} result list ordered by fused score. +type hybridRow struct { + UID string `json:"uid"` + Score float64 `json:"val(f)"` +} + +func fuseRows(t *testing.T, js string) []hybridRow { + t.Helper() + var resp struct { + Data struct { + Me []hybridRow `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + return resp.Data.Me +} + +func fuseUIDSet(rows []hybridRow) map[string]bool { + set := make(map[string]bool, len(rows)) + for _, r := range rows { + set[r.UID] = true + } + return set +} + +// --- similar_to score surfacing (prerequisite for vector fusion) ------------- + +func TestSimilarToScoreVariable(t *testing.T) { + // similar_to bound to a value variable surfaces a higher-is-better similarity + // score. The query vector equals doc 503's embedding, so 503 is the closest + // (euclidean distance 0 -> score 1.0) and must rank first. + query := ` + { + s as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) + me(func: uid(s), orderdesc: val(s)) { + uid + val(s) + } + }` + js := processQueryNoErr(t, query) + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(s)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me) + + uid503 := uidHex(t, 503) + require.Equal(t, uid503, resp.Data.Me[0].UID, "closest vector (503) should rank first") + require.InDelta(t, 1.0, resp.Data.Me[0].Score, 1e-6, "exact match should score 1/(1+0)=1.0") + + // Scores must be in descending order. + for i := 1; i < len(resp.Data.Me); i++ { + require.GreaterOrEqual(t, resp.Data.Me[i-1].Score, resp.Data.Me[i].Score) + } +} + +// --- fuse() over BM25 channels ---------------------------------------------- + +func TestFuseRRFTwoBM25Channels(t *testing.T) { + // Fuse two BM25 channels: "fox" (matches 501,502,503,506,507) and "dog" + // (matches 501,502,504,505). The union must include dog-only docs (504,505) + // that the fox channel never returns — proving outer-join (not intersection). + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.NotEmpty(t, rows) + set := fuseUIDSet(rows) + + // Union contains both fox-only (503) and dog-only (504, 505) documents. + require.True(t, set[uidHex(t, 503)], "fox-only doc 503 must be present") + require.True(t, set[uidHex(t, 504)], "dog-only doc 504 must be present (union, not intersection)") + require.True(t, set[uidHex(t, 505)], "dog-only doc 505 must be present") + + // Doc 501/502 appear in BOTH channels, so their fused RRF score should exceed a + // document that appears in only one channel at the same rank. + scores := make(map[string]float64) + for _, r := range rows { + scores[r.UID] = r.Score + } + require.Greater(t, scores[uidHex(t, 501)], 0.0) + + // Scores descending. + for i := 1; i < len(rows); i++ { + require.GreaterOrEqual(t, rows[i-1].Score, rows[i].Score) + } +} + +func TestFuseLinearWeights(t *testing.T) { + // Linear fusion with weights. Heavily weight the "dog" channel; a dog-only doc + // should then be present and outscore where appropriate. + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "linear", weights: "0.1,0.9", normalize: "max")) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.NotEmpty(t, rows) + set := fuseUIDSet(rows) + require.True(t, set[uidHex(t, 504)], "dog-only doc must appear with linear fusion") + for i := 1; i < len(rows); i++ { + require.GreaterOrEqual(t, rows[i-1].Score, rows[i].Score) + } +} + +func TestFuseSingleChannelPassthrough(t *testing.T) { + // A single-channel fuse should preserve that channel's ranking. "fox fox fox" + // (503) is the top BM25 result for "fox". + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + f as var(func: fuse(fox, method: "rrf")) + me(func: uid(f), orderdesc: val(f), first: 1) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.Len(t, rows, 1) + require.Equal(t, uidHex(t, 503), rows[0].UID) +} + +func TestFusePagination(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "rrf")) + me(func: uid(f), orderdesc: val(f), first: 2, offset: 1) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.Len(t, rows, 2, "first:2 offset:1 should return exactly 2 rows") +} + +// --- fuse() hybrid: BM25 + vector ------------------------------------------- + +func TestFuseHybridBM25AndVector(t *testing.T) { + // The headline use case: fuse BM25 text relevance with vector similarity. + // Doc 503 ("fox fox fox", embedding [3,0,0,0]) is rank-1 in BOTH the "fox" BM25 + // channel and the [3,0,0,0] vector channel, so it must be the top fused result. + // The vector channel (k=7) returns all docs, so dog-only docs (504,505) enter + // the union even though the BM25 "fox" channel never returns them. + query := ` + { + txt as var(func: bm25(description_bm25, "fox")) + vec as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) + f as var(func: fuse(txt, vec, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.NotEmpty(t, rows) + require.Equal(t, uidHex(t, 503), rows[0].UID, "503 is rank-1 in both channels -> top fused") + + set := fuseUIDSet(rows) + require.True(t, set[uidHex(t, 504)], "vector-only doc 504 must enter the union") + require.True(t, set[uidHex(t, 505)], "vector-only doc 505 must enter the union") + + for i := 1; i < len(rows); i++ { + require.GreaterOrEqual(t, rows[i-1].Score, rows[i].Score) + } +} + +// --- hybrid() sugar ---------------------------------------------------------- + +func TestHybridSugarEquivalentToFuse(t *testing.T) { + // hybrid() must produce the same top result as the explicit fuse() form. + hybridQ := ` + { + f as var(func: hybrid(description_bm25, "fox", description_vec, "[3.0, 0.0, 0.0, 0.0]", topk: 7, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + explicitQ := ` + { + txt as var(func: bm25(description_bm25, "fox"), first: 7) + vec as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) + f as var(func: fuse(txt, vec, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + hybridRows := fuseRows(t, processQueryNoErr(t, hybridQ)) + explicitRows := fuseRows(t, processQueryNoErr(t, explicitQ)) + + require.NotEmpty(t, hybridRows) + require.Equal(t, len(explicitRows), len(hybridRows), "hybrid and fuse must return the same set") + require.Equal(t, explicitRows[0].UID, hybridRows[0].UID, "same top result") + // Fused scores should match position-for-position. + for i := range explicitRows { + require.Equal(t, explicitRows[i].UID, hybridRows[i].UID, "row %d uid mismatch", i) + require.InDelta(t, explicitRows[i].Score, hybridRows[i].Score, 1e-9, "row %d score mismatch", i) + } +} + +// --- error handling ---------------------------------------------------------- + +func TestFuseUnknownMethod(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + f as var(func: fuse(fox, method: "bogus")) + me(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "method") +} + +func TestFuseWeightsCountMismatch(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "linear", weights: "0.5")) + me(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "weights") +} + +func TestFuseBadK(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + f as var(func: fuse(fox, k: "-5")) + me(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "k must be") +} diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 39d72d8f5e7..890ca170480 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -214,6 +214,24 @@ type SimilarityType[T c.Float] struct { isSimilarityMetric bool } +// similarityScore converts a heap element's metric-domain value into a +// higher-is-better similarity score suitable for ranking and score fusion. +// +// - Cosine / dot product (isSimilarityMetric): the value is already a similarity +// where higher is better, so it is returned as-is. +// - Euclidean: the value is a squared L2 distance where lower is better. It is +// mapped to 1/(1+d) in (0,1], which is monotonically decreasing in distance, so +// higher is better and the result is well-behaved under linear normalization. +// +// Keeping this orientation in one place lets every caller (and hybrid-search +// fusion) treat vector scores with the same higher-is-better convention as BM25. +func (s SimilarityType[T]) similarityScore(value T) float64 { + if s.isSimilarityMetric { + return float64(value) + } + return 1.0 / (1.0 + float64(value)) +} + func GetSimType[T c.Float](indexType string, floatBits int) SimilarityType[T] { switch { case indexType == Euclidean: diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 5658800e579..e7b8561e24e 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -266,6 +266,20 @@ func (ph *persistentHNSW[T]) SearchWithOptions( maxResults int, opts index.VectorIndexOptions[T], ) ([]uint64, error) { + uids, _, err := ph.SearchWithOptionsScored(ctx, c, query, maxResults, opts) + return uids, err +} + +// SearchWithOptionsScored is SearchWithOptions that also returns a higher-is-better +// similarity score for each returned uid (positionally aligned). See +// index.ScoredSearchOptions. +func (ph *persistentHNSW[T]) SearchWithOptionsScored( + ctx context.Context, + c index.CacheType, + query []T, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, []float64, error) { if opts.Filter == nil { opts.Filter = index.AcceptAll[T] } @@ -279,7 +293,7 @@ func (ph *persistentHNSW[T]) SearchWithOptions( var startVec []T entry, err := ph.PickStartNode(ctx, c, &startVec) if err != nil { - return nil, err + return nil, nil, err } // Upper layers use efUpper (override if provided) @@ -296,13 +310,13 @@ func (ph *persistentHNSW[T]) SearchWithOptions( layerResult, err := ph.searchPersistentLayer( c, level, entry, startVec, query, filterOut, efUpper, opts.Filter) if err != nil { - return nil, err + return nil, nil, err } layerResult.updateFinalMetrics(r) entry = layerResult.bestNeighbor().index layerResult.updateFinalPath(r) if err = ph.getVecFromUid(entry, c, &startVec); err != nil { - return nil, err + return nil, nil, err } } @@ -315,13 +329,14 @@ func (ph *persistentHNSW[T]) SearchWithOptions( layerResult, err := ph.searchPersistentLayer( c, ph.maxLevels-1, entry, startVec, query, filterOut, candidateK, opts.Filter) if err != nil { - return nil, err + return nil, nil, err } layerResult.updateFinalMetrics(r) layerResult.updateFinalPath(r) // Build final neighbor list with optional threshold, limited to maxResults. res := make([]uint64, 0, maxResults) + scores := make([]float64, 0, maxResults) for _, n := range layerResult.neighbors { if maxResults == 0 { break @@ -347,23 +362,38 @@ func (ph *persistentHNSW[T]) SearchWithOptions( } } res = append(res, n.index) + scores = append(scores, ph.simType.similarityScore(n.value)) if len(res) >= maxResults { break } } r.Metrics[searchTime] = uint64(time.Now().UnixMilli() - start) - return res, nil + return res, scores, nil } // SearchWithUidAndOptions is analogous to SearchWithUid but applies per‑call options. func (ph *persistentHNSW[T]) SearchWithUidAndOptions( - _ context.Context, + ctx context.Context, c index.CacheType, queryUid uint64, maxResults int, opts index.VectorIndexOptions[T], ) ([]uint64, error) { + uids, _, err := ph.SearchWithUidAndOptionsScored(ctx, c, queryUid, maxResults, opts) + return uids, err +} + +// SearchWithUidAndOptionsScored is SearchWithUidAndOptions that also returns a +// higher-is-better similarity score for each returned uid (positionally aligned). +// See index.ScoredSearchOptions. +func (ph *persistentHNSW[T]) SearchWithUidAndOptionsScored( + _ context.Context, + c index.CacheType, + queryUid uint64, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, []float64, error) { if opts.Filter == nil { opts.Filter = index.AcceptAll[T] } @@ -373,12 +403,12 @@ func (ph *persistentHNSW[T]) SearchWithUidAndOptions( var queryVec []T if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { if errors.Is(err, errFetchingPostingList) { - return []uint64{}, nil + return []uint64{}, []float64{}, nil } - return []uint64{}, err + return []uint64{}, []float64{}, err } if len(queryVec) == 0 { - return []uint64{}, nil + return []uint64{}, []float64{}, nil } filterOut := !opts.Filter(queryVec, queryVec, queryUid) candidateK := maxResults @@ -388,9 +418,10 @@ func (ph *persistentHNSW[T]) SearchWithUidAndOptions( lr, err := ph.searchPersistentLayer( c, ph.maxLevels-1, queryUid, queryVec, queryVec, filterOut, candidateK, opts.Filter) if err != nil { - return []uint64{}, err + return []uint64{}, []float64{}, err } res := make([]uint64, 0, maxResults) + scores := make([]float64, 0, maxResults) for _, n := range lr.neighbors { if maxResults == 0 { break @@ -413,11 +444,12 @@ func (ph *persistentHNSW[T]) SearchWithUidAndOptions( } } res = append(res, n.index) + scores = append(scores, ph.simType.similarityScore(n.value)) if len(res) >= maxResults { break } } - return res, nil + return res, scores, nil } // SearchWithUid searches the HNSW graph for the nearest neighbors of the query UID @@ -548,13 +580,56 @@ func (ph *persistentHNSW[T]) SearchWithPath( } layerResult.updateFinalMetrics(r) layerResult.updateFinalPath(r) - layerResult.addFinalNeighbors(r) + layerResult.addFinalNeighbors(r, ph.simType) t := time.Now().UnixMilli() elapsed := t - start r.Metrics[searchTime] = uint64(elapsed) return r, nil } +// SearchScored is Search that also returns a higher-is-better similarity score for +// each returned uid (positionally aligned with the neighbor uids). It preserves the +// exact candidate-exploration behavior of Search (unlike the options-based path), +// so scoring an otherwise plain query does not change which neighbors are returned. +func (ph *persistentHNSW[T]) SearchScored(ctx context.Context, c index.CacheType, query []T, + maxResults int, filter index.SearchFilter[T]) ([]uint64, []float64, error) { + r, err := ph.SearchWithPath(ctx, c, query, maxResults, filter) + if err != nil { + return nil, nil, err + } + return r.Neighbors, r.Distances, nil +} + +// SearchWithUidScored is SearchWithUid that also returns a higher-is-better +// similarity score for each returned uid (positionally aligned), preserving +// SearchWithUid's exact neighbor selection. +func (ph *persistentHNSW[T]) SearchWithUidScored(_ context.Context, c index.CacheType, + queryUid uint64, maxResults int, filter index.SearchFilter[T]) ([]uint64, []float64, error) { + var queryVec []T + if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { + if errors.Is(err, errFetchingPostingList) { + return []uint64{}, []float64{}, nil + } + return []uint64{}, []float64{}, err + } + if len(queryVec) == 0 { + return []uint64{}, []float64{}, nil + } + shouldFilterOutQueryVec := !filter(queryVec, queryVec, queryUid) + r, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, queryUid, queryVec, queryVec, shouldFilterOutQueryVec, maxResults, filter) + if err != nil { + return []uint64{}, []float64{}, err + } + uids := make([]uint64, 0, len(r.neighbors)) + scores := make([]float64, 0, len(r.neighbors)) + for _, n := range r.neighbors { + uids = append(uids, n.index) + scores = append(scores, ph.simType.similarityScore(n.value)) + } + return uids, scores, nil +} + // InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, diff --git a/tok/hnsw/search_layer.go b/tok/hnsw/search_layer.go index 55140e7319a..81c529bd29b 100644 --- a/tok/hnsw/search_layer.go +++ b/tok/hnsw/search_layer.go @@ -113,10 +113,13 @@ func (slr *searchLayerResult[T]) updateFinalPath(r *index.SearchPathResult) { r.Path = append(r.Path, slr.path...) } -func (slr *searchLayerResult[T]) addFinalNeighbors(r *index.SearchPathResult) { +func (slr *searchLayerResult[T]) addFinalNeighbors(r *index.SearchPathResult, simType SimilarityType[T]) { for _, n := range slr.neighbors { if !n.filteredOut { r.Neighbors = append(r.Neighbors, n.index) + // Distances carries the higher-is-better similarity for each neighbor, + // positionally aligned with Neighbors, so scored searches can surface it. + r.Distances = append(r.Distances, simType.similarityScore(n.value)) } } } diff --git a/tok/hnsw/similarity_score_test.go b/tok/hnsw/similarity_score_test.go new file mode 100644 index 00000000000..481c97371b7 --- /dev/null +++ b/tok/hnsw/similarity_score_test.go @@ -0,0 +1,36 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hnsw + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestSimilarityScoreOrientation verifies that every metric surfaces a +// higher-is-better similarity score, which native hybrid-search fusion relies on. +func TestSimilarityScoreOrientation(t *testing.T) { + cosine := GetSimType[float32](Cosine, 32) + dot := GetSimType[float32](DotProd, 32) + euclid := GetSimType[float32](Euclidean, 32) + + // Cosine / dot product: returned as-is (already higher-is-better). + require.InDelta(t, 0.9, cosine.similarityScore(0.9), 1e-6) + require.InDelta(t, -0.2, cosine.similarityScore(-0.2), 1e-6) + require.InDelta(t, 12.5, dot.similarityScore(12.5), 1e-6) + + // Euclidean: squared distance mapped to 1/(1+d), so a smaller distance yields a + // larger score (closer = better). + near := euclid.similarityScore(0.0) // distance 0 -> perfect match + mid := euclid.similarityScore(1.0) + far := euclid.similarityScore(9.0) + require.InDelta(t, 1.0, near, 1e-6) + require.InDelta(t, 0.5, mid, 1e-6) + require.InDelta(t, 0.1, far, 1e-6) + require.Greater(t, near, mid) + require.Greater(t, mid, far) +} diff --git a/tok/index/index.go b/tok/index/index.go index 1e981ef189e..050ec85e97e 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -147,6 +147,29 @@ type OptionalSearchOptions[T c.Float] interface { maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) } +// ScoredSearchOptions extends search to also return a higher-is-better similarity +// score for each returned uid (positionally aligned). These power native hybrid +// search: the score is bound to a DQL value variable so vector results can be a +// fusion channel alongside BM25. Scores carry the same higher-is-better convention +// as BM25 (cosine/dot as-is; euclidean as 1/(1+dist)). +// +// SearchScored / SearchWithUidScored preserve the exact neighbor selection of +// Search / SearchWithUid (so scoring a plain query does not change its results), +// while the *Options* variants apply per-call ef/distance-threshold controls. +type ScoredSearchOptions[T c.Float] interface { + SearchScored(ctx context.Context, c CacheType, query []T, + maxResults int, filter SearchFilter[T]) ([]uint64, []float64, error) + + SearchWithUidScored(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, filter SearchFilter[T]) ([]uint64, []float64, error) + + SearchWithOptionsScored(ctx context.Context, c CacheType, query []T, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, []float64, error) + + SearchWithUidAndOptionsScored(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, []float64, error) +} + // A Txn is an interface representation of a persistent storage transaction, // where multiple operations are performed on a database type Txn interface { diff --git a/tok/index/search_path.go b/tok/index/search_path.go index 7e24b7d068c..5387b71ee7f 100644 --- a/tok/index/search_path.go +++ b/tok/index/search_path.go @@ -12,6 +12,10 @@ type SearchPathResult struct { // The collection of nearest neighbors in sorted order after filtering // out neighbors that fail any Filter criteria. Neighbors []uint64 + // Distances holds the higher-is-better similarity score for each entry in + // Neighbors (positionally aligned). It is populated by scored searches and may + // be empty for callers that only need the neighbor uids. + Distances []float64 // The path from the start of search to the closest neighbor vector. Path []uint64 // A collection of captured named counters that occurred for the @@ -24,6 +28,7 @@ type SearchPathResult struct { func NewSearchPathResult() *SearchPathResult { return &SearchPathResult{ Neighbors: []uint64{}, + Distances: []float64{}, Path: []uint64{}, Metrics: make(map[string]uint64), } diff --git a/tok/tok.go b/tok/tok.go index c1da3e991d7..cb50b0a369e 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -50,6 +50,7 @@ const ( IdentBigFloat = 0xD IdentVFloat = 0xE IdentNGram = 0xF + IdentBM25 = 0x10 IdentCustom = 0x80 IdentDelimiter = 0x1f // ASCII 31 - Unit separator ) @@ -101,6 +102,7 @@ func init() { registerTokenizer(TermTokenizer{}) registerTokenizer(FullTextTokenizer{}) registerTokenizer(NGramTokenizer{}) + registerTokenizer(BM25Tokenizer{}) registerTokenizer(Sha256Tokenizer{}) setupBleve() } @@ -576,6 +578,47 @@ func (t FullTextTokenizer) Identifier() byte { return IdentFullText } func (t FullTextTokenizer) IsSortable() bool { return false } func (t FullTextTokenizer) IsLossy() bool { return true } +// BM25Tokenizer generates tokens for BM25 ranked text search. +// It uses the same pipeline as FullTextTokenizer (normalize, stopwords, stem) +// but preserves duplicates for term frequency counting. +type BM25Tokenizer struct{ lang string } + +func (t BM25Tokenizer) Name() string { return "bm25" } +func (t BM25Tokenizer) Type() string { return "string" } +func (t BM25Tokenizer) Tokens(v interface{}) ([]string, error) { + str, ok := v.(string) + if !ok || str == "" { + return []string{}, nil + } + lang := LangBase(t.lang) + tokens := fulltextAnalyzer.Analyze([]byte(str)) + tokens = filterStopwords(lang, tokens) + tokens = filterStemmers(lang, tokens) + // Return all tokens with duplicates preserved (for TF counting). + result := make([]string, 0, len(tokens)) + for _, t := range tokens { + result = append(result, string(t.Term)) + } + return result, nil +} +func (t BM25Tokenizer) Identifier() byte { return IdentBM25 } +func (t BM25Tokenizer) IsSortable() bool { return false } +func (t BM25Tokenizer) IsLossy() bool { return true } + +// TokensWithFrequency tokenizes the input and returns term frequencies and doc length. +func (t BM25Tokenizer) TokensWithFrequency(v interface{}, lang string) (map[string]uint32, uint32, error) { + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(v) + if err != nil { + return nil, 0, err + } + termFreqs := make(map[string]uint32, len(allTokens)) + for _, t := range allTokens { + termFreqs[t]++ + } + return termFreqs, uint32(len(allTokens)), nil +} + // Sha256Tokenizer generates tokens for the sha256 hash part from string data. type Sha256Tokenizer struct{ _ string } diff --git a/tok/tok_test.go b/tok/tok_test.go index 4c95094e577..b9fbc4dd1a5 100644 --- a/tok/tok_test.go +++ b/tok/tok_test.go @@ -652,6 +652,146 @@ func TestNGramTokenizerNonStringInput(t *testing.T) { require.Equal(t, 0, len(tokens2), "Expected empty tokens for nil input") } +func TestBM25Tokenizer(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + require.NotNil(t, tokenizer) + require.Equal(t, "bm25", tokenizer.Name()) + require.Equal(t, "string", tokenizer.Type()) + require.Equal(t, byte(IdentBM25), tokenizer.Identifier()) + require.True(t, tokenizer.IsLossy()) + require.False(t, tokenizer.IsSortable()) +} + +func TestBM25TokensPreservesDuplicates(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("fox fox fox dog") + require.NoError(t, err) + // "fox" should appear 3 times (duplicates preserved), "dog" once + foxCount := 0 + dogCount := 0 + for _, token := range tokens { + if token == "fox" { + foxCount++ + } + if token == "dog" { + dogCount++ + } + } + require.Equal(t, 3, foxCount, "Expected 3 occurrences of 'fox'") + require.Equal(t, 1, dogCount, "Expected 1 occurrence of 'dog'") +} + +func TestBM25TokensWithFrequency(t *testing.T) { + tok := BM25Tokenizer{} + termFreqs, docLen, err := tok.TokensWithFrequency("the quick brown fox fox fox", "en") + require.NoError(t, err) + // "the" is a stopword and should be removed + _, hasThe := termFreqs["the"] + require.False(t, hasThe, "'the' should be removed as stopword") + // "fox" should have tf=3 + require.Equal(t, uint32(3), termFreqs["fox"]) + // "quick" -> "quick" (stemmed) + require.Contains(t, termFreqs, "quick") + require.Equal(t, uint32(1), termFreqs["quick"]) + // "brown" -> "brown" (stemmed) + require.Contains(t, termFreqs, "brown") + require.Equal(t, uint32(1), termFreqs["brown"]) + // docLen should be total tokens after stopword removal + require.Equal(t, uint32(5), docLen) +} + +func TestBM25TokensEmpty(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestBM25TokensSingleWord(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("hello") + require.NoError(t, err) + require.Equal(t, 1, len(tokens)) + require.Equal(t, "hello", tokens[0]) +} + +func TestBM25TokensStemming(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("running jumping swimming") + require.NoError(t, err) + require.Equal(t, 3, len(tokens)) + require.Contains(t, tokens, "run") + require.Contains(t, tokens, "jump") + require.Contains(t, tokens, "swim") +} + +func TestGetBM25QueryTokens(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"quick brown fox fox"}, "en") + require.NoError(t, err) + // Query tokens should be deduplicated + require.Equal(t, 3, len(tokens)) + // Each token should be encoded with the BM25 identifier prefix + for _, token := range tokens { + require.Equal(t, byte(IdentBM25), token[0], "Token should start with BM25 identifier") + } +} + +func TestGetBM25QueryTokensEmpty(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{""}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestBM25TokenizerForLang(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + langTok := GetTokenizerForLang(tokenizer, "de") + bm25Tok, ok := langTok.(BM25Tokenizer) + require.True(t, ok) + // German: "Katzen" -> "katz" (stemmed) + tokens, err := bm25Tok.Tokens("Katzen und Katzen") + require.NoError(t, err) + // "und" is a German stopword + katzCount := 0 + for _, token := range tokens { + if token == "katz" { + katzCount++ + } + } + require.Equal(t, 2, katzCount, "Expected 2 occurrences of stemmed 'katz'") +} + +func TestBM25AllStopwords(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("the a an is") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("the a an is", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestGetBM25QueryTokensAllStopwords(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"the a an"}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestGetBM25QueryTokensWrongArgCount(t *testing.T) { + _, err := GetBM25QueryTokens([]string{}, "en") + require.Error(t, err) + _, err = GetBM25QueryTokens([]string{"a", "b"}, "en") + require.Error(t, err) +} + func BenchmarkTermTokenizer(b *testing.B) { b.Skip() // tmp } diff --git a/tok/tokens.go b/tok/tokens.go index bda9a04e743..f089a3f4344 100644 --- a/tok/tokens.go +++ b/tok/tokens.go @@ -25,6 +25,8 @@ func GetTokenizerForLang(t Tokenizer, lang string) Tokenizer { // We must return a new instance because another goroutine might be calling this // with a different lang. return FullTextTokenizer{lang: lang} + case BM25Tokenizer: + return BM25Tokenizer{lang: lang} case TermTokenizer: return TermTokenizer{lang: lang} case ExactTokenizer: @@ -67,6 +69,29 @@ func GetNGramQueryTokens(funcArgs []string, lang string) ([]string, error) { return BuildNGramQueryTokens(funcArgs[0], NGramTokenizer{lang: lang}) } +// GetBM25QueryTokens tokenizes the query text using the fulltext pipeline, +// deduplicates, and encodes with the BM25 identifier prefix. +func GetBM25QueryTokens(funcArgs []string, lang string) ([]string, error) { + if l := len(funcArgs); l != 1 { + return nil, errors.Errorf("Function requires 1 arguments, but got %d", l) + } + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(funcArgs[0]) + if err != nil { + return nil, err + } + // Deduplicate for query + seen := make(map[string]struct{}, len(allTokens)) + var unique []string + for _, t := range allTokens { + if _, ok := seen[t]; !ok { + seen[t] = struct{}{} + unique = append(unique, encodeToken(t, tok.Identifier())) + } + } + return unique, nil +} + // GetFullTextTokens returns the full-text tokens for the given value. func GetFullTextTokens(funcArgs []string, lang string) ([]string, error) { if l := len(funcArgs); l != 1 { diff --git a/worker/bm25wand.go b/worker/bm25wand.go new file mode 100644 index 00000000000..c950ffbe3e8 --- /dev/null +++ b/worker/bm25wand.go @@ -0,0 +1,406 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "sort" + + "github.com/dgraph-io/dgraph/v25/posting" +) + +// wandBlockSize is the number of postings grouped into one logical block for +// Block-Max WAND upper bounds. The postings come from a standard Dgraph posting +// list (already resident in memory once loaded); these blocks exist only to give +// WAND per-block score bounds for pruning — they are not a storage format. +const wandBlockSize = 128 + +// termCursor is an in-memory cursor over one query term's posting list, +// materialized from the standard posting List as UID-ascending (uid, tf, docLen) +// entries. Document length travels with each posting, so scoring needs no separate +// lookup. Per-block max bounds drive Block-Max WAND pruning. +type termCursor struct { + postings []posting.BM25Posting + idf float64 + pos int + + // blockUBPre[i] is the pre-IDF BM25 upper bound for block i (max term + // frequency, min document length in the block). suffixUBPre[i] = max over + // j >= i of blockUBPre[j], for the remaining-list upper bound. + blockUBPre []float64 + suffixUBPre []float64 +} + +// ubPre computes the pre-IDF BM25 contribution upper bound for a block, using the +// block's maximum term frequency and minimum document length (the score is +// increasing in tf and decreasing in dl, so this is a safe upper bound). +func ubPre(maxTF, minDL uint32, k, b, avgDL float64) float64 { + if avgDL <= 0 { + avgDL = 1 + } + tf := float64(maxTF) + dl := float64(minDL) + denom := k*(1-b+b*dl/avgDL) + tf + if denom <= 0 { + return 0 + } + return (k + 1) * tf / denom +} + +// newTermCursor builds a cursor and precomputes its per-block upper bounds. +func newTermCursor(postings []posting.BM25Posting, idf, k, b, avgDL float64) *termCursor { + c := &termCursor{postings: postings, idf: idf} + numBlocks := (len(postings) + wandBlockSize - 1) / wandBlockSize + c.blockUBPre = make([]float64, numBlocks) + for blk := 0; blk < numBlocks; blk++ { + start := blk * wandBlockSize + end := start + wandBlockSize + if end > len(postings) { + end = len(postings) + } + var maxTF uint32 + minDL := uint32(math.MaxUint32) + for i := start; i < end; i++ { + if postings[i].TF > maxTF { + maxTF = postings[i].TF + } + dl := postings[i].DocLen + if dl == 0 { + dl = 1 + } + if dl < minDL { + minDL = dl + } + } + c.blockUBPre[blk] = ubPre(maxTF, minDL, k, b, avgDL) + } + c.suffixUBPre = make([]float64, numBlocks) + var running float64 + for blk := numBlocks - 1; blk >= 0; blk-- { + if c.blockUBPre[blk] > running { + running = c.blockUBPre[blk] + } + c.suffixUBPre[blk] = running + } + return c +} + +func (c *termCursor) exhausted() bool { return c.pos >= len(c.postings) } + +func (c *termCursor) currentDoc() uint64 { + if c.exhausted() { + return math.MaxUint64 + } + return c.postings[c.pos].Uid +} + +func (c *termCursor) currentTF() uint32 { + if c.exhausted() { + return 0 + } + return c.postings[c.pos].TF +} + +func (c *termCursor) currentDocLen() uint32 { + if c.exhausted() { + return 0 + } + return c.postings[c.pos].DocLen +} + +// remainingUB returns the IDF-weighted upper-bound score over the remainder of the +// list from the current position. +func (c *termCursor) remainingUB() float64 { + if c.exhausted() || len(c.suffixUBPre) == 0 { + return 0 + } + blk := c.pos / wandBlockSize + if blk >= len(c.suffixUBPre) { + return 0 + } + return c.idf * c.suffixUBPre[blk] +} + +// next advances by one posting. +func (c *termCursor) next() bool { + c.pos++ + return !c.exhausted() +} + +// skipTo advances to the first posting with UID >= target. +func (c *termCursor) skipTo(target uint64) bool { + if c.exhausted() { + return false + } + if c.postings[c.pos].Uid >= target { + return true + } + rel := sort.Search(len(c.postings)-c.pos, func(i int) bool { + return c.postings[c.pos+i].Uid >= target + }) + c.pos += rel + return !c.exhausted() +} + +// skipToWithBMW is skipTo with Block-Max WAND pruning: blocks whose upper bound +// combined with otherUB cannot beat theta are skipped wholesale. +func (c *termCursor) skipToWithBMW(target uint64, theta, otherUB float64) bool { + if !c.skipTo(target) { + return false + } + for !c.exhausted() { + blk := c.pos / wandBlockSize + if c.idf*c.blockUBPre[blk]+otherUB > theta { + return true + } + // This block can't produce a winner; jump to the start of the next block. + c.pos = (blk + 1) * wandBlockSize + } + return false +} + +// scoredDoc holds a UID and its BM25 score for the min-heap. +type scoredDoc struct { + uid uint64 + score float64 +} + +// topKHeap is a min-heap of scored documents for top-k tracking. +type topKHeap struct { + docs []scoredDoc + k int +} + +func (h *topKHeap) Len() int { return len(h.docs) } +func (h *topKHeap) Less(i, j int) bool { return h.docs[i].score < h.docs[j].score } +func (h *topKHeap) Swap(i, j int) { h.docs[i], h.docs[j] = h.docs[j], h.docs[i] } +func (h *topKHeap) Push(x interface{}) { h.docs = append(h.docs, x.(scoredDoc)) } +func (h *topKHeap) Pop() interface{} { + old := h.docs + n := len(old) + item := old[n-1] + h.docs = old[:n-1] + return item +} + +// threshold returns the minimum score in the heap (the score to beat). +func (h *topKHeap) threshold() float64 { + if len(h.docs) < h.k { + return 0 + } + return h.docs[0].score +} + +// tryPush adds a doc if it beats the current threshold. +func (h *topKHeap) tryPush(uid uint64, score float64) { + if len(h.docs) < h.k { + heap.Push(h, scoredDoc{uid: uid, score: score}) + return + } + if score > h.docs[0].score { + h.docs[0] = scoredDoc{uid: uid, score: score} + heap.Fix(h, 0) + } +} + +// sorted returns all docs sorted by score descending, then UID ascending. +func (h *topKHeap) sorted() []scoredDoc { + result := make([]scoredDoc, len(h.docs)) + copy(result, h.docs) + sort.Slice(result, func(i, j int) bool { + if result[i].score != result[j].score { + return result[i].score > result[j].score + } + return result[i].uid < result[j].uid + }) + return result +} + +// bm25Score computes the BM25 contribution of a single term occurrence. +func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { + if avgDL <= 0 { + avgDL = 1 + } + if dl <= 0 { + dl = 1 + } + return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) +} + +// wandSearch performs a WAND / Block-Max WAND top-k BM25 search over standard +// posting lists. queryTokens must already carry the BM25 tokenizer identifier +// byte. getList reads a posting list for a key. If topK <= 0, every matching +// document is scored (no early termination). +func wandSearch(getList func(key []byte) (*posting.List, error), attr string, readTs uint64, + queryTokens []string, k, b, avgDL, N float64, topK int, + filterSet map[uint64]struct{}, useBMW bool) ([]scoredDoc, error) { + + var cursors []*termCursor + for _, token := range queryTokens { + postings, err := posting.ReadBM25TermPostings(getList, attr, token, readTs) + if err != nil { + return nil, err + } + df := uint64(len(postings)) + if df == 0 { + continue + } + // N comes from bucketed stats and df from the term's posting list; if stats + // ever lag the postings, clamp N >= df for this term so the smoothed IDF + // stays non-negative and finite instead of producing a negative/NaN score. + dfN := float64(df) + nDocs := N + if nDocs < dfN { + nDocs = dfN + } + idf := math.Log1p((nDocs - dfN + 0.5) / (dfN + 0.5)) + cursors = append(cursors, newTermCursor(postings, idf, k, b, avgDL)) + } + + if len(cursors) == 0 { + return nil, nil + } + + if topK <= 0 { + return scoreAllDocs(cursors, k, b, avgDL, filterSet), nil + } + return wandTopK(cursors, k, b, avgDL, topK, filterSet, useBMW), nil +} + +// wandTopK runs the WAND / Block-Max WAND main loop over prepared cursors and +// returns the top-k documents sorted by score descending. It is the core scoring +// loop, separated from posting-list I/O so it can be exercised directly. +func wandTopK(cursors []*termCursor, k, b, avgDL float64, topK int, + filterSet map[uint64]struct{}, useBMW bool) []scoredDoc { + + h := &topKHeap{k: topK} + heap.Init(h) + + for { + // Drop exhausted cursors. + active := cursors[:0] + for _, c := range cursors { + if !c.exhausted() { + active = append(active, c) + } + } + cursors = active + if len(cursors) == 0 { + break + } + + // Sort cursors by current document ascending. + sort.Slice(cursors, func(i, j int) bool { + return cursors[i].currentDoc() < cursors[j].currentDoc() + }) + + theta := h.threshold() + + // Find pivot: accumulate upper bounds until they exceed theta. + var sumUB float64 + pivot := -1 + var pivotDoc uint64 + for i, c := range cursors { + sumUB += c.remainingUB() + if sumUB > theta && pivot == -1 { + pivot = i + pivotDoc = c.currentDoc() + } + } + if pivot == -1 { + break // sum of all upper bounds can't beat theta + } + + // Advance all cursors before the pivot up to pivotDoc. + allAtPivot := true + for i := 0; i < pivot; i++ { + if cursors[i].currentDoc() < pivotDoc { + var ok bool + if useBMW { + otherUB := sumUB - cursors[i].remainingUB() + ok = cursors[i].skipToWithBMW(pivotDoc, theta, otherUB) + } else { + ok = cursors[i].skipTo(pivotDoc) + } + if !ok { + allAtPivot = false + break + } + if cursors[i].currentDoc() != pivotDoc { + allAtPivot = false + } + } + } + if !allAtPivot { + continue + } + + // Score the pivot document. + if filterSet != nil { + if _, ok := filterSet[pivotDoc]; !ok { + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + c.next() + } + } + continue + } + } + + var score float64 + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + dl := float64(c.currentDocLen()) + score += bm25Score(c.idf, float64(c.currentTF()), dl, avgDL, k, b) + } + } + h.tryPush(pivotDoc, score) + + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + c.next() + } + } + } + + return h.sorted() +} + +// scoreAllDocs scores every matching document without early termination. Used when +// no top-k limit is specified. +func scoreAllDocs(cursors []*termCursor, k, b, avgDL float64, + filterSet map[uint64]struct{}) []scoredDoc { + + scores := make(map[uint64]float64) + + for _, c := range cursors { + for !c.exhausted() { + uid := c.currentDoc() + if filterSet != nil { + if _, ok := filterSet[uid]; !ok { + c.next() + continue + } + } + scores[uid] += bm25Score(c.idf, float64(c.currentTF()), float64(c.currentDocLen()), + avgDL, k, b) + c.next() + } + } + + results := make([]scoredDoc, 0, len(scores)) + for uid, s := range scores { + results = append(results, scoredDoc{uid: uid, score: s}) + } + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].uid < results[j].uid + }) + return results +} diff --git a/worker/bm25wand_test.go b/worker/bm25wand_test.go new file mode 100644 index 00000000000..8f133a6ec30 --- /dev/null +++ b/worker/bm25wand_test.go @@ -0,0 +1,200 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "math/rand" + "sort" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/posting" +) + +func TestTopKHeapBasic(t *testing.T) { + h := &topKHeap{k: 3} + heap.Init(h) + + require.Equal(t, 0.0, h.threshold()) + + h.tryPush(1, 5.0) + h.tryPush(2, 3.0) + require.Equal(t, 0.0, h.threshold()) // not full yet + + h.tryPush(3, 7.0) + require.InEpsilon(t, 3.0, h.threshold(), 1e-9) // full, min is 3.0 + + h.tryPush(4, 4.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) // 3.0 evicted, min is now 4.0 + + // 2.0 shouldn't be accepted. + h.tryPush(5, 2.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) + + sorted := h.sorted() + require.Len(t, sorted, 3) + require.Equal(t, uint64(3), sorted[0].uid) // highest score (7.0) + require.Equal(t, uint64(1), sorted[1].uid) // 5.0 + require.Equal(t, uint64(4), sorted[2].uid) // 4.0 +} + +func TestTopKHeapTieBreaking(t *testing.T) { + h := &topKHeap{k: 5} + heap.Init(h) + + // Same score, different UIDs — should sort by UID ascending. + h.tryPush(10, 5.0) + h.tryPush(5, 5.0) + h.tryPush(15, 5.0) + + sorted := h.sorted() + require.Equal(t, uint64(5), sorted[0].uid) + require.Equal(t, uint64(10), sorted[1].uid) + require.Equal(t, uint64(15), sorted[2].uid) +} + +func TestBm25ScoreFunction(t *testing.T) { + k, b := 1.2, 0.75 + avgDL := 10.0 + + // idf * (k+1) * tf / (k*(1-b+b*dl/avgDL) + tf) + idf := 1.5 + tf := 3.0 + dl := 10.0 + + expected := idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) + got := bm25Score(idf, tf, dl, avgDL, k, b) + require.InEpsilon(t, expected, got, 1e-9) + + // With b=0: no length normalization. + expected0 := idf * (k + 1) * tf / (k + tf) + got0 := bm25Score(idf, tf, dl, avgDL, k, 0) + require.InEpsilon(t, expected0, got0, 1e-9) + + // Score should be positive for positive inputs. + require.Greater(t, bm25Score(1.0, 1.0, 5.0, 10.0, k, b), 0.0) + + // Higher tf should produce higher score (same dl). + s1 := bm25Score(idf, 1.0, dl, avgDL, k, b) + s3 := bm25Score(idf, 3.0, dl, avgDL, k, b) + require.Greater(t, s3, s1) + + // Shorter doc should score higher (same tf). + sShort := bm25Score(idf, tf, 5.0, avgDL, k, b) + sLong := bm25Score(idf, tf, 20.0, avgDL, k, b) + require.Greater(t, sShort, sLong) +} + +func TestBm25ScoreNaN(t *testing.T) { + // Ensure no NaN/Inf for edge-case inputs. + score := bm25Score(0.5, 1.0, 0.0, 10.0, 1.2, 0.75) + require.False(t, math.IsNaN(score)) + require.False(t, math.IsInf(score, 0)) + require.Greater(t, score, 0.0) +} + +// brute force scores every doc across all cursors (ground truth for WAND). +func bruteForceTopK(termPostings [][]posting.BM25Posting, idfs []float64, + k, b, avgDL float64, topK int) []scoredDoc { + scores := map[uint64]float64{} + dls := map[uint64]uint32{} + for ti, ps := range termPostings { + for _, p := range ps { + scores[p.Uid] += bm25Score(idfs[ti], float64(p.TF), float64(p.DocLen), avgDL, k, b) + dls[p.Uid] = p.DocLen + } + } + out := make([]scoredDoc, 0, len(scores)) + for uid, s := range scores { + out = append(out, scoredDoc{uid: uid, score: s}) + } + sort.Slice(out, func(i, j int) bool { + if out[i].score != out[j].score { + return out[i].score > out[j].score + } + return out[i].uid < out[j].uid + }) + if topK > 0 && len(out) > topK { + out = out[:topK] + } + return out +} + +// TestWandMatchesBruteForce checks that WAND and Block-Max WAND return exactly the +// same top-k documents and scores as exhaustive scoring, across many randomized +// posting lists. This is the core correctness guarantee: pruning must never change +// the result, only the work done. +func TestWandMatchesBruteForce(t *testing.T) { + rng := rand.New(rand.NewSource(42)) + k, b, avgDL := 1.2, 0.75, 12.0 + + for trial := 0; trial < 200; trial++ { + numTerms := 1 + rng.Intn(4) + termPostings := make([][]posting.BM25Posting, numTerms) + idfs := make([]float64, numTerms) + for ti := 0; ti < numTerms; ti++ { + n := rng.Intn(400) // spans multiple wandBlockSize blocks + seen := map[uint64]bool{} + var ps []posting.BM25Posting + for j := 0; j < n; j++ { + uid := uint64(1 + rng.Intn(500)) + if seen[uid] { + continue + } + seen[uid] = true + ps = append(ps, posting.BM25Posting{ + Uid: uid, + TF: uint32(1 + rng.Intn(10)), + DocLen: uint32(1 + rng.Intn(30)), + }) + } + sort.Slice(ps, func(i, j int) bool { return ps[i].Uid < ps[j].Uid }) + termPostings[ti] = ps + // Vary IDF per term so different terms carry different weight. + idfs[ti] = 0.5 + rng.Float64()*2 + } + + topK := 1 + rng.Intn(10) + want := bruteForceTopK(termPostings, idfs, k, b, avgDL, topK) + // One extra result lets us detect a tie between the cutoff rank and the + // first excluded document (a boundary tie outside the top-k window). + wantPlus := bruteForceTopK(termPostings, idfs, k, b, avgDL, topK+1) + + build := func() []*termCursor { + cs := make([]*termCursor, 0, numTerms) + for ti, ps := range termPostings { + if len(ps) == 0 { + continue + } + cs = append(cs, newTermCursor(ps, idfs[ti], k, b, avgDL)) + } + return cs + } + + for _, useBMW := range []bool{false, true} { + got := wandTopK(build(), k, b, avgDL, topK, nil, useBMW) + require.Lenf(t, got, len(want), "trial %d bmw=%v len", trial, useBMW) + for i := range want { + // The score at each rank must match exactly: WAND/BMW pruning must + // never change which scores make the top-k, only the work done. + require.InEpsilonf(t, want[i].score, got[i].score, 1e-9, + "trial %d bmw=%v rank %d score", trial, useBMW, i) + // The uid is only guaranteed when this rank's score is not tied with + // a neighbor (including the first excluded doc); tied-boundary docs + // are interchangeable in the ranking. + tied := (i > 0 && wantPlus[i].score == wantPlus[i-1].score) || + (i+1 < len(wantPlus) && wantPlus[i].score == wantPlus[i+1].score) + if !tied { + require.Equalf(t, want[i].uid, got[i].uid, + "trial %d bmw=%v rank %d uid", trial, useBMW, i) + } + } + } + } +} diff --git a/worker/mutation.go b/worker/mutation.go index fdac2a41c1b..076beed185f 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -410,6 +410,19 @@ func checkSchema(s *pb.SchemaUpdate) error { x.ParseAttr(s.Predicate)) } + // BM25 scores a single document (one value) per UID: per-document length and + // corpus statistics are not well-defined for a list predicate, and the bucketed + // stats maintenance relies on conflict detection that a list predicate's + // value-dependent conflict key would not provide. Reject @index(bm25) on lists. + if s.List { + for _, tokenizer := range s.Tokenizer { + if tokenizer == "bm25" { + return errors.Errorf("Tokenizer 'bm25' cannot be applied to list predicate: %s", + x.ParseAttr(s.Predicate)) + } + } + } + // If schema update has upsert directive, it should have index directive. if s.Upsert && len(s.Tokenizer) == 0 && !s.Unique { return errors.Errorf("Index tokenizer is mandatory for: [%s] when specifying @upsert directive", diff --git a/worker/mutation_integration_test.go b/worker/mutation_integration_test.go index 99a2a1eed01..f1f4b81695b 100644 --- a/worker/mutation_integration_test.go +++ b/worker/mutation_integration_test.go @@ -93,6 +93,25 @@ func TestCheckSchema(t *testing.T) { } require.NoError(t, checkSchema(s1)) + // bm25 on a scalar string predicate is allowed. + s1 = &pb.SchemaUpdate{ + Predicate: x.AttrInRootNamespace("bio"), + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"bm25"}, + } + require.NoError(t, checkSchema(s1)) + + // bm25 on a list predicate is rejected. + s1 = &pb.SchemaUpdate{ + Predicate: x.AttrInRootNamespace("tags"), + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"bm25"}, + List: true, + } + require.Error(t, checkSchema(s1)) + s1 = &pb.SchemaUpdate{ Predicate: x.AttrInRootNamespace("friend"), ValueType: pb.Posting_UID, diff --git a/worker/task.go b/worker/task.go index 409ec3f0fc4..24b5f4dbe17 100644 --- a/worker/task.go +++ b/worker/task.go @@ -7,6 +7,7 @@ package worker import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -224,6 +225,7 @@ const ( customIndexFn matchFn similarToFn + bm25SearchFn standardFn = 100 ) @@ -266,6 +268,8 @@ func parseFuncTypeHelper(name string) (FuncType, string) { return uidInFn, f case "similar_to": return similarToFn, f + case "bm25": + return bm25SearchFn, f case "anyof", "allof": return customIndexFn, f case "match": @@ -292,6 +296,8 @@ func needsIndex(fnType FuncType, uidList *pb.List) bool { return true case similarToFn: return true + case bm25SearchFn: + return true } return false } @@ -314,7 +320,7 @@ type funcArgs struct { // The function tells us whether we want to fetch value posting lists or uid posting lists. func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) { switch srcFn.fnType { - case aggregatorFn, passwordFn, similarToFn: + case aggregatorFn, passwordFn, similarToFn, bm25SearchFn: return true, nil case compareAttrFn: if len(srcFn.tokens) > 0 { @@ -351,11 +357,15 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er attribute.String("srcFn", x.SafeUTF8(fmt.Sprintf("%+v", args.srcFn))))) switch srcFn.fnType { - case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn: + case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn, bm25SearchFn: default: return errors.Errorf("Unhandled function in handleValuePostings: %s", srcFn.fname) } + if srcFn.fnType == bm25SearchFn { + return qs.handleBM25Search(ctx, args) + } + if srcFn.fnType == similarToFn { numNeighbors, err := strconv.ParseInt(q.SrcFunc.Args[0], 10, 32) if err != nil { @@ -375,6 +385,7 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er return err } var nnUids []uint64 + var nnScores []float64 // Build optional search options if provided filter := index.AcceptAll[float32] opts := index.VectorIndexOptions[float32]{Filter: filter} @@ -385,27 +396,63 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er opts.DistanceThreshold = srcFn.vsDistanceThreshold } hasOptions := opts.EfOverride > 0 || opts.DistanceThreshold != nil - if o, ok := indexer.(index.OptionalSearchOptions[float32]); ok && hasOptions { - if srcFn.vectorInfo != nil { - nnUids, err = o.SearchWithOptions(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) - } else { - nnUids, err = o.SearchWithUidAndOptions(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + // Use the scored search path so the per-uid similarity score can be bound to a + // value variable (powering native hybrid search / fuse()). The scored variants + // mirror their unscored counterparts exactly — SearchScored/SearchWithUidScored + // preserve the plain-query neighbor selection, and the *Options* variants are + // used only when the query supplies ef/distance-threshold — so adding scoring + // does not change which neighbors existing queries return. Indexes that don't + // implement scoring fall back to the unscored path (no scores surfaced). + if so, ok := indexer.(index.ScoredSearchOptions[float32]); ok { + switch { + case hasOptions && srcFn.vectorInfo != nil: + nnUids, nnScores, err = so.SearchWithOptionsScored(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) + case hasOptions: + nnUids, nnScores, err = so.SearchWithUidAndOptionsScored(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + case srcFn.vectorInfo != nil: + nnUids, nnScores, err = so.SearchScored(ctx, qc, srcFn.vectorInfo, int(numNeighbors), index.AcceptAll[float32]) + default: + nnUids, nnScores, err = so.SearchWithUidScored(ctx, qc, srcFn.vectorUid, int(numNeighbors), index.AcceptAll[float32]) } + } else if srcFn.vectorInfo != nil { + nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) } else { - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) - } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) - } + nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) } if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { return err } - sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: nnUids}) + + // Emit uids ascending (required by the query pipeline) with positionally + // aligned similarity scores in ValueMatrix. The query layer binds these to a + // value variable so callers can order by and project the score via val(), and + // so vector results can serve as a fuse() channel. Scores are higher-is-better. + order := make([]int, len(nnUids)) + for i := range order { + order[i] = i + } + sort.Slice(order, func(i, j int) bool { return nnUids[order[i]] < nnUids[order[j]] }) + sortedUids := make([]uint64, len(nnUids)) + for i, idx := range order { + sortedUids[i] = nnUids[idx] + } + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: sortedUids}) + + if len(nnScores) == len(nnUids) && len(nnScores) > 0 { + scoreBuf := make([]byte, len(nnUids)*8) + scoreValues := make([]*pb.ValueList, len(nnUids)) + for i, idx := range order { + off := i * 8 + binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(nnScores[idx])) + scoreValues[i] = &pb.ValueList{ + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_FLOAT}}, + } + } + args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) + } return nil } @@ -1219,6 +1266,122 @@ func needsStringFiltering(srcFn *functionContext, langs []string, attr string) b srcFn.fnType == customIndexFn || srcFn.fnType == ngramFn) } +func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error { + q := args.q + attr := q.Attr + + // 1. Parse args: query text, optional k (default 1.2), b (default 0.75). + if len(q.SrcFunc.Args) < 1 { + return errors.Errorf("bm25 requires at least 1 argument (query text)") + } + queryText := q.SrcFunc.Args[0] + k := 1.2 + b := 0.75 + if len(q.SrcFunc.Args) >= 2 { + var err error + k, err = strconv.ParseFloat(q.SrcFunc.Args[1], 64) + if err != nil { + return errors.Errorf("bm25: invalid k parameter: %s", q.SrcFunc.Args[1]) + } + } + if len(q.SrcFunc.Args) >= 3 { + var err error + b, err = strconv.ParseFloat(q.SrcFunc.Args[2], 64) + if err != nil { + return errors.Errorf("bm25: invalid b parameter: %s", q.SrcFunc.Args[2]) + } + } + if math.IsNaN(k) || math.IsInf(k, 0) || k <= 0 { + return errors.Errorf("bm25: k must be a positive finite number, got %v", k) + } + if math.IsNaN(b) || math.IsInf(b, 0) || b < 0 || b > 1 { + return errors.Errorf("bm25: b must be between 0 and 1, got %v", b) + } + + // 2. Tokenize query (deduplicated) using the fulltext pipeline. The returned + // tokens already carry the BM25 tokenizer identifier byte. + lang := langForFunc(q.Langs) + queryTokens, err := tok.GetBM25QueryTokens([]string{queryText}, lang) + if err != nil { + return err + } + if len(queryTokens) == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + + // 3. Read bucketed corpus stats and derive N and the average document length. + docCount, totalTerms, err := posting.ReadBM25Stats(qs.cache.Get, attr, q.ReadTs) + if err != nil { + return err + } + if docCount == 0 || totalTerms == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + avgDL := float64(totalTerms) / float64(docCount) + N := float64(docCount) + + // Build a filter set if bm25 is used as a filter (@filter(bm25(...))). + var filterSet map[uint64]struct{} + if q.UidList != nil && len(q.UidList.Uids) > 0 { + filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) + for _, uid := range q.UidList.Uids { + filterSet[uid] = struct{}{} + } + } + + // 4. Use WAND top-k early termination only when first is set without an offset; + // otherwise score all matching documents and paginate afterwards. + topK := 0 + if q.First > 0 && q.Offset == 0 { + topK = int(q.First) + } + + // 5. Run WAND / Block-Max WAND over the standard posting lists. + results, err := wandSearch(qs.cache.Get, attr, q.ReadTs, queryTokens, k, b, avgDL, N, + topK, filterSet, true) + if err != nil { + return err + } + + // 6. Paginate score-sorted results when WAND did not already top-k them. + if topK <= 0 && (q.First > 0 || q.Offset > 0) { + offset := int(q.Offset) + if offset > len(results) { + offset = len(results) + } + results = results[offset:] + if q.First > 0 && int(q.First) < len(results) { + results = results[:int(q.First)] + } + } + + // 7. Emit UIDs ascending (required by the query pipeline) with positionally- + // aligned scores in ValueMatrix; the query layer binds these to a value + // variable so callers order by and project the score via val(). + sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) + uids := make([]uint64, len(results)) + for i, r := range results { + uids[i] = r.uid + } + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + + scoreBuf := make([]byte, len(results)*8) + scoreValues := make([]*pb.ValueList, len(results)) + for i, r := range results { + off := i * 8 + binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(r.score)) + // Three-index slice caps capacity at 8 so a downstream append can't corrupt + // adjacent scores in the shared backing array. + scoreValues[i] = &pb.ValueList{ + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_FLOAT}}, + } + } + args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) + return nil +} + func (qs *queryState) handleCompareScalarFunction(ctx context.Context, arg funcArgs) error { attr := arg.q.Attr if ok := schema.State().HasCount(ctx, attr); !ok { @@ -2167,6 +2330,18 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { return nil, err } checkRoot(q, fc) + case bm25SearchFn: + // bm25(pred, "query text") or bm25(pred, "query text", "k", "b") + if len(q.SrcFunc.Args) < 1 || len(q.SrcFunc.Args) > 3 { + return nil, errors.Errorf("Function 'bm25' requires 1-3 arguments (query [, k, b]), but got %d", + len(q.SrcFunc.Args)) + } + required, found := verifyStringIndex(ctx, attr, fnType) + if !found { + return nil, errors.Errorf("Attribute %s is not indexed with type %s", x.ParseAttr(attr), + required) + } + checkRoot(q, fc) case similarToFn: // similar_to accepts 2 mandatory args: k, vector_or_uid followed by optional key:value pairs // Example: similar_to(vpred, 3, $vec, ef: 64, distance_threshold: 0.5) diff --git a/worker/tokens.go b/worker/tokens.go index 2740d29f447..b8c85a22816 100644 --- a/worker/tokens.go +++ b/worker/tokens.go @@ -25,6 +25,8 @@ func verifyStringIndex(ctx context.Context, attr string, funcType FuncType) (str requiredTokenizer = tok.NGramTokenizer{} case fullTextSearchFn: requiredTokenizer = tok.FullTextTokenizer{} + case bm25SearchFn: + requiredTokenizer = tok.BM25Tokenizer{} case matchFn: requiredTokenizer = tok.TrigramTokenizer{} default: @@ -65,6 +67,9 @@ func getStringTokens(funcArgs []string, lang string, funcType FuncType, query bo if funcType == fullTextSearchFn { return tok.GetFullTextTokens(funcArgs, lang) } + if funcType == bm25SearchFn { + return tok.GetBM25QueryTokens(funcArgs, lang) + } if funcType == ngramFn { if query { return tok.GetNGramQueryTokens(funcArgs, lang) diff --git a/x/keys.go b/x/keys.go index 94112d07c03..a88dc4640a1 100644 --- a/x/keys.go +++ b/x/keys.go @@ -291,6 +291,30 @@ func CountKey(attr string, count uint32, reverse bool) []byte { return buf } +// BM25IndexKey generates the index key for a BM25 term posting list. The +// encodedToken already carries the BM25 tokenizer identifier byte, so BM25 term +// postings live at the same standard index key as every other tokenizer — +// IndexKey(attr, identifier || term) — and inherit rollup, splits, backup, and +// index-rebuild handling for free. This is a thin alias of IndexKey so the index +// write path and the query read path share one definition. +func BM25IndexKey(attr string, encodedToken string) []byte { + return IndexKey(attr, encodedToken) +} + +// bm25StatsPrefix namespaces the BM25 corpus-statistics keys. These hold the +// document count and total term count (used to derive the average document +// length); they are auxiliary metadata, not term postings, so they use a reserved +// token that cannot collide with any stemmed BM25 term. +const bm25StatsPrefix = "\x00_bm25stats_" + +// BM25StatsKey generates the key for one bucket of BM25 corpus statistics. Stats +// are sharded across buckets (keyed by uid%numBuckets) to spread write contention. +func BM25StatsKey(attr string, bucket int) []byte { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], uint16(bucket)) + return IndexKey(attr, bm25StatsPrefix+string(buf[:])) +} + // ParsedKey represents a key that has been parsed into its multiple attributes. type ParsedKey struct { Attr string