From 4dcb5ed2b5ba73c7a42a7f37f98ff3d4e91efcb9 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 16:26:46 -0500 Subject: [PATCH 01/19] feat(search): add BM25 ranked text search Add BM25 relevance-ranked text search to Dgraph, enabling users to query text predicates and receive results ordered by relevance score instead of boolean matching. Implementation: - New BM25 tokenizer using the fulltext pipeline (normalize, stopwords, stem) that preserves term frequencies for TF counting - BM25-specific index storage: per-term TF posting lists, doc length lists, and corpus statistics (doc count, total terms) - Query execution with full BM25 scoring: score = IDF * (k+1) * tf / (k * (1 - b + b * dl/avgDL) + tf) IDF = log1p((N - df + 0.5) / (df + 0.5)) - DQL syntax: bm25(predicate, "query" [, "k", "b"]) as root func or filter - Schema syntax: @index(bm25) - Parameter validation (k > 0, 0 <= b <= 1) - Early UID intersection for filter-mode performance - All-stopword document and query handling Co-Authored-By: Claude Opus 4.6 --- dql/parser.go | 2 +- posting/index.go | 183 +++++++++++++++++++++++++++++++++ query/common_test.go | 17 +++ query/query_bm25_test.go | 214 ++++++++++++++++++++++++++++++++++++++ tok/tok.go | 43 ++++++++ tok/tok_test.go | 140 +++++++++++++++++++++++++ tok/tokens.go | 25 +++++ worker/task.go | 216 ++++++++++++++++++++++++++++++++++++++- worker/tokens.go | 5 + x/keys.go | 19 ++++ 10 files changed, 861 insertions(+), 3 deletions(-) create mode 100644 query/query_bm25_test.go diff --git a/dql/parser.go b/dql/parser.go index 0dd6e1db7ac..666c3eacaab 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -1701,7 +1701,7 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": return true } return false diff --git a/posting/index.go b/posting/index.go index ae6c3352a44..88c0e5920a9 100644 --- a/posting/index.go +++ b/posting/index.go @@ -68,6 +68,10 @@ func indexTokens(ctx context.Context, info *indexMutationInfo) ([]string, error) var tokens []string for _, it := range info.tokenizers { + // BM25 tokenizer is handled separately in addBM25IndexMutations. + if it.Identifier() == tok.IdentBM25 { + continue + } toks, err := tok.BuildTokens(sv.Value, tok.GetTokenizerForLang(it, lang)) if err != nil { return tokens, err @@ -179,6 +183,17 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) } } + // Check if any tokenizer is BM25 and handle separately. + for _, it := range info.tokenizers { + if _, ok := tok.GetTokenizerForLang(it, info.edge.GetLang()).(tok.BM25Tokenizer); ok { + if err := txn.addBM25IndexMutations(ctx, info); err != nil { + return []*pb.DirectedEdge{}, err + } + // Continue to process remaining non-BM25 tokenizers below. + continue + } + } + tokens, err := indexTokens(ctx, info) if err != nil { // This data is not indexable @@ -215,6 +230,174 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok return nil } +// addBM25IndexMutations handles index mutations for the BM25 tokenizer. +// It stores term frequencies, document lengths, and corpus statistics. +func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { + attr := info.edge.Attr + uid := info.edge.Entity + lang := info.edge.GetLang() + + schemaType, err := schema.State().TypeOf(attr) + if err != nil || !schemaType.IsScalar() { + return errors.Errorf("Cannot BM25 index attribute %s of type object.", attr) + } + + sv, err := types.Convert(info.val, schemaType) + if err != nil { + return err + } + + bm25Tok := tok.BM25Tokenizer{} + termFreqs, docLen, err := bm25Tok.TokensWithFrequency(sv.Value, lang) + if err != nil { + return err + } + + // Skip documents that tokenize to zero terms (e.g., all stopwords). + if docLen == 0 { + return nil + } + + if info.op == pb.DirectedEdge_DEL { + // For DELETE: remove uid from all BM25 term posting lists, doc length list, + // and decrement corpus stats. + for term := range termFreqs { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_DEL, + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + } + // Remove doc length entry. + dlKey := x.BM25DocLenKey(attr) + dlPlist, err := txn.cache.GetFromDelta(dlKey) + if err != nil { + return err + } + dlEdge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_DEL, + } + if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { + return err + } + + // Update corpus stats: decrement doc count and total terms. + return txn.updateBM25Stats(ctx, attr, -1, -int64(docLen)) + } + + // For SET: store term frequencies, doc length, and update corpus stats. + for term, tf := range termFreqs { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + // Store uid in the posting list. The TF is encoded in the Value field. + tfBuf := make([]byte, 4) + binary.BigEndian.PutUint32(tfBuf, tf) + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Value: tfBuf, + ValueType: pb.Posting_INT, + Op: pb.DirectedEdge_SET, + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + } + + // Store document length. + dlKey := x.BM25DocLenKey(attr) + dlPlist, err := txn.cache.GetFromDelta(dlKey) + if err != nil { + return err + } + dlBuf := make([]byte, 4) + binary.BigEndian.PutUint32(dlBuf, docLen) + dlEdge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Value: dlBuf, + ValueType: pb.Posting_INT, + Op: pb.DirectedEdge_SET, + } + if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { + return err + } + + // Update corpus stats: increment doc count by 1 and total terms by docLen. + return txn.updateBM25Stats(ctx, attr, 1, int64(docLen)) +} + +// updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, +// applies the given deltas, and writes back. +func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, docCountDelta int64, totalTermsDelta int64) error { + statsKey := x.BM25StatsKey(attr) + plist, err := txn.cache.GetFromDelta(statsKey) + if err != nil { + return err + } + + // Read existing stats from posting with uid=1. + var docCount, totalTerms uint64 + val, err := plist.Value(txn.StartTs) + if err == nil && val.Value != nil { + data, ok := val.Value.([]byte) + if ok && len(data) == 16 { + docCount = binary.BigEndian.Uint64(data[0:8]) + totalTerms = binary.BigEndian.Uint64(data[8:16]) + } + } + + // Apply deltas. + if docCountDelta >= 0 { + docCount += uint64(docCountDelta) + } else { + dec := uint64(-docCountDelta) + if dec > docCount { + docCount = 0 + } else { + docCount -= dec + } + } + if totalTermsDelta >= 0 { + totalTerms += uint64(totalTermsDelta) + } else { + dec := uint64(-totalTermsDelta) + if dec > totalTerms { + totalTerms = 0 + } else { + totalTerms -= dec + } + } + + // Write back stats. + statsBuf := make([]byte, 16) + binary.BigEndian.PutUint64(statsBuf[0:8], docCount) + binary.BigEndian.PutUint64(statsBuf[8:16], totalTerms) + edge := &pb.DirectedEdge{ + Entity: 1, + Attr: attr, + Value: statsBuf, + ValueType: pb.Posting_ValType(0), + Op: pb.DirectedEdge_SET, + } + return plist.addMutation(ctx, txn, edge) +} + // countParams is sent to updateCount function. It is used to update the count index. // It deletes the uid from the key corresponding to and adds it // to . diff --git a/query/common_test.go b/query/common_test.go index e36211f7a18..32a3e65a81b 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -390,6 +390,11 @@ func populateCluster(dc dgraphapi.Cluster) { testSchema += "\ndescription: string @index(ngram) ." } + // BM25 indexing - uses same version gate as ngram for now + if ngramSupport { + testSchema += "\ndescription_bm25: string @index(bm25) ." + } + setSchema(testSchema) err = addTriplesToCluster(` @@ -1007,4 +1012,16 @@ func populateCluster(dc dgraphapi.Cluster) { <415> "Linguistic analysis helps understand text meaning" . `) x.Panic(err) + + // Add data for BM25 tests - uses separate predicate to avoid conflicts + err = addTriplesToCluster(` + <501> "The quick brown fox jumps over the lazy dog" . + <502> "A quick brown fox leaps over a sleeping dog" . + <503> "fox fox fox" . + <504> "The lazy dog sleeps under the warm sun all day long in the garden" . + <505> "Dogs are loyal companions to humans and families everywhere" . + <506> "Quick movements help foxes catch their prey in the wild" . + <507> "Brown foxes are quick and agile animals in the forest" . + `) + x.Panic(err) } diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go new file mode 100644 index 00000000000..f0a3a0c16a9 --- /dev/null +++ b/query/query_bm25_test.go @@ -0,0 +1,214 @@ +//go:build integration || cloud + +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +//nolint:lll +package query + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBM25Basic(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick brown fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return documents containing "quick", "brown", or "fox" + require.Contains(t, js, "quick brown fox jumps") + require.Contains(t, js, "quick brown fox leaps") +} + +func TestBM25Ordering(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Document 503 has "fox fox fox" (tf=3, short doc) so should rank highest. + // Verify it appears before other fox-containing documents in the output. + foxFoxFoxIdx := strings.Index(js, "fox fox fox") + quickBrownIdx := strings.Index(js, "quick brown fox jumps") + require.Greater(t, foxFoxFoxIdx, -1, "should contain 'fox fox fox'") + require.Greater(t, quickBrownIdx, -1, "should contain 'quick brown fox jumps'") + require.Less(t, foxFoxFoxIdx, quickBrownIdx, + "'fox fox fox' (higher tf, shorter doc) should rank before 'quick brown fox jumps'") +} + +func TestBM25WithParams(t *testing.T) { + // Custom k and b parameters + query := ` + { + me(func: bm25(description_bm25, "fox", "1.5", "0.5")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") +} + +func TestBM25InvalidParams(t *testing.T) { + // Negative k should be rejected. + query := ` + { + me(func: bm25(description_bm25, "fox", "-1.0", "0.75")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: k must be a positive finite number") + + // b > 1 should be rejected. + query2 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "1.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query2) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") + + // b < 0 should be rejected. + query3 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "-0.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query3) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") +} + +func TestBM25AsFilter(t *testing.T) { + query := ` + { + me(func: has(description_bm25)) @filter(bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") + // Should not contain documents without "fox" + require.NotContains(t, js, "Dogs are loyal") +} + +func TestBM25NoResults(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "xyznonexistent")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25SingleTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "dog")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "dog") +} + +func TestBM25MultiTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick lazy")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should find docs with "quick" or "lazy" (scores accumulate). + // Doc 501 has both "quick" and "lazy", so it should rank high. + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25AllStopwords(t *testing.T) { + // A query consisting entirely of stopwords should return no results. + query := ` + { + me(func: bm25(description_bm25, "the a an")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25EmptyPredicate(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "")) { + uid + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25WithCount(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox")) { + count(uid) + } + } + ` + js := processQueryNoErr(t, query) + // Should have at least 2 results (docs with "fox") + require.Contains(t, js, "count") +} + +func TestBM25Pagination(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // With first:1, should return exactly one result (the highest-scoring). + // Doc 503 "fox fox fox" should be the top result. + require.Contains(t, js, "fox fox fox") +} diff --git a/tok/tok.go b/tok/tok.go index c1da3e991d7..cb50b0a369e 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -50,6 +50,7 @@ const ( IdentBigFloat = 0xD IdentVFloat = 0xE IdentNGram = 0xF + IdentBM25 = 0x10 IdentCustom = 0x80 IdentDelimiter = 0x1f // ASCII 31 - Unit separator ) @@ -101,6 +102,7 @@ func init() { registerTokenizer(TermTokenizer{}) registerTokenizer(FullTextTokenizer{}) registerTokenizer(NGramTokenizer{}) + registerTokenizer(BM25Tokenizer{}) registerTokenizer(Sha256Tokenizer{}) setupBleve() } @@ -576,6 +578,47 @@ func (t FullTextTokenizer) Identifier() byte { return IdentFullText } func (t FullTextTokenizer) IsSortable() bool { return false } func (t FullTextTokenizer) IsLossy() bool { return true } +// BM25Tokenizer generates tokens for BM25 ranked text search. +// It uses the same pipeline as FullTextTokenizer (normalize, stopwords, stem) +// but preserves duplicates for term frequency counting. +type BM25Tokenizer struct{ lang string } + +func (t BM25Tokenizer) Name() string { return "bm25" } +func (t BM25Tokenizer) Type() string { return "string" } +func (t BM25Tokenizer) Tokens(v interface{}) ([]string, error) { + str, ok := v.(string) + if !ok || str == "" { + return []string{}, nil + } + lang := LangBase(t.lang) + tokens := fulltextAnalyzer.Analyze([]byte(str)) + tokens = filterStopwords(lang, tokens) + tokens = filterStemmers(lang, tokens) + // Return all tokens with duplicates preserved (for TF counting). + result := make([]string, 0, len(tokens)) + for _, t := range tokens { + result = append(result, string(t.Term)) + } + return result, nil +} +func (t BM25Tokenizer) Identifier() byte { return IdentBM25 } +func (t BM25Tokenizer) IsSortable() bool { return false } +func (t BM25Tokenizer) IsLossy() bool { return true } + +// TokensWithFrequency tokenizes the input and returns term frequencies and doc length. +func (t BM25Tokenizer) TokensWithFrequency(v interface{}, lang string) (map[string]uint32, uint32, error) { + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(v) + if err != nil { + return nil, 0, err + } + termFreqs := make(map[string]uint32, len(allTokens)) + for _, t := range allTokens { + termFreqs[t]++ + } + return termFreqs, uint32(len(allTokens)), nil +} + // Sha256Tokenizer generates tokens for the sha256 hash part from string data. type Sha256Tokenizer struct{ _ string } diff --git a/tok/tok_test.go b/tok/tok_test.go index 4c95094e577..b9fbc4dd1a5 100644 --- a/tok/tok_test.go +++ b/tok/tok_test.go @@ -652,6 +652,146 @@ func TestNGramTokenizerNonStringInput(t *testing.T) { require.Equal(t, 0, len(tokens2), "Expected empty tokens for nil input") } +func TestBM25Tokenizer(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + require.NotNil(t, tokenizer) + require.Equal(t, "bm25", tokenizer.Name()) + require.Equal(t, "string", tokenizer.Type()) + require.Equal(t, byte(IdentBM25), tokenizer.Identifier()) + require.True(t, tokenizer.IsLossy()) + require.False(t, tokenizer.IsSortable()) +} + +func TestBM25TokensPreservesDuplicates(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("fox fox fox dog") + require.NoError(t, err) + // "fox" should appear 3 times (duplicates preserved), "dog" once + foxCount := 0 + dogCount := 0 + for _, token := range tokens { + if token == "fox" { + foxCount++ + } + if token == "dog" { + dogCount++ + } + } + require.Equal(t, 3, foxCount, "Expected 3 occurrences of 'fox'") + require.Equal(t, 1, dogCount, "Expected 1 occurrence of 'dog'") +} + +func TestBM25TokensWithFrequency(t *testing.T) { + tok := BM25Tokenizer{} + termFreqs, docLen, err := tok.TokensWithFrequency("the quick brown fox fox fox", "en") + require.NoError(t, err) + // "the" is a stopword and should be removed + _, hasThe := termFreqs["the"] + require.False(t, hasThe, "'the' should be removed as stopword") + // "fox" should have tf=3 + require.Equal(t, uint32(3), termFreqs["fox"]) + // "quick" -> "quick" (stemmed) + require.Contains(t, termFreqs, "quick") + require.Equal(t, uint32(1), termFreqs["quick"]) + // "brown" -> "brown" (stemmed) + require.Contains(t, termFreqs, "brown") + require.Equal(t, uint32(1), termFreqs["brown"]) + // docLen should be total tokens after stopword removal + require.Equal(t, uint32(5), docLen) +} + +func TestBM25TokensEmpty(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestBM25TokensSingleWord(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("hello") + require.NoError(t, err) + require.Equal(t, 1, len(tokens)) + require.Equal(t, "hello", tokens[0]) +} + +func TestBM25TokensStemming(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("running jumping swimming") + require.NoError(t, err) + require.Equal(t, 3, len(tokens)) + require.Contains(t, tokens, "run") + require.Contains(t, tokens, "jump") + require.Contains(t, tokens, "swim") +} + +func TestGetBM25QueryTokens(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"quick brown fox fox"}, "en") + require.NoError(t, err) + // Query tokens should be deduplicated + require.Equal(t, 3, len(tokens)) + // Each token should be encoded with the BM25 identifier prefix + for _, token := range tokens { + require.Equal(t, byte(IdentBM25), token[0], "Token should start with BM25 identifier") + } +} + +func TestGetBM25QueryTokensEmpty(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{""}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestBM25TokenizerForLang(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + langTok := GetTokenizerForLang(tokenizer, "de") + bm25Tok, ok := langTok.(BM25Tokenizer) + require.True(t, ok) + // German: "Katzen" -> "katz" (stemmed) + tokens, err := bm25Tok.Tokens("Katzen und Katzen") + require.NoError(t, err) + // "und" is a German stopword + katzCount := 0 + for _, token := range tokens { + if token == "katz" { + katzCount++ + } + } + require.Equal(t, 2, katzCount, "Expected 2 occurrences of stemmed 'katz'") +} + +func TestBM25AllStopwords(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("the a an is") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("the a an is", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestGetBM25QueryTokensAllStopwords(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"the a an"}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestGetBM25QueryTokensWrongArgCount(t *testing.T) { + _, err := GetBM25QueryTokens([]string{}, "en") + require.Error(t, err) + _, err = GetBM25QueryTokens([]string{"a", "b"}, "en") + require.Error(t, err) +} + func BenchmarkTermTokenizer(b *testing.B) { b.Skip() // tmp } diff --git a/tok/tokens.go b/tok/tokens.go index bda9a04e743..f089a3f4344 100644 --- a/tok/tokens.go +++ b/tok/tokens.go @@ -25,6 +25,8 @@ func GetTokenizerForLang(t Tokenizer, lang string) Tokenizer { // We must return a new instance because another goroutine might be calling this // with a different lang. return FullTextTokenizer{lang: lang} + case BM25Tokenizer: + return BM25Tokenizer{lang: lang} case TermTokenizer: return TermTokenizer{lang: lang} case ExactTokenizer: @@ -67,6 +69,29 @@ func GetNGramQueryTokens(funcArgs []string, lang string) ([]string, error) { return BuildNGramQueryTokens(funcArgs[0], NGramTokenizer{lang: lang}) } +// GetBM25QueryTokens tokenizes the query text using the fulltext pipeline, +// deduplicates, and encodes with the BM25 identifier prefix. +func GetBM25QueryTokens(funcArgs []string, lang string) ([]string, error) { + if l := len(funcArgs); l != 1 { + return nil, errors.Errorf("Function requires 1 arguments, but got %d", l) + } + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(funcArgs[0]) + if err != nil { + return nil, err + } + // Deduplicate for query + seen := make(map[string]struct{}, len(allTokens)) + var unique []string + for _, t := range allTokens { + if _, ok := seen[t]; !ok { + seen[t] = struct{}{} + unique = append(unique, encodeToken(t, tok.Identifier())) + } + } + return unique, nil +} + // GetFullTextTokens returns the full-text tokens for the given value. func GetFullTextTokens(funcArgs []string, lang string) ([]string, error) { if l := len(funcArgs); l != 1 { diff --git a/worker/task.go b/worker/task.go index 409ec3f0fc4..1da128c76bc 100644 --- a/worker/task.go +++ b/worker/task.go @@ -7,6 +7,7 @@ package worker import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -224,6 +225,7 @@ const ( customIndexFn matchFn similarToFn + bm25SearchFn standardFn = 100 ) @@ -266,6 +268,8 @@ func parseFuncTypeHelper(name string) (FuncType, string) { return uidInFn, f case "similar_to": return similarToFn, f + case "bm25": + return bm25SearchFn, f case "anyof", "allof": return customIndexFn, f case "match": @@ -292,6 +296,8 @@ func needsIndex(fnType FuncType, uidList *pb.List) bool { return true case similarToFn: return true + case bm25SearchFn: + return true } return false } @@ -314,7 +320,7 @@ type funcArgs struct { // The function tells us whether we want to fetch value posting lists or uid posting lists. func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) { switch srcFn.fnType { - case aggregatorFn, passwordFn, similarToFn: + case aggregatorFn, passwordFn, similarToFn, bm25SearchFn: return true, nil case compareAttrFn: if len(srcFn.tokens) > 0 { @@ -351,11 +357,15 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er attribute.String("srcFn", x.SafeUTF8(fmt.Sprintf("%+v", args.srcFn))))) switch srcFn.fnType { - case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn: + case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn, bm25SearchFn: default: return errors.Errorf("Unhandled function in handleValuePostings: %s", srcFn.fname) } + if srcFn.fnType == bm25SearchFn { + return qs.handleBM25Search(ctx, args) + } + if srcFn.fnType == similarToFn { numNeighbors, err := strconv.ParseInt(q.SrcFunc.Args[0], 10, 32) if err != nil { @@ -1219,6 +1229,196 @@ func needsStringFiltering(srcFn *functionContext, langs []string, attr string) b srcFn.fnType == customIndexFn || srcFn.fnType == ngramFn) } +func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error { + q := args.q + attr := q.Attr + + // 1. Parse args: query text, optional k (default 1.2), b (default 0.75). + if len(q.SrcFunc.Args) < 1 { + return errors.Errorf("bm25 requires at least 1 argument (query text)") + } + queryText := q.SrcFunc.Args[0] + k := 1.2 + b := 0.75 + if len(q.SrcFunc.Args) >= 2 { + var err error + k, err = strconv.ParseFloat(q.SrcFunc.Args[1], 64) + if err != nil { + return errors.Errorf("bm25: invalid k parameter: %s", q.SrcFunc.Args[1]) + } + } + if len(q.SrcFunc.Args) >= 3 { + var err error + b, err = strconv.ParseFloat(q.SrcFunc.Args[2], 64) + if err != nil { + return errors.Errorf("bm25: invalid b parameter: %s", q.SrcFunc.Args[2]) + } + } + if math.IsNaN(k) || math.IsInf(k, 0) || k <= 0 { + return errors.Errorf("bm25: k must be a positive finite number, got %v", k) + } + if math.IsNaN(b) || math.IsInf(b, 0) || b < 0 || b > 1 { + return errors.Errorf("bm25: b must be between 0 and 1, got %v", b) + } + + // 2. Tokenize query (deduplicated) using fulltext pipeline. + lang := langForFunc(q.Langs) + queryTokens, err := tok.GetBM25QueryTokens([]string{queryText}, lang) + if err != nil { + return err + } + if len(queryTokens) == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + + // 3. Read corpus stats. + statsKey := x.BM25StatsKey(attr) + statsPl, err := qs.cache.Get(statsKey) + if err != nil { + // No stats means no documents indexed yet. + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + statsVal, err := statsPl.Value(q.ReadTs) + if err != nil || statsVal.Value == nil { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + statsData, ok := statsVal.Value.([]byte) + if !ok || len(statsData) != 16 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + docCount := binary.BigEndian.Uint64(statsData[0:8]) + totalTerms := binary.BigEndian.Uint64(statsData[8:16]) + if docCount == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + if totalTerms == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + avgDL := float64(totalTerms) / float64(docCount) + N := float64(docCount) + + // Build filter set early if used as a filter, for efficient intersection during iteration. + var filterSet map[uint64]struct{} + if q.UidList != nil && len(q.UidList.Uids) > 0 { + filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) + for _, uid := range q.UidList.Uids { + filterSet[uid] = struct{}{} + } + } + + // 4. For each query token, read the posting list and collect term info. + type termInfo struct { + idf float64 + uidTFs map[uint64]uint32 + } + termInfos := make(map[string]*termInfo) + + for _, token := range queryTokens { + key := x.BM25IndexKey(attr, token) + pl, err := qs.cache.Get(key) + if err != nil { + continue + } + + ti := &termInfo{uidTFs: make(map[uint64]uint32)} + var df float64 + err = pl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { + df++ + // When used as filter, only collect TF for UIDs in the filter set. + if filterSet != nil { + if _, ok := filterSet[p.Uid]; !ok { + return nil + } + } + tf := uint32(1) + if len(p.Value) >= 4 { + tf = binary.BigEndian.Uint32(p.Value[:4]) + } + ti.uidTFs[p.Uid] = tf + return nil + }) + if err != nil { + continue + } + ti.idf = math.Log1p((N - df + 0.5) / (df + 0.5)) + termInfos[token] = ti + } + + // 5. Read doc lengths for all UIDs seen. + allUids := make(map[uint64]struct{}) + for _, ti := range termInfos { + for uid := range ti.uidTFs { + allUids[uid] = struct{}{} + } + } + + docLens := make(map[uint64]uint32) + dlKey := x.BM25DocLenKey(attr) + dlPl, err := qs.cache.Get(dlKey) + if err == nil { + remaining := len(allUids) + _ = dlPl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { + if remaining == 0 { + return posting.ErrStopIteration + } + if _, needed := allUids[p.Uid]; needed { + dl := uint32(1) + if len(p.Value) >= 4 { + dl = binary.BigEndian.Uint32(p.Value[:4]) + } + docLens[p.Uid] = dl + remaining-- + } + return nil + }) + } + + // 6. Compute final BM25 scores. + scores := make(map[uint64]float64) + for _, ti := range termInfos { + for uid, tf := range ti.uidTFs { + dl := float64(1) + if v, ok := docLens[uid]; ok { + dl = float64(v) + } + tfFloat := float64(tf) + score := ti.idf * (k + 1) * tfFloat / (k*(1-b+b*dl/avgDL) + tfFloat) + scores[uid] += score + } + } + + // 7. Sort by score descending. + type uidScore struct { + uid uint64 + score float64 + } + results := make([]uidScore, 0, len(scores)) + for uid, score := range scores { + results = append(results, uidScore{uid: uid, score: score}) + } + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].uid < results[j].uid + }) + + // Build output UIDs. + uids := make([]uint64, len(results)) + for i, r := range results { + uids[i] = r.uid + } + + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + return nil +} + func (qs *queryState) handleCompareScalarFunction(ctx context.Context, arg funcArgs) error { attr := arg.q.Attr if ok := schema.State().HasCount(ctx, attr); !ok { @@ -2167,6 +2367,18 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { return nil, err } checkRoot(q, fc) + case bm25SearchFn: + // bm25(pred, "query text") or bm25(pred, "query text", "k", "b") + if len(q.SrcFunc.Args) < 1 || len(q.SrcFunc.Args) > 3 { + return nil, errors.Errorf("Function 'bm25' requires 1-3 arguments (query [, k, b]), but got %d", + len(q.SrcFunc.Args)) + } + required, found := verifyStringIndex(ctx, attr, fnType) + if !found { + return nil, errors.Errorf("Attribute %s is not indexed with type %s", x.ParseAttr(attr), + required) + } + checkRoot(q, fc) case similarToFn: // similar_to accepts 2 mandatory args: k, vector_or_uid followed by optional key:value pairs // Example: similar_to(vpred, 3, $vec, ef: 64, distance_threshold: 0.5) diff --git a/worker/tokens.go b/worker/tokens.go index 2740d29f447..b8c85a22816 100644 --- a/worker/tokens.go +++ b/worker/tokens.go @@ -25,6 +25,8 @@ func verifyStringIndex(ctx context.Context, attr string, funcType FuncType) (str requiredTokenizer = tok.NGramTokenizer{} case fullTextSearchFn: requiredTokenizer = tok.FullTextTokenizer{} + case bm25SearchFn: + requiredTokenizer = tok.BM25Tokenizer{} case matchFn: requiredTokenizer = tok.TrigramTokenizer{} default: @@ -65,6 +67,9 @@ func getStringTokens(funcArgs []string, lang string, funcType FuncType, query bo if funcType == fullTextSearchFn { return tok.GetFullTextTokens(funcArgs, lang) } + if funcType == bm25SearchFn { + return tok.GetBM25QueryTokens(funcArgs, lang) + } if funcType == ngramFn { if query { return tok.GetNGramQueryTokens(funcArgs, lang) diff --git a/x/keys.go b/x/keys.go index 94112d07c03..23196fd89c9 100644 --- a/x/keys.go +++ b/x/keys.go @@ -291,6 +291,25 @@ func CountKey(attr string, count uint32, reverse bool) []byte { return buf } +// BM25Prefix is the prefix used for BM25 index keys to prevent collision +// with regular fulltext index tokens. +const BM25Prefix = "\x00_bm25_" + +// BM25IndexKey generates an index key for a BM25 term posting list. +func BM25IndexKey(attr string, token string) []byte { + return IndexKey(attr, BM25Prefix+token) +} + +// BM25DocLenKey generates the key for the BM25 document length posting list. +func BM25DocLenKey(attr string) []byte { + return IndexKey(attr, BM25Prefix+"__doclen__") +} + +// BM25StatsKey generates the key for BM25 corpus statistics. +func BM25StatsKey(attr string) []byte { + return IndexKey(attr, BM25Prefix+"__stats__") +} + // ParsedKey represents a key that has been parsed into its multiple attributes. type ParsedKey struct { Attr string From ed720da20d9ba4984bfad6cc61d22656a50c3a6c Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 17:18:15 -0500 Subject: [PATCH 02/19] fix(bm25): store TF/doclen in facets and fix query pipeline integration Three critical bugs fixed: 1. REF postings lose Value during rollup: The posting list encode/rollup cycle strips the Value field from REF postings without facets (list.go:1630). BM25 term frequencies and doc lengths were stored in Value and lost. Fix: Store TF and doclen as facets on REF postings, which are preserved. 2. Missing function validation: query/query.go has a separate isValidFuncName check from dql/parser.go. "bm25" was only added to the parser, causing "Invalid function name: bm25" at query time. 3. Unsorted UIDs break query pipeline: BM25 returned UIDs sorted by score, but the query pipeline (algo.MergeSorted, child predicate fetching) requires UID-ascending order. Fix: Sort UIDs ascending in UidMatrix, apply first/offset pagination on score-sorted results before UID sorting. Co-Authored-By: Claude Opus 4.6 --- posting/index.go | 22 +++++++++++----------- query/query.go | 2 +- query/query_bm25_test.go | 27 ++++++++++++++++++--------- worker/task.go | 30 +++++++++++++++++++++++++----- 4 files changed, 55 insertions(+), 26 deletions(-) diff --git a/posting/index.go b/posting/index.go index 88c0e5920a9..a24f0bac2e6 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,6 +28,7 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" + "github.com/dgraph-io/dgo/v250/protos/api" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" "github.com/dgraph-io/dgraph/v25/tok" @@ -304,15 +305,15 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn if err != nil { return err } - // Store uid in the posting list. The TF is encoded in the Value field. + // Store uid in the posting list. TF is stored as a facet so it survives + // the rollup cycle (REF postings without facets lose their Value field). tfBuf := make([]byte, 4) binary.BigEndian.PutUint32(tfBuf, tf) edge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Value: tfBuf, - ValueType: pb.Posting_INT, - Op: pb.DirectedEdge_SET, + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_SET, + Facets: []*api.Facet{{Key: "tf", Value: tfBuf, ValType: api.Facet_INT}}, } if err := plist.addMutation(ctx, txn, edge); err != nil { return err @@ -328,11 +329,10 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn dlBuf := make([]byte, 4) binary.BigEndian.PutUint32(dlBuf, docLen) dlEdge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Value: dlBuf, - ValueType: pb.Posting_INT, - Op: pb.DirectedEdge_SET, + ValueId: uid, + Attr: attr, + Op: pb.DirectedEdge_SET, + Facets: []*api.Facet{{Key: "dl", Value: dlBuf, ValType: api.Facet_INT}}, } if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { return err diff --git a/query/query.go b/query/query.go index 6926e2ac6ed..3025033e1e0 100644 --- a/query/query.go +++ b/query/query.go @@ -2751,7 +2751,7 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": return true } return isInequalityFn(f) || types.IsGeoFunc(f) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index f0a3a0c16a9..dcceece7428 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -10,7 +10,6 @@ package query import ( "context" - "strings" "testing" "github.com/stretchr/testify/require" @@ -32,6 +31,8 @@ func TestBM25Basic(t *testing.T) { } func TestBM25Ordering(t *testing.T) { + // BM25 returns all matching documents. Use first:1 to verify the highest-scored + // document is "fox fox fox" (tf=3, short doc). query := ` { me(func: bm25(description_bm25, "fox")) { @@ -41,14 +42,22 @@ func TestBM25Ordering(t *testing.T) { } ` js := processQueryNoErr(t, query) - // Document 503 has "fox fox fox" (tf=3, short doc) so should rank highest. - // Verify it appears before other fox-containing documents in the output. - foxFoxFoxIdx := strings.Index(js, "fox fox fox") - quickBrownIdx := strings.Index(js, "quick brown fox jumps") - require.Greater(t, foxFoxFoxIdx, -1, "should contain 'fox fox fox'") - require.Greater(t, quickBrownIdx, -1, "should contain 'quick brown fox jumps'") - require.Less(t, foxFoxFoxIdx, quickBrownIdx, - "'fox fox fox' (higher tf, shorter doc) should rank before 'quick brown fox jumps'") + // Should contain all fox-mentioning documents. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + + // first:1 should return the top-ranked document. + topQuery := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + topJs := processQueryNoErr(t, topQuery) + require.Contains(t, topJs, "fox fox fox", + "top-1 BM25 result for 'fox' should be 'fox fox fox' (highest tf, shortest doc)") } func TestBM25WithParams(t *testing.T) { diff --git a/worker/task.go b/worker/task.go index 1da128c76bc..2fbd65acca4 100644 --- a/worker/task.go +++ b/worker/task.go @@ -1337,8 +1337,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } tf := uint32(1) - if len(p.Value) >= 4 { - tf = binary.BigEndian.Uint32(p.Value[:4]) + for _, f := range p.Facets { + if f.Key == "tf" && len(f.Value) >= 4 { + tf = binary.BigEndian.Uint32(f.Value[:4]) + break + } } ti.uidTFs[p.Uid] = tf return nil @@ -1369,8 +1372,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } if _, needed := allUids[p.Uid]; needed { dl := uint32(1) - if len(p.Value) >= 4 { - dl = binary.BigEndian.Uint32(p.Value[:4]) + for _, f := range p.Facets { + if f.Key == "dl" && len(f.Value) >= 4 { + dl = binary.BigEndian.Uint32(f.Value[:4]) + break + } } docLens[p.Uid] = dl remaining-- @@ -1402,6 +1408,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error for uid, score := range scores { results = append(results, uidScore{uid: uid, score: score}) } + // Sort by score descending for ordering, then collect UIDs. sort.Slice(results, func(i, j int) bool { if results[i].score != results[j].score { return results[i].score > results[j].score @@ -1409,11 +1416,24 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return results[i].uid < results[j].uid }) - // Build output UIDs. + // Apply first/offset pagination on score-sorted results before returning UIDs. + if q.First > 0 || q.Offset > 0 { + offset := int(q.Offset) + if offset > len(results) { + offset = len(results) + } + results = results[offset:] + if q.First > 0 && int(q.First) < len(results) { + results = results[:int(q.First)] + } + } + + // Build output UIDs sorted by UID (ascending) as required by the query pipeline. uids := make([]uint64, len(results)) for i, r := range results { uids[i] = r.uid } + sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) return nil From 1ee74731a51dd455aeff86e6380abcbad8862976 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 23:04:08 -0500 Subject: [PATCH 03/19] perf(bm25): replace facet storage with compact direct Badger KV encoding Replace the facet-based BM25 storage (~40-50 bytes/posting) with compact varint-encoded binary blobs stored as direct Badger KV entries (~4-6 bytes/posting, ~10x reduction). Add bm25_score pseudo-predicate for variable-based score ordering following the similar_to pattern. - Add posting/bm25enc package for compact binary encode/decode - Rewrite write path in posting/index.go for direct Badger KV - Add bm25Writes buffer to LocalCache with read-your-own-writes - Flush BM25 blobs in CommitToDisk with BitBM25Data UserMeta - Rewrite read path in worker/task.go with direct blob decoding - Add bm25_score pseudo-predicate in query/query.go - Add score ordering integration tests Co-Authored-By: Claude Opus 4.6 --- posting/bm25enc/bm25enc.go | 147 ++++++++++++++++++++++++++++++++ posting/bm25enc/bm25enc_test.go | 132 ++++++++++++++++++++++++++++ posting/index.go | 121 +++++++------------------- posting/list.go | 2 + posting/lists.go | 56 ++++++++++++ posting/mvcc.go | 15 ++++ query/query.go | 64 ++++++++++++++ query/query_bm25_test.go | 79 +++++++++++++++++ worker/task.go | 109 +++++++++-------------- 9 files changed, 562 insertions(+), 163 deletions(-) create mode 100644 posting/bm25enc/bm25enc.go create mode 100644 posting/bm25enc/bm25enc_test.go diff --git a/posting/bm25enc/bm25enc.go b/posting/bm25enc/bm25enc.go new file mode 100644 index 00000000000..8da82b299dd --- /dev/null +++ b/posting/bm25enc/bm25enc.go @@ -0,0 +1,147 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package bm25enc provides compact binary encoding for BM25 index data. +// +// Two types of lists share the same format: +// - Term posting lists: (UID, term-frequency) pairs +// - Document length lists: (UID, doc-length) pairs +// +// Binary format: +// +// Header: +// [4 bytes] uint32 big-endian: entry count +// Entries (sorted ascending by UID): +// [varint] UID delta from previous (first entry is absolute) +// [varint] value (TF or doclen) +package bm25enc + +import ( + "encoding/binary" + "sort" +) + +// Entry represents a single (UID, Value) pair in a BM25 posting list. +type Entry struct { + UID uint64 + Value uint32 +} + +// Encode encodes a sorted slice of entries into the compact binary format. +// Entries must be sorted by UID ascending. Returns nil for empty input. +func Encode(entries []Entry) []byte { + if len(entries) == 0 { + return nil + } + + // Pre-allocate: 4 header + ~6 bytes per entry is a reasonable estimate. + buf := make([]byte, 4, 4+len(entries)*6) + binary.BigEndian.PutUint32(buf, uint32(len(entries))) + + var tmp [binary.MaxVarintLen64]byte + var prevUID uint64 + for _, e := range entries { + delta := e.UID - prevUID + n := binary.PutUvarint(tmp[:], delta) + buf = append(buf, tmp[:n]...) + n = binary.PutUvarint(tmp[:], uint64(e.Value)) + buf = append(buf, tmp[:n]...) + prevUID = e.UID + } + return buf +} + +// Decode decodes the binary format into a sorted slice of entries. +// Returns nil for nil/empty input. +func Decode(data []byte) []Entry { + if len(data) < 4 { + return nil + } + count := binary.BigEndian.Uint32(data[:4]) + if count == 0 { + return nil + } + + entries := make([]Entry, 0, count) + pos := 4 + var prevUID uint64 + for i := uint32(0); i < count; i++ { + delta, n := binary.Uvarint(data[pos:]) + if n <= 0 { + break + } + pos += n + + val, n := binary.Uvarint(data[pos:]) + if n <= 0 { + break + } + pos += n + + uid := prevUID + delta + entries = append(entries, Entry{UID: uid, Value: uint32(val)}) + prevUID = uid + } + return entries +} + +// Upsert inserts or updates the entry for uid in a sorted entries slice. +// Returns the new sorted slice. +func Upsert(entries []Entry, uid uint64, value uint32) []Entry { + i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) + if i < len(entries) && entries[i].UID == uid { + entries[i].Value = value + return entries + } + // Insert at position i. + entries = append(entries, Entry{}) + copy(entries[i+1:], entries[i:]) + entries[i] = Entry{UID: uid, Value: value} + return entries +} + +// Remove removes the entry for uid from a sorted entries slice. +// Returns the new slice (may be shorter). +func Remove(entries []Entry, uid uint64) []Entry { + i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) + if i < len(entries) && entries[i].UID == uid { + return append(entries[:i], entries[i+1:]...) + } + return entries +} + +// Search returns the value for uid using binary search, and whether it was found. +func Search(entries []Entry, uid uint64) (uint32, bool) { + i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) + if i < len(entries) && entries[i].UID == uid { + return entries[i].Value, true + } + return 0, false +} + +// UIDs extracts just the UIDs from entries as a uint64 slice. +func UIDs(entries []Entry) []uint64 { + uids := make([]uint64, len(entries)) + for i, e := range entries { + uids[i] = e.UID + } + return uids +} + +// EncodeStats encodes BM25 corpus statistics (docCount, totalTerms) as 16 bytes. +func EncodeStats(docCount, totalTerms uint64) []byte { + buf := make([]byte, 16) + binary.BigEndian.PutUint64(buf[0:8], docCount) + binary.BigEndian.PutUint64(buf[8:16], totalTerms) + return buf +} + +// DecodeStats decodes BM25 corpus statistics. Returns (0,0) for invalid input. +func DecodeStats(data []byte) (docCount, totalTerms uint64) { + if len(data) != 16 { + return 0, 0 + } + return binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) +} diff --git a/posting/bm25enc/bm25enc_test.go b/posting/bm25enc/bm25enc_test.go new file mode 100644 index 00000000000..1969e472ed2 --- /dev/null +++ b/posting/bm25enc/bm25enc_test.go @@ -0,0 +1,132 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package bm25enc + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRoundtrip(t *testing.T) { + entries := []Entry{ + {UID: 1, Value: 3}, + {UID: 5, Value: 1}, + {UID: 100, Value: 7}, + {UID: 200, Value: 2}, + } + data := Encode(entries) + got := Decode(data) + require.Equal(t, entries, got) +} + +func TestRoundtripEmpty(t *testing.T) { + require.Nil(t, Encode(nil)) + require.Nil(t, Encode([]Entry{})) + require.Nil(t, Decode(nil)) + require.Nil(t, Decode([]byte{})) + require.Nil(t, Decode([]byte{0, 0, 0, 0})) // count=0 +} + +func TestRoundtripSingle(t *testing.T) { + entries := []Entry{{UID: 42, Value: 10}} + got := Decode(Encode(entries)) + require.Equal(t, entries, got) +} + +func TestRoundtripLargeUIDs(t *testing.T) { + entries := []Entry{ + {UID: 1<<40 + 1, Value: 1}, + {UID: 1<<40 + 1000, Value: 5}, + {UID: 1<<50 + 999, Value: 99}, + } + got := Decode(Encode(entries)) + require.Equal(t, entries, got) +} + +func TestUpsertNew(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} + entries = Upsert(entries, 3, 7) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 3, Value: 7}, {UID: 5, Value: 1}}, entries) +} + +func TestUpsertExisting(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} + entries = Upsert(entries, 5, 99) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 99}}, entries) +} + +func TestUpsertEmpty(t *testing.T) { + var entries []Entry + entries = Upsert(entries, 10, 5) + require.Equal(t, []Entry{{UID: 10, Value: 5}}, entries) +} + +func TestRemove(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 10, Value: 2}} + entries = Remove(entries, 5) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 10, Value: 2}}, entries) +} + +func TestRemoveNotFound(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} + entries = Remove(entries, 99) + require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}}, entries) +} + +func TestSearch(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 100, Value: 7}} + v, ok := Search(entries, 5) + require.True(t, ok) + require.Equal(t, uint32(1), v) + + _, ok = Search(entries, 50) + require.False(t, ok) +} + +func TestUIDs(t *testing.T) { + entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 100, Value: 7}} + require.Equal(t, []uint64{1, 5, 100}, UIDs(entries)) +} + +func TestStatsRoundtrip(t *testing.T) { + data := EncodeStats(12345, 98765) + dc, tt := DecodeStats(data) + require.Equal(t, uint64(12345), dc) + require.Equal(t, uint64(98765), tt) +} + +func TestStatsInvalid(t *testing.T) { + dc, tt := DecodeStats(nil) + require.Zero(t, dc) + require.Zero(t, tt) + dc, tt = DecodeStats([]byte{1, 2, 3}) + require.Zero(t, dc) + require.Zero(t, tt) +} + +func BenchmarkEncode(b *testing.B) { + entries := make([]Entry, 10000) + for i := range entries { + entries[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + Encode(entries) + } +} + +func BenchmarkDecode(b *testing.B) { + entries := make([]Entry, 10000) + for i := range entries { + entries[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} + } + data := Encode(entries) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Decode(data) + } +} diff --git a/posting/index.go b/posting/index.go index a24f0bac2e6..826355a3633 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,7 +28,7 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" - "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" "github.com/dgraph-io/dgraph/v25/tok" @@ -232,7 +232,8 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok } // addBM25IndexMutations handles index mutations for the BM25 tokenizer. -// It stores term frequencies, document lengths, and corpus statistics. +// It stores term frequencies, document lengths, and corpus statistics as direct +// Badger KV entries using compact varint encoding, bypassing posting lists. func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { attr := info.edge.Attr uid := info.edge.Entity @@ -260,107 +261,53 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn } if info.op == pb.DirectedEdge_DEL { - // For DELETE: remove uid from all BM25 term posting lists, doc length list, - // and decrement corpus stats. + // For DELETE: remove uid from all BM25 term posting lists and doc length list. for term := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term key := x.BM25IndexKey(attr, encodedTerm) - plist, err := txn.cache.GetFromDelta(key) - if err != nil { - return err - } - edge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_DEL, - } - if err := plist.addMutation(ctx, txn, edge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(key) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) } // Remove doc length entry. dlKey := x.BM25DocLenKey(attr) - dlPlist, err := txn.cache.GetFromDelta(dlKey) - if err != nil { - return err - } - dlEdge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_DEL, - } - if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(dlKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) // Update corpus stats: decrement doc count and total terms. - return txn.updateBM25Stats(ctx, attr, -1, -int64(docLen)) + return txn.updateBM25Stats(attr, -1, -int64(docLen)) } - // For SET: store term frequencies, doc length, and update corpus stats. + // For SET: store term frequencies and doc length. for term, tf := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term key := x.BM25IndexKey(attr, encodedTerm) - plist, err := txn.cache.GetFromDelta(key) - if err != nil { - return err - } - // Store uid in the posting list. TF is stored as a facet so it survives - // the rollup cycle (REF postings without facets lose their Value field). - tfBuf := make([]byte, 4) - binary.BigEndian.PutUint32(tfBuf, tf) - edge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_SET, - Facets: []*api.Facet{{Key: "tf", Value: tfBuf, ValType: api.Facet_INT}}, - } - if err := plist.addMutation(ctx, txn, edge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(key) + entries := bm25enc.Decode(blob) + entries = bm25enc.Upsert(entries, uid, tf) + txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) } // Store document length. dlKey := x.BM25DocLenKey(attr) - dlPlist, err := txn.cache.GetFromDelta(dlKey) - if err != nil { - return err - } - dlBuf := make([]byte, 4) - binary.BigEndian.PutUint32(dlBuf, docLen) - dlEdge := &pb.DirectedEdge{ - ValueId: uid, - Attr: attr, - Op: pb.DirectedEdge_SET, - Facets: []*api.Facet{{Key: "dl", Value: dlBuf, ValType: api.Facet_INT}}, - } - if err := dlPlist.addMutation(ctx, txn, dlEdge); err != nil { - return err - } + blob := txn.cache.ReadBM25Blob(dlKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Upsert(entries, uid, docLen) + txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) // Update corpus stats: increment doc count by 1 and total terms by docLen. - return txn.updateBM25Stats(ctx, attr, 1, int64(docLen)) + return txn.updateBM25Stats(attr, 1, int64(docLen)) } // updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, -// applies the given deltas, and writes back. -func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, docCountDelta int64, totalTermsDelta int64) error { +// applies the given deltas, and writes back as a direct Badger KV entry. +func (txn *Txn) updateBM25Stats(attr string, docCountDelta int64, totalTermsDelta int64) error { statsKey := x.BM25StatsKey(attr) - plist, err := txn.cache.GetFromDelta(statsKey) - if err != nil { - return err - } - - // Read existing stats from posting with uid=1. - var docCount, totalTerms uint64 - val, err := plist.Value(txn.StartTs) - if err == nil && val.Value != nil { - data, ok := val.Value.([]byte) - if ok && len(data) == 16 { - docCount = binary.BigEndian.Uint64(data[0:8]) - totalTerms = binary.BigEndian.Uint64(data[8:16]) - } - } + blob := txn.cache.ReadBM25Blob(statsKey) + docCount, totalTerms := bm25enc.DecodeStats(blob) // Apply deltas. if docCountDelta >= 0 { @@ -384,18 +331,8 @@ func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, docCountDelta } } - // Write back stats. - statsBuf := make([]byte, 16) - binary.BigEndian.PutUint64(statsBuf[0:8], docCount) - binary.BigEndian.PutUint64(statsBuf[8:16], totalTerms) - edge := &pb.DirectedEdge{ - Entity: 1, - Attr: attr, - Value: statsBuf, - ValueType: pb.Posting_ValType(0), - Op: pb.DirectedEdge_SET, - } - return plist.addMutation(ctx, txn, edge) + txn.cache.WriteBM25Blob(statsKey, bm25enc.EncodeStats(docCount, totalTerms)) + return nil } // countParams is sent to updateCount function. It is used to update the count index. diff --git a/posting/list.go b/posting/list.go index 1c0c7a0fc55..5420a69a157 100644 --- a/posting/list.go +++ b/posting/list.go @@ -60,6 +60,8 @@ const ( BitCompletePosting byte = 0x08 // BitEmptyPosting signals that the value stores an empty posting list. BitEmptyPosting byte = 0x10 + // BitBM25Data signals that the value stores BM25 index data (direct KV, not a posting list). + BitBM25Data byte = 0x20 ) // List stores the in-memory representation of a posting list. diff --git a/posting/lists.go b/posting/lists.go index a4bc4fb355b..0bd9848de23 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -76,6 +76,10 @@ type LocalCache struct { // plists are posting lists in memory. They can be discarded to reclaim space. plists map[string]*List + + // bm25Writes buffers BM25 direct KV writes (key → encoded blob). + // These bypass the posting list infrastructure entirely. + bm25Writes map[string][]byte } // struct to implement LocalCache interface from vector-indexer @@ -135,6 +139,7 @@ func NewLocalCache(startTs uint64) *LocalCache { deltas: make(map[string][]byte), plists: make(map[string]*List), maxVersions: make(map[string]uint64), + bm25Writes: make(map[string][]byte), } } @@ -144,6 +149,57 @@ func NoCache(startTs uint64) *LocalCache { return &LocalCache{startTs: startTs} } +// ReadBM25Blob returns the BM25 blob for the given key. +// It checks the in-memory buffer first (read-your-own-writes), +// then falls back to reading from pstore at startTs. +func (lc *LocalCache) ReadBM25Blob(key []byte) []byte { + lc.RLock() + if blob, ok := lc.bm25Writes[string(key)]; ok { + lc.RUnlock() + return blob + } + lc.RUnlock() + + // Fall back to Badger. + txn := pstore.NewTransactionAt(lc.startTs, false) + defer txn.Discard() + item, err := txn.Get(key) + if err != nil { + return nil + } + val, err := item.ValueCopy(nil) + if err != nil { + return nil + } + return val +} + +// WriteBM25Blob buffers a BM25 blob write for the given key. +func (lc *LocalCache) WriteBM25Blob(key []byte, blob []byte) { + lc.Lock() + defer lc.Unlock() + if lc.bm25Writes == nil { + lc.bm25Writes = make(map[string][]byte) + } + lc.bm25Writes[string(key)] = blob +} + +// ReadBM25BlobAt reads a BM25 blob from pstore at the given read timestamp. +// This is used by the query read path (worker/task.go). +func ReadBM25BlobAt(key []byte, readTs uint64) []byte { + txn := pstore.NewTransactionAt(readTs, false) + defer txn.Discard() + item, err := txn.Get(key) + if err != nil { + return nil + } + val, err := item.ValueCopy(nil) + if err != nil { + return nil + } + return val +} + func (lc *LocalCache) UpdateCommitTs(commitTs uint64) { lc.Lock() defer lc.Unlock() diff --git a/posting/mvcc.go b/posting/mvcc.go index 81c5e375553..3b9510ef6bb 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -318,6 +318,21 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { return err } } + + // Flush BM25 direct KV writes. These are complete blobs (not deltas) + // and don't need rollup. + for key, blob := range cache.bm25Writes { + if err := writer.update(commitTs, func(btxn *badger.Txn) error { + return btxn.SetEntry(&badger.Entry{ + Key: []byte(key), + Value: blob, + UserMeta: BitBM25Data, + }) + }); err != nil { + return err + } + } + return nil } diff --git a/query/query.go b/query/query.go index 3025033e1e0..e241d18946b 100644 --- a/query/query.go +++ b/query/query.go @@ -7,6 +7,7 @@ package query import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -373,6 +374,19 @@ func getValue(tv *pb.TaskValue) (types.Val, error) { return val, nil } +func valToTaskValue(v types.Val) *pb.TaskValue { + data := types.ValueForType(types.BinaryID) + res := &pb.TaskValue{ValType: v.Tid.Enum(), Val: x.Nilbyte} + if v.Value == nil { + return res + } + if err := types.Marshal(v, &data); err != nil { + return res + } + res.Val = data.Value.([]byte) + return res +} + var ( // ErrEmptyVal is returned when a value is empty. ErrEmptyVal = errors.New("Query: harmless error, e.g. task.Val is nil") @@ -1369,6 +1383,9 @@ func (sg *SubGraph) valueVarAggregation(doneVars map[string]varValue, path []*Su case sg.Attr == "uid" && sg.Params.DoCount: // This is the count(uid) case. // We will do the computation later while constructing the result. + case sg.Attr == "bm25_score": + // bm25_score is a pseudo-predicate handled inline during children processing. + // Its valueMatrix is already populated. Nothing to aggregate. default: return errors.Errorf("Unhandled pb.node <%v> with parent <%v>", sg.Attr, parent.Attr) } @@ -2173,6 +2190,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { rch <- nil return } + var err error switch { case parent == nil && sg.SrcFunc != nil && sg.SrcFunc.Name == "uid": @@ -2275,6 +2293,30 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.List = result.List sg.vectorMetrics = result.VectorMetrics + // If this is a BM25 root function, extract scores from ValueMatrix + // and store them in ParentVars for bm25_score pseudo-predicate children. + if sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && len(result.UidMatrix) > 0 && + len(result.ValueMatrix) > 0 { + bm25Scores := types.NewShardedMap() + uids := result.UidMatrix[0].GetUids() + for i, uid := range uids { + if i < len(result.ValueMatrix) && len(result.ValueMatrix[i].Values) > 0 { + tv := result.ValueMatrix[i].Values[0] + if len(tv.Val) == 8 { + score := math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) + bm25Scores.Set(uid, types.Val{ + Tid: types.FloatID, + Value: score, + }) + } + } + } + if sg.Params.ParentVars == nil { + sg.Params.ParentVars = make(map[string]varValue) + } + sg.Params.ParentVars["__bm25_scores__"] = varValue{Vals: bm25Scores} + } + if sg.Params.DoCount { if len(sg.Filters) == 0 { // If there is a filter, we need to do more work to get the actual count. @@ -2452,6 +2494,28 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } child.SrcUIDs = sg.DestUIDs // Make the connection. + + // Handle bm25_score pseudo-predicate: populate valueMatrix from parent's + // BM25 scores. Mark IsInternal so populateUidValVar case 4 (value variable) + // fires instead of case 3 (UID variable). + if child.Attr == "bm25_score" { + if bm25Var, ok := child.Params.ParentVars["__bm25_scores__"]; ok && bm25Var.Vals != nil { + child.valueMatrix = make([]*pb.ValueList, len(child.SrcUIDs.GetUids())) + for j, uid := range child.SrcUIDs.GetUids() { + if val, okv := bm25Var.Vals.Get(uid); okv { + child.valueMatrix[j] = &pb.ValueList{ + Values: []*pb.TaskValue{valToTaskValue(val)}, + } + } else { + child.valueMatrix[j] = &pb.ValueList{} + } + } + } + child.DestUIDs = &pb.List{} + child.Params.IsInternal = true + continue + } + if child.IsInternal() { // We dont have to execute these nodes. continue diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index dcceece7428..cdb235be36f 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -221,3 +221,82 @@ func TestBM25Pagination(t *testing.T) { // Doc 503 "fox fox fox" should be the top result. require.Contains(t, js, "fox fox fox") } + +func TestBM25ScoreOrdering(t *testing.T) { + // Use the bm25_score pseudo-predicate with var block to order results by score. + query := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // "fox fox fox" (doc 503) has the highest BM25 score (tf=3, shortest doc). + require.Contains(t, js, "fox fox fox") +} + +func TestBM25ScoreOrderingMultiTerm(t *testing.T) { + // Multi-term query with score ordering: "quick lazy" should rank doc 501 highest + // since it contains both terms. + query := ` + { + var(func: bm25(description_bm25, "quick lazy")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25ScoreOrderingAllResults(t *testing.T) { + // Verify all results are returned in score-descending order via val(score). + query := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // All fox-containing docs should appear. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + // Score values should be present. + require.Contains(t, js, "val(score)") +} + +func TestBM25ScoreWithPagination(t *testing.T) { + // Use offset with score ordering. + query := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 1, offset: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return the second-highest scored document (not "fox fox fox"). + require.NotContains(t, js, "fox fox fox") + require.Contains(t, js, "fox") +} diff --git a/worker/task.go b/worker/task.go index 2fbd65acca4..fbc3189a42b 100644 --- a/worker/task.go +++ b/worker/task.go @@ -30,6 +30,7 @@ import ( "github.com/dgraph-io/dgraph/v25/algo" "github.com/dgraph-io/dgraph/v25/conn" "github.com/dgraph-io/dgraph/v25/posting" + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" ctask "github.com/dgraph-io/dgraph/v25/task" @@ -1272,31 +1273,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return nil } - // 3. Read corpus stats. + // 3. Read corpus stats from direct Badger KV. statsKey := x.BM25StatsKey(attr) - statsPl, err := qs.cache.Get(statsKey) - if err != nil { - // No stats means no documents indexed yet. - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - statsVal, err := statsPl.Value(q.ReadTs) - if err != nil || statsVal.Value == nil { - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - statsData, ok := statsVal.Value.([]byte) - if !ok || len(statsData) != 16 { - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - docCount := binary.BigEndian.Uint64(statsData[0:8]) - totalTerms := binary.BigEndian.Uint64(statsData[8:16]) - if docCount == 0 { - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) - return nil - } - if totalTerms == 0 { + statsBlob := posting.ReadBM25BlobAt(statsKey, q.ReadTs) + docCount, totalTerms := bm25enc.DecodeStats(statsBlob) + if docCount == 0 || totalTerms == 0 { args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) return nil } @@ -1312,7 +1293,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // 4. For each query token, read the posting list and collect term info. + // 4. For each query token, read the BM25 term blob and collect term info. type termInfo struct { idf float64 uidTFs map[uint64]uint32 @@ -1321,39 +1302,27 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error for _, token := range queryTokens { key := x.BM25IndexKey(attr, token) - pl, err := qs.cache.Get(key) - if err != nil { + blob := posting.ReadBM25BlobAt(key, q.ReadTs) + entries := bm25enc.Decode(blob) + if len(entries) == 0 { continue } ti := &termInfo{uidTFs: make(map[uint64]uint32)} - var df float64 - err = pl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { - df++ - // When used as filter, only collect TF for UIDs in the filter set. + df := float64(len(entries)) + for _, e := range entries { if filterSet != nil { - if _, ok := filterSet[p.Uid]; !ok { - return nil - } - } - tf := uint32(1) - for _, f := range p.Facets { - if f.Key == "tf" && len(f.Value) >= 4 { - tf = binary.BigEndian.Uint32(f.Value[:4]) - break + if _, ok := filterSet[e.UID]; !ok { + continue } } - ti.uidTFs[p.Uid] = tf - return nil - }) - if err != nil { - continue + ti.uidTFs[e.UID] = e.Value } ti.idf = math.Log1p((N - df + 0.5) / (df + 0.5)) termInfos[token] = ti } - // 5. Read doc lengths for all UIDs seen. + // 5. Read doc lengths for all UIDs seen using binary search on the doclen blob. allUids := make(map[uint64]struct{}) for _, ti := range termInfos { for uid := range ti.uidTFs { @@ -1361,28 +1330,15 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - docLens := make(map[uint64]uint32) dlKey := x.BM25DocLenKey(attr) - dlPl, err := qs.cache.Get(dlKey) - if err == nil { - remaining := len(allUids) - _ = dlPl.Iterate(q.ReadTs, 0, func(p *pb.Posting) error { - if remaining == 0 { - return posting.ErrStopIteration - } - if _, needed := allUids[p.Uid]; needed { - dl := uint32(1) - for _, f := range p.Facets { - if f.Key == "dl" && len(f.Value) >= 4 { - dl = binary.BigEndian.Uint32(f.Value[:4]) - break - } - } - docLens[p.Uid] = dl - remaining-- - } - return nil - }) + dlBlob := posting.ReadBM25BlobAt(dlKey, q.ReadTs) + dlEntries := bm25enc.Decode(dlBlob) + + docLens := make(map[uint64]uint32, len(allUids)) + for uid := range allUids { + if v, ok := bm25enc.Search(dlEntries, uid); ok { + docLens[uid] = v + } } // 6. Compute final BM25 scores. @@ -1408,7 +1364,6 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error for uid, score := range scores { results = append(results, uidScore{uid: uid, score: score}) } - // Sort by score descending for ordering, then collect UIDs. sort.Slice(results, func(i, j int) bool { if results[i].score != results[j].score { return results[i].score > results[j].score @@ -1428,14 +1383,26 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // Build output UIDs sorted by UID (ascending) as required by the query pipeline. + // Build output: UIDs sorted ascending (required by query pipeline) + // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). + sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) for i, r := range results { uids[i] = r.uid } - sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + + // Populate ValueMatrix with BM25 scores aligned to UIDs. + // Each entry is a ValueList with a single float64 value. + scoreValues := make([]*pb.ValueList, len(results)) + for i, r := range results { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, math.Float64bits(r.score)) + scoreValues[i] = &pb.ValueList{ + Values: []*pb.TaskValue{{Val: buf, ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, + } + } + args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) return nil } From 79d52955f975861afdac4ba3f48ad399a44b0766 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 4 Mar 2026 23:51:09 -0500 Subject: [PATCH 04/19] test(bm25): add 15 integration tests for mutation scenarios and edge cases Cover incremental add/update/delete, IDF score stability as corpus grows, large corpus pagination, unicode, stopwords, uid filtering, score validation, and concurrent batch adds. Co-Authored-By: Claude Opus 4.6 --- query/query_bm25_test.go | 535 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 535 insertions(+) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index cdb235be36f..6f469220ab5 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -10,6 +10,10 @@ package query import ( "context" + "encoding/json" + "fmt" + "math" + "strings" "testing" "github.com/stretchr/testify/require" @@ -300,3 +304,534 @@ func TestBM25ScoreWithPagination(t *testing.T) { require.NotContains(t, js, "fox fox fox") require.Contains(t, js, "fox") } + +// parseScoresFromJSON extracts uid → score from JSON responses containing val(score). +func parseScoresFromJSON(t *testing.T, js string) map[string]float64 { + t.Helper() + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + scores := make(map[string]float64) + for _, item := range resp.Data.Me { + scores[item.UID] = item.Score + } + return scores +} + +func TestBM25IncrementalAddBatch(t *testing.T) { + batch1 := ` + <600> "alpha bravo charlie" . + <601> "delta echo foxtrot" . + ` + batch2 := ` + <602> "golf hotel india" . + <603> "juliet kilo lima" . + <604> "mike november oscar" . + ` + batch3 := ` + <605> "papa quebec romeo" . + <606> "sierra tango uniform" . + <607> "victor whiskey xray" . + ` + cleanup := func() { + deleteTriplesInCluster(` + <600> * . + <601> * . + <602> * . + <603> * . + <604> * . + <605> * . + <606> * . + <607> * . + `) + } + t.Cleanup(cleanup) + + countQuery := ` + { + me(func: bm25(description_bm25, "alpha bravo delta echo golf juliet mike papa sierra victor")) { + count(uid) + } + } + ` + + // Batch 1: add 2 docs. + require.NoError(t, addTriplesToCluster(batch1)) + js := processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":2`) + + // Batch 2: add 3 more docs → total 5. + require.NoError(t, addTriplesToCluster(batch2)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":5`) + + // Batch 3: add 3 more docs → total 8. + require.NoError(t, addTriplesToCluster(batch3)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":8`) + + // Verify specific new UIDs are searchable. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid } }`) + require.Contains(t, js, `"0x25e"`) // 606 +} + +func TestBM25CorpusStatsAffectIDF(t *testing.T) { + // Capture baseline score for "fox" query. + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + jsBefore := processQueryNoErr(t, scoreQuery) + scoresBefore := parseScoresFromJSON(t, jsBefore) + require.NotEmpty(t, scoresBefore, "baseline should have fox results") + + // Add 10 non-fox docs → N grows, df("fox") stays same → IDF should increase. + var triples string + for i := 610; i < 620; i++ { + triples += fmt.Sprintf(`<%d> "completely unrelated document about cats and dogs number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 610; i < 620; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + jsAfter := processQueryNoErr(t, scoreQuery) + scoresAfter := parseScoresFromJSON(t, jsAfter) + + // Compare score for UID 503 ("fox fox fox") — should increase. + uid503 := "0x1f7" + before, ok1 := scoresBefore[uid503] + after, ok2 := scoresAfter[uid503] + require.True(t, ok1 && ok2, "UID 503 should appear in both before and after results") + require.Greater(t, after, before, + "IDF should increase when corpus grows with non-matching docs (before=%f, after=%f)", before, after) +} + +func TestBM25DocumentUpdate(t *testing.T) { + // Add a doc with lots of "fox". + require.NoError(t, addTriplesToCluster(`<620> "fox fox fox fox" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<620> * .`) + }) + + // Should rank top for "fox". + js := processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + } + }`) + require.Contains(t, js, `"0x26c"`) // 620 + + // Update to remove "fox", add "cat". + deleteTriplesInCluster(`<620> "fox fox fox fox" .`) + require.NoError(t, addTriplesToCluster(`<620> "the cat sat on the mat" .`)) + + // Should no longer appear in "fox" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox")) { + uid + } + }`) + require.NotContains(t, js, `"0x26c"`) + + // Should appear in "cat" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "cat")) { + uid + } + }`) + require.Contains(t, js, `"0x26c"`) +} + +func TestBM25DocumentDeletion(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<625> "unique elephant term" .`)) + t.Cleanup(func() { + // Cleanup in case test fails before explicit delete. + deleteTriplesInCluster(`<625> * .`) + }) + + // Should find the elephant doc. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"0x271"`) // 625 + + // Delete it. + deleteTriplesInCluster(`<625> "unique elephant term" .`) + + // Should return empty for "elephant". + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.JSONEq(t, `{"data": {"me":[]}}`, js) + + // Baseline "fox" results should be unaffected. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid } }`) + require.Contains(t, js, "fox") +} + +func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + uid503 := "0x1f7" + + // Phase 1: baseline score. + js1 := processQueryNoErr(t, scoreQuery) + scores1 := parseScoresFromJSON(t, js1) + score1, ok := scores1[uid503] + require.True(t, ok, "UID 503 must appear in baseline") + + // Phase 2: add 5 fox docs → IDF decreases. + var foxTriples string + for i := 630; i < 635; i++ { + foxTriples += fmt.Sprintf(`<%d> "the fox runs quickly across the field number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(foxTriples)) + t.Cleanup(func() { + var del string + for i := 630; i < 640; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + js2 := processQueryNoErr(t, scoreQuery) + scores2 := parseScoresFromJSON(t, js2) + score2, ok := scores2[uid503] + require.True(t, ok, "UID 503 must appear after adding fox docs") + require.Greater(t, score1, score2, + "Adding fox docs should decrease IDF and thus score (phase1=%f, phase2=%f)", score1, score2) + + // Phase 3: add 5 non-fox docs → IDF increases relative to phase 2. + var nonFoxTriples string + for i := 635; i < 640; i++ { + nonFoxTriples += fmt.Sprintf(`<%d> "unrelated content about birds and fish number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(nonFoxTriples)) + + js3 := processQueryNoErr(t, scoreQuery) + scores3 := parseScoresFromJSON(t, js3) + score3, ok := scores3[uid503] + require.True(t, ok, "UID 503 must appear after adding non-fox docs") + require.Greater(t, score3, score2, + "Adding non-fox docs should increase IDF relative to phase2 (phase2=%f, phase3=%f)", score2, score3) +} + +func TestBM25LargeCorpus(t *testing.T) { + // Add 100 docs: 50 with "alpha", 50 with "beta". + var triples string + for i := 700; i < 750; i++ { + triples += fmt.Sprintf(`<%d> "alpha document content number %d with some padding words" . +`, i, i) + } + for i := 750; i < 800; i++ { + triples += fmt.Sprintf(`<%d> "beta document content number %d with some padding words" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 700; i < 800; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + // Count alpha docs. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Count beta docs. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "beta")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Union count: "alpha beta" should match all 100. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha beta")) { count(uid) } }`) + require.Contains(t, js, `"count":100`) + + // Pagination: first:10, offset:40 for alpha should return 10 results. + js = processQueryNoErr(t, ` + { + var(func: bm25(description_bm25, "alpha")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score), first: 10, offset: 40) { + uid + } + }`) + var resp struct { + Data struct { + Me []struct{ UID string } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.Len(t, resp.Data.Me, 10, "pagination first:10 offset:40 should return exactly 10 results") +} + +func TestBM25EdgeCaseSingleCharTerm(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<640> "x y z" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<640> * .`) + }) + + // Single-char terms may or may not be indexed depending on tokenizer. + // Just verify no panic/error. + _, err := processQuery(context.Background(), t, ` + { + me(func: bm25(description_bm25, "x")) { + uid + } + }`) + require.NoError(t, err) +} + +func TestBM25EdgeCaseLongDocument(t *testing.T) { + // Build a ~500-word document with "fox" appearing once. + words := make([]string, 500) + for i := range words { + words[i] = "padding" + } + words[250] = "fox" + longDoc := strings.Join(words, " ") + + require.NoError(t, addTriplesToCluster(fmt.Sprintf(`<645> %q .`, longDoc))) + t.Cleanup(func() { + deleteTriplesInCluster(`<645> * .`) + }) + + // Get scores for "fox" query. + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + uid503 := "0x1f7" // "fox fox fox" (doclen=3) + uid645 := "0x285" // long doc (doclen~500) + s503, ok1 := scores[uid503] + s645, ok2 := scores[uid645] + require.True(t, ok1, "UID 503 must appear in fox results") + require.True(t, ok2, "UID 645 must appear in fox results") + require.Greater(t, s503, s645, + "Short doc with high tf should score higher than long doc with low tf (503=%f, 645=%f)", s503, s645) +} + +func TestBM25EdgeCaseUnicode(t *testing.T) { + triples := ` + <650> "der schnelle braune Fuchs springt" . + <651> "le renard brun rapide saute" . + <652> "el zorro marrón rápido salta" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <650> * . + <651> * . + <652> * . + `) + }) + + // Query German term. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "Fuchs")) { uid } }`) + require.Contains(t, js, `"0x28a"`) // 650 + + // Query French term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "renard")) { uid } }`) + require.Contains(t, js, `"0x28b"`) // 651 + + // Query Spanish term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "zorro")) { uid } }`) + require.Contains(t, js, `"0x28c"`) // 652 +} + +func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<655> "the a an is are was were" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<655> * .`) + }) + + // Query "the" — should return empty since "the" is a stopword. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "the")) { uid } }`) + require.NotContains(t, js, `"0x28f"`) // 655 should not appear + + // But the doc should exist via has(). + js = processQueryNoErr(t, ` + { + me(func: has(description_bm25)) @filter(uid(655)) { + uid + } + }`) + require.Contains(t, js, `"0x28f"`) +} + +func TestBM25WithUidFilter(t *testing.T) { + // BM25 root with uid filter to restrict results. + query := ` + { + me(func: bm25(description_bm25, "fox")) @filter(uid(501, 503)) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should contain only UIDs 501 and 503. + require.Contains(t, js, `"0x1f5"`) // 501 + require.Contains(t, js, `"0x1f7"`) // 503 + // Should NOT contain other fox docs like 502, 506, 507. + require.NotContains(t, js, `"0x1f6"`) // 502 + require.NotContains(t, js, `"0x1fa"`) // 506 +} + +func TestBM25ScoreValuesAreValidFloats(t *testing.T) { + scoreQuery := ` + { + var(func: bm25(description_bm25, "fox")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + require.NotEmpty(t, scores, "should have at least one result") + + var prevScore float64 + first := true + // Iterate over results in order (they're orderdesc by score). + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + + for _, item := range resp.Data.Me { + score := item.Score + require.False(t, math.IsNaN(score), "score should not be NaN for uid %s", item.UID) + require.False(t, math.IsInf(score, 0), "score should not be Inf for uid %s", item.UID) + require.Greater(t, score, 0.0, "score should be positive for uid %s", item.UID) + + if !first { + require.GreaterOrEqual(t, prevScore, score, + "scores should be in descending order: %f >= %f", prevScore, score) + } + prevScore = score + first = false + } +} + +func TestBM25IncrementalAddThenDeleteThenReadd(t *testing.T) { + t.Cleanup(func() { + deleteTriplesInCluster(`<670> * .`) + }) + + // Phase 1: add with "elephant". + require.NoError(t, addTriplesToCluster(`<670> "elephant roams the savanna" .`)) + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"0x29e"`) // 670 + + // Phase 2: delete. + deleteTriplesInCluster(`<670> "elephant roams the savanna" .`) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"0x29e"`) + + // Phase 3: re-add with different content. + require.NoError(t, addTriplesToCluster(`<670> "penguin waddles on the ice" .`)) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "penguin")) { uid } }`) + require.Contains(t, js, `"0x29e"`) + + // "elephant" should still not match 670. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"0x29e"`) +} + +func TestBM25NonIndexedPredicateError(t *testing.T) { + // "name" predicate does not have @index(bm25). + query := ` + { + me(func: bm25(name, "alice")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25") +} + +func TestBM25ConcurrentBatchAdd(t *testing.T) { + // Add 5 batches of 4 docs each (UIDs 680-699) back-to-back. + t.Cleanup(func() { + var del string + for i := 680; i < 700; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + for batch := 0; batch < 5; batch++ { + var triples string + for j := 0; j < 4; j++ { + uid := 680 + batch*4 + j + triples += fmt.Sprintf(`<%d> "searchterm batch%d doc%d content here" . +`, uid, batch, j) + } + require.NoError(t, addTriplesToCluster(triples)) + } + + // All 20 docs should be findable. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "searchterm")) { count(uid) } }`) + require.Contains(t, js, `"count":20`) + + // Spot-check a doc from each batch. + for batch := 0; batch < 5; batch++ { + uid := 680 + batch*4 + hexUID := fmt.Sprintf(`"0x%x"`, uid) + term := fmt.Sprintf("batch%d", batch) + js = processQueryNoErr(t, fmt.Sprintf(`{ me(func: bm25(description_bm25, "%s")) { uid } }`, term)) + require.Contains(t, js, hexUID, "doc %d from batch %d should be searchable", uid, batch) + } +} From c9273f613e61d85da362efe2771cde56040a1313 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 07:46:43 -0500 Subject: [PATCH 05/19] test(bm25): add exact score verification, BM15 variant, and single-doc tests Addresses test coverage gaps identified during code review against ArangoDB's BM25 implementation: - TestBM25ExactScoreValues: validates numerical correctness of BM25 formula using b=0 to enable hand-computed expected scores - TestBM25BM15NoLengthNormalization: verifies b=0 disables length normalization and contrasts with default b=0.75 behavior - TestBM25SingleMatchingDocument: covers df=1 edge case with high IDF Co-Authored-By: Claude Opus 4.6 --- query/query_bm25_test.go | 181 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index 6f469220ab5..457c7b46452 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -835,3 +835,184 @@ func TestBM25ConcurrentBatchAdd(t *testing.T) { require.Contains(t, js, hexUID, "doc %d from batch %d should be searchable", uid, batch) } } + +// parseCorpusCount returns the total number of documents with the description_bm25 predicate. +func parseCorpusCount(t *testing.T) float64 { + t.Helper() + js := processQueryNoErr(t, `{ me(func: has(description_bm25)) { count(uid) } }`) + var resp struct { + Data struct { + Me []struct { + Count int `json:"count"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me) + n := float64(resp.Data.Me[0].Count) + require.Greater(t, n, 0.0, "corpus must have documents") + return n +} + +func TestBM25ExactScoreValues(t *testing.T) { + // Exact score verification using b=0 (BM15 variant) to eliminate avgDL dependency. + // With b=0: score = idf * (k+1) * tf / (k + tf) + // This validates the core BM25 formula computes correct numerical values. + triples := ` + <850> "quasar quasar quasar" . + <851> "quasar nebula pulsar" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <850> * . + <851> * . + `) + }) + + N := parseCorpusCount(t) + + // Query "quasar" with b=0 so score depends only on tf, k, and IDF (not avgDL). + scoreQuery := ` + { + var(func: bm25(description_bm25, "quasar", "1.2", "0")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + k := 1.2 + df := 2.0 // both 850 and 851 contain "quasar" + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + + // Doc 850 "quasar quasar quasar": tf=3, b=0 → score = idf * 2.2 * 3 / 4.2 + expected850 := idf * (k + 1) * 3.0 / (k + 3.0) + // Doc 851 "quasar nebula pulsar": tf=1, b=0 → score = idf * 2.2 * 1 / 2.2 = idf + expected851 := idf * (k + 1) * 1.0 / (k + 1.0) + + actual850, ok := scores["0x352"] // 850 + require.True(t, ok, "UID 850 (0x352) must be in results") + actual851, ok := scores["0x353"] // 851 + require.True(t, ok, "UID 851 (0x353) must be in results") + + require.InEpsilon(t, expected850, actual850, 1e-6, + "Doc 850 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected850, actual850, N, df, idf) + require.InEpsilon(t, expected851, actual851, 1e-6, + "Doc 851 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected851, actual851, N, df, idf) + + // Verify ordering: higher tf should yield higher score. + require.Greater(t, actual850, actual851) +} + +func TestBM25BM15NoLengthNormalization(t *testing.T) { + // With b=0 (BM15 variant), document length should NOT affect the score. + // Two docs with the same term frequency but different lengths must score identically. + triples := ` + <860> "vortex" . + <861> "vortex alpha bravo charlie delta echo foxtrot golf hotel india" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <860> * . + <861> * . + `) + }) + + // Query with b=0: length normalization disabled. + scoreQuery := ` + { + var(func: bm25(description_bm25, "vortex", "1.2", "0")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + score860, ok1 := scores["0x35c"] // 860 + score861, ok2 := scores["0x35d"] // 861 + require.True(t, ok1, "UID 860 must be in results") + require.True(t, ok2, "UID 861 must be in results") + + // With b=0 and same tf=1, scores must be equal regardless of document length. + require.InDelta(t, score860, score861, 1e-9, + "b=0 should disable length normalization: short doc score=%f, long doc score=%f", + score860, score861) + + // Now verify that with default b=0.75, the shorter doc scores higher. + scoreQueryDefault := ` + { + var(func: bm25(description_bm25, "vortex")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js = processQueryNoErr(t, scoreQueryDefault) + scoresDefault := parseScoresFromJSON(t, js) + + defScore860, ok1 := scoresDefault["0x35c"] + defScore861, ok2 := scoresDefault["0x35d"] + require.True(t, ok1, "UID 860 must be in default results") + require.True(t, ok2, "UID 861 must be in default results") + require.Greater(t, defScore860, defScore861, + "With b=0.75, shorter doc (doclen=1) should score higher than longer doc (doclen=10)") +} + +func TestBM25SingleMatchingDocument(t *testing.T) { + // Edge case: a single document matching the query term (df=1). + // IDF should be high since the term is very rare. + triples := `<865> "aardvark" .` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(`<865> * .`) + }) + + N := parseCorpusCount(t) + + // Query with b=0 for exact verification. + scoreQuery := ` + { + var(func: bm25(description_bm25, "aardvark", "1.2", "0")) { + score as bm25_score + } + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + require.Len(t, scores, 1, "exactly one document should match 'aardvark'") + + actual, ok := scores["0x361"] // 865 + require.True(t, ok, "UID 865 (0x361) must be in results") + + // With df=1, tf=1, b=0, k=1.2: + // idf = log1p((N - 1 + 0.5) / (1 + 0.5)) = log1p((N - 0.5) / 1.5) + // score = idf * 2.2 * 1 / (1.2 + 1) = idf * 2.2 / 2.2 = idf + k := 1.2 + df := 1.0 + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + expected := idf * (k + 1) * 1.0 / (k + 1.0) // simplifies to idf + + require.InEpsilon(t, expected, actual, 1e-6, + "Single-doc score mismatch: expected %f, got %f (N=%f, idf=%f)", + expected, actual, N, idf) + require.Greater(t, actual, 0.0, "score must be positive") + require.False(t, math.IsInf(actual, 0), "score must be finite") +} From ffb7f2f3f459b6abf68672f034b29b23c42cbba1 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:07:35 -0500 Subject: [PATCH 06/19] feat(bm25): add block storage infrastructure for segmented column stores Phase 1 of BM25 scaling plan. Introduces bm25block package with: - BlockMeta/Dir types for block directory encoding/decoding - SplitIntoBlocks: splits monolithic entry slices into 128-entry blocks - MergeAllBlocks: compacts overlapping blocks with dedup and tombstone removal - ComputeUBPre/SuffixMaxUBPre: WAND upper-bound precomputation - New key functions: BM25TermDirKey, BM25TermBlockKey, BM25DocLenDirKey, BM25DocLenBlockKey for block-addressed Badger KV storage 17 unit tests and benchmarks for the block storage format. Co-Authored-By: Claude Opus 4.6 --- posting/bm25block/bm25block.go | 261 ++++++++++++++++++++++++++++ posting/bm25block/bm25block_test.go | 258 +++++++++++++++++++++++++++ x/keys.go | 24 +++ 3 files changed, 543 insertions(+) create mode 100644 posting/bm25block/bm25block.go create mode 100644 posting/bm25block/bm25block_test.go diff --git a/posting/bm25block/bm25block.go b/posting/bm25block/bm25block.go new file mode 100644 index 00000000000..f529ed8fab8 --- /dev/null +++ b/posting/bm25block/bm25block.go @@ -0,0 +1,261 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package bm25block provides block-based storage for BM25 index data. +// +// Instead of storing all postings for a term in a single blob, this package +// splits them into fixed-size blocks (~128 entries). Each block is stored as +// a separate Badger KV entry, and a lightweight directory indexes the blocks. +// +// This enables: +// - Selective I/O: queries only read blocks they need +// - WAND/Block-Max WAND: per-block upper bounds enable early termination +// - Efficient mutations: only the affected block is rewritten +package bm25block + +import ( + "encoding/binary" + "math" + "sort" + + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" +) + +const ( + // TargetBlockSize is the ideal number of entries per block. + TargetBlockSize = 128 + // MaxBlockSize is the threshold at which a block is split. + MaxBlockSize = 256 + // DocLenBlockSize is the target entries per document-length block. + DocLenBlockSize = 512 + + // dirHeaderSize is 4 (blockCount) + 4 (nextID). + dirHeaderSize = 8 + // dirEntrySize is 8 (firstUID) + 4 (blockID) + 4 (count) + 4 (maxTF). + dirEntrySize = 20 +) + +// BlockMeta stores metadata for a single block in a directory. +type BlockMeta struct { + FirstUID uint64 + BlockID uint32 + Count uint32 + MaxTF uint32 +} + +// Dir is a block directory for a term's posting list or document-length list. +type Dir struct { + Blocks []BlockMeta + NextID uint32 // next available block ID +} + +// EncodeDir encodes a directory to bytes. Returns nil for an empty directory. +func EncodeDir(d *Dir) []byte { + if d == nil || len(d.Blocks) == 0 { + return nil + } + buf := make([]byte, dirHeaderSize+len(d.Blocks)*dirEntrySize) + binary.BigEndian.PutUint32(buf[0:4], uint32(len(d.Blocks))) + binary.BigEndian.PutUint32(buf[4:8], d.NextID) + off := dirHeaderSize + for _, b := range d.Blocks { + binary.BigEndian.PutUint64(buf[off:off+8], b.FirstUID) + binary.BigEndian.PutUint32(buf[off+8:off+12], b.BlockID) + binary.BigEndian.PutUint32(buf[off+12:off+16], b.Count) + binary.BigEndian.PutUint32(buf[off+16:off+20], b.MaxTF) + off += dirEntrySize + } + return buf +} + +// DecodeDir decodes a directory from bytes. Returns an empty Dir for nil/invalid input. +func DecodeDir(data []byte) *Dir { + if len(data) < dirHeaderSize { + return &Dir{} + } + count := binary.BigEndian.Uint32(data[0:4]) + nextID := binary.BigEndian.Uint32(data[4:8]) + if int(count)*dirEntrySize+dirHeaderSize > len(data) { + return &Dir{NextID: nextID} + } + blocks := make([]BlockMeta, count) + off := dirHeaderSize + for i := uint32(0); i < count; i++ { + blocks[i] = BlockMeta{ + FirstUID: binary.BigEndian.Uint64(data[off : off+8]), + BlockID: binary.BigEndian.Uint32(data[off+8 : off+12]), + Count: binary.BigEndian.Uint32(data[off+12 : off+16]), + MaxTF: binary.BigEndian.Uint32(data[off+16 : off+20]), + } + off += dirEntrySize + } + return &Dir{Blocks: blocks, NextID: nextID} +} + +// FindBlock returns the index of the block that should contain uid. +// Returns 0 if the directory is empty (caller should create first block). +func (d *Dir) FindBlock(uid uint64) int { + if len(d.Blocks) == 0 { + return 0 + } + // Binary search: find the last block where FirstUID <= uid. + i := sort.Search(len(d.Blocks), func(i int) bool { + return d.Blocks[i].FirstUID > uid + }) + if i > 0 { + return i - 1 + } + return 0 +} + +// AllocBlockID returns the next available block ID and increments the counter. +func (d *Dir) AllocBlockID() uint32 { + id := d.NextID + d.NextID++ + return id +} + +// UpdateBlockMeta recomputes metadata for the block at index idx from entries. +func (d *Dir) UpdateBlockMeta(idx int, entries []bm25enc.Entry) { + if idx < 0 || idx >= len(d.Blocks) || len(entries) == 0 { + return + } + d.Blocks[idx].FirstUID = entries[0].UID + d.Blocks[idx].Count = uint32(len(entries)) + var maxTF uint32 + for _, e := range entries { + if e.Value > maxTF { + maxTF = e.Value + } + } + d.Blocks[idx].MaxTF = maxTF +} + +// InsertBlockMeta inserts a new block at position idx. +func (d *Dir) InsertBlockMeta(idx int, meta BlockMeta) { + d.Blocks = append(d.Blocks, BlockMeta{}) + copy(d.Blocks[idx+1:], d.Blocks[idx:]) + d.Blocks[idx] = meta +} + +// RemoveBlockMeta removes the block at position idx. +func (d *Dir) RemoveBlockMeta(idx int) { + if idx < 0 || idx >= len(d.Blocks) { + return + } + d.Blocks = append(d.Blocks[:idx], d.Blocks[idx+1:]...) +} + +// SplitIntoBlocks splits a sorted entry slice into blocks of TargetBlockSize. +// Returns a new Dir and a map of blockID -> entries. +func SplitIntoBlocks(entries []bm25enc.Entry) (*Dir, map[uint32][]bm25enc.Entry) { + if len(entries) == 0 { + return &Dir{}, nil + } + dir := &Dir{} + blockMap := make(map[uint32][]bm25enc.Entry) + + for i := 0; i < len(entries); i += TargetBlockSize { + end := i + TargetBlockSize + if end > len(entries) { + end = len(entries) + } + block := entries[i:end] + blockID := dir.AllocBlockID() + + var maxTF uint32 + for _, e := range block { + if e.Value > maxTF { + maxTF = e.Value + } + } + + dir.Blocks = append(dir.Blocks, BlockMeta{ + FirstUID: block[0].UID, + BlockID: blockID, + Count: uint32(len(block)), + MaxTF: maxTF, + }) + // Make a copy so the caller owns the slice. + cp := make([]bm25enc.Entry, len(block)) + copy(cp, block) + blockMap[blockID] = cp + } + return dir, blockMap +} + +// MergeAllBlocks reads all block entries from a map (keyed by blockID), +// merges them into a single sorted slice, then re-splits into clean blocks. +func MergeAllBlocks(dir *Dir, readBlock func(blockID uint32) []bm25enc.Entry) (*Dir, map[uint32][]bm25enc.Entry) { + var all []bm25enc.Entry + for _, bm := range dir.Blocks { + entries := readBlock(bm.BlockID) + all = append(all, entries...) + } + // Sort by UID and deduplicate (keep last occurrence for same UID). + sort.Slice(all, func(i, j int) bool { return all[i].UID < all[j].UID }) + deduped := make([]bm25enc.Entry, 0, len(all)) + for i, e := range all { + if i > 0 && e.UID == all[i-1].UID { + deduped[len(deduped)-1] = e // overwrite with latest + continue + } + deduped = append(deduped, e) + } + // Remove tombstones (Value == 0). + live := deduped[:0] + for _, e := range deduped { + if e.Value > 0 { + live = append(live, e) + } + } + return SplitIntoBlocks(live) +} + +// ComputeUBPre computes the upper-bound pre-IDF BM25 contribution for a block +// given its maxTF and query parameters k and b. +// With dl=0 (best case for scoring): score = (maxTF*(k+1)) / (maxTF + k*(1-b)) +func ComputeUBPre(maxTF uint32, k, b float64) float64 { + if maxTF == 0 { + return 0 + } + tf := float64(maxTF) + return tf * (k + 1) / (tf + k*(1-b)) +} + +// SuffixMaxUBPre computes suffix maxima of UBPre values for WAND. +// suffixMax[i] = max(ubPre[i], ubPre[i+1], ..., ubPre[n-1]) +func SuffixMaxUBPre(dir *Dir, k, b float64) []float64 { + n := len(dir.Blocks) + if n == 0 { + return nil + } + suf := make([]float64, n) + suf[n-1] = ComputeUBPre(dir.Blocks[n-1].MaxTF, k, b) + for i := n - 2; i >= 0; i-- { + ub := ComputeUBPre(dir.Blocks[i].MaxTF, k, b) + suf[i] = math.Max(ub, suf[i+1]) + } + return suf +} + +// BlockMetaFromEntries computes a BlockMeta from entries. +func BlockMetaFromEntries(blockID uint32, entries []bm25enc.Entry) BlockMeta { + if len(entries) == 0 { + return BlockMeta{BlockID: blockID} + } + var maxTF uint32 + for _, e := range entries { + if e.Value > maxTF { + maxTF = e.Value + } + } + return BlockMeta{ + FirstUID: entries[0].UID, + BlockID: blockID, + Count: uint32(len(entries)), + MaxTF: maxTF, + } +} diff --git a/posting/bm25block/bm25block_test.go b/posting/bm25block/bm25block_test.go new file mode 100644 index 00000000000..a7cc26f493a --- /dev/null +++ b/posting/bm25block/bm25block_test.go @@ -0,0 +1,258 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package bm25block + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" +) + +func TestDirRoundtrip(t *testing.T) { + dir := &Dir{ + NextID: 5, + Blocks: []BlockMeta{ + {FirstUID: 100, BlockID: 0, Count: 128, MaxTF: 10}, + {FirstUID: 500, BlockID: 1, Count: 128, MaxTF: 5}, + {FirstUID: 900, BlockID: 2, Count: 64, MaxTF: 20}, + }, + } + data := EncodeDir(dir) + got := DecodeDir(data) + require.Equal(t, dir.NextID, got.NextID) + require.Equal(t, dir.Blocks, got.Blocks) +} + +func TestDirRoundtripEmpty(t *testing.T) { + require.Nil(t, EncodeDir(nil)) + require.Nil(t, EncodeDir(&Dir{})) + + got := DecodeDir(nil) + require.Empty(t, got.Blocks) + got = DecodeDir([]byte{}) + require.Empty(t, got.Blocks) +} + +func TestDirRoundtripSingle(t *testing.T) { + dir := &Dir{ + NextID: 1, + Blocks: []BlockMeta{{FirstUID: 42, BlockID: 0, Count: 1, MaxTF: 3}}, + } + got := DecodeDir(EncodeDir(dir)) + require.Equal(t, dir.Blocks, got.Blocks) +} + +func TestFindBlock(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{ + {FirstUID: 100}, + {FirstUID: 500}, + {FirstUID: 900}, + }, + } + require.Equal(t, 0, dir.FindBlock(50)) // before first block + require.Equal(t, 0, dir.FindBlock(100)) // exact first + require.Equal(t, 0, dir.FindBlock(200)) // within first block + require.Equal(t, 1, dir.FindBlock(500)) // exact second + require.Equal(t, 1, dir.FindBlock(700)) // within second block + require.Equal(t, 2, dir.FindBlock(900)) // exact third + require.Equal(t, 2, dir.FindBlock(9999)) // beyond last block +} + +func TestFindBlockEmpty(t *testing.T) { + dir := &Dir{} + require.Equal(t, 0, dir.FindBlock(100)) +} + +func TestAllocBlockID(t *testing.T) { + dir := &Dir{NextID: 3} + require.Equal(t, uint32(3), dir.AllocBlockID()) + require.Equal(t, uint32(4), dir.AllocBlockID()) + require.Equal(t, uint32(5), dir.NextID) +} + +func TestSplitIntoBlocks(t *testing.T) { + // Create 300 entries. + entries := make([]bm25enc.Entry, 300) + for i := range entries { + entries[i] = bm25enc.Entry{UID: uint64(i + 1), Value: uint32(i%10 + 1)} + } + dir, blockMap := SplitIntoBlocks(entries) + + // Should split into ceil(300/128) = 3 blocks. + require.Len(t, dir.Blocks, 3) + require.Len(t, blockMap, 3) + + // First block: 128 entries. + require.Equal(t, uint32(128), dir.Blocks[0].Count) + require.Equal(t, uint64(1), dir.Blocks[0].FirstUID) + require.Len(t, blockMap[dir.Blocks[0].BlockID], 128) + + // Second block: 128 entries. + require.Equal(t, uint32(128), dir.Blocks[1].Count) + require.Equal(t, uint64(129), dir.Blocks[1].FirstUID) + + // Third block: 44 entries. + require.Equal(t, uint32(44), dir.Blocks[2].Count) + require.Equal(t, uint64(257), dir.Blocks[2].FirstUID) + + // NextID should be 3. + require.Equal(t, uint32(3), dir.NextID) +} + +func TestSplitIntoBlocksEmpty(t *testing.T) { + dir, blockMap := SplitIntoBlocks(nil) + require.Empty(t, dir.Blocks) + require.Nil(t, blockMap) +} + +func TestSplitIntoBlocksSmall(t *testing.T) { + entries := []bm25enc.Entry{{UID: 1, Value: 5}, {UID: 2, Value: 3}} + dir, blockMap := SplitIntoBlocks(entries) + require.Len(t, dir.Blocks, 1) + require.Equal(t, uint32(2), dir.Blocks[0].Count) + require.Equal(t, uint32(5), dir.Blocks[0].MaxTF) + require.Equal(t, entries, blockMap[0]) +} + +func TestUpdateBlockMeta(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{{FirstUID: 100, BlockID: 0, Count: 3, MaxTF: 5}}, + } + entries := []bm25enc.Entry{ + {UID: 50, Value: 2}, + {UID: 100, Value: 8}, + {UID: 200, Value: 3}, + {UID: 300, Value: 1}, + } + dir.UpdateBlockMeta(0, entries) + require.Equal(t, uint64(50), dir.Blocks[0].FirstUID) + require.Equal(t, uint32(4), dir.Blocks[0].Count) + require.Equal(t, uint32(8), dir.Blocks[0].MaxTF) +} + +func TestInsertRemoveBlockMeta(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{ + {FirstUID: 100, BlockID: 0}, + {FirstUID: 500, BlockID: 1}, + }, + } + dir.InsertBlockMeta(1, BlockMeta{FirstUID: 300, BlockID: 2}) + require.Len(t, dir.Blocks, 3) + require.Equal(t, uint64(300), dir.Blocks[1].FirstUID) + require.Equal(t, uint64(500), dir.Blocks[2].FirstUID) + + dir.RemoveBlockMeta(1) + require.Len(t, dir.Blocks, 2) + require.Equal(t, uint64(500), dir.Blocks[1].FirstUID) +} + +func TestComputeUBPre(t *testing.T) { + k, b := 1.2, 0.75 + + // maxTF=0 -> 0 + require.Equal(t, 0.0, ComputeUBPre(0, k, b)) + + // maxTF=1: 1 * 2.2 / (1 + 1.2*0.25) = 2.2 / 1.3 + expected := 2.2 / 1.3 + require.InEpsilon(t, expected, ComputeUBPre(1, k, b), 1e-9) + + // maxTF=10: 10 * 2.2 / (10 + 1.2*0.25) = 22 / 10.3 + expected = 22.0 / 10.3 + require.InEpsilon(t, expected, ComputeUBPre(10, k, b), 1e-9) + + // With b=0: score = tf*(k+1)/(tf+k) — no length normalization. + expected = 5.0 * 2.2 / (5.0 + 1.2) + require.InEpsilon(t, expected, ComputeUBPre(5, k, 0), 1e-9) +} + +func TestSuffixMaxUBPre(t *testing.T) { + dir := &Dir{ + Blocks: []BlockMeta{ + {MaxTF: 1}, + {MaxTF: 10}, + {MaxTF: 3}, + }, + } + k, b := 1.2, 0.75 + suf := SuffixMaxUBPre(dir, k, b) + require.Len(t, suf, 3) + + ub0 := ComputeUBPre(1, k, b) + ub1 := ComputeUBPre(10, k, b) + ub2 := ComputeUBPre(3, k, b) + + require.InEpsilon(t, math.Max(ub0, math.Max(ub1, ub2)), suf[0], 1e-9) + require.InEpsilon(t, math.Max(ub1, ub2), suf[1], 1e-9) + require.InEpsilon(t, ub2, suf[2], 1e-9) +} + +func TestSuffixMaxUBPreEmpty(t *testing.T) { + require.Nil(t, SuffixMaxUBPre(&Dir{}, 1.2, 0.75)) +} + +func TestMergeAllBlocks(t *testing.T) { + // Simulate overlapping blocks with a tombstone. + blocks := map[uint32][]bm25enc.Entry{ + 0: {{UID: 1, Value: 3}, {UID: 5, Value: 1}}, + 1: {{UID: 5, Value: 7}, {UID: 10, Value: 2}}, // UID 5 overrides + 2: {{UID: 15, Value: 0}, {UID: 20, Value: 4}}, // UID 15 is tombstone + } + dir := &Dir{ + Blocks: []BlockMeta{ + {FirstUID: 1, BlockID: 0, Count: 2}, + {FirstUID: 5, BlockID: 1, Count: 2}, + {FirstUID: 15, BlockID: 2, Count: 2}, + }, + NextID: 3, + } + newDir, newBlocks := MergeAllBlocks(dir, func(id uint32) []bm25enc.Entry { + return blocks[id] + }) + // After merge: UID 1(3), 5(7), 10(2), 20(4) — UID 15 removed (tombstone). + require.Len(t, newDir.Blocks, 1) // 4 entries fits in one block + require.Len(t, newBlocks, 1) + entries := newBlocks[newDir.Blocks[0].BlockID] + require.Len(t, entries, 4) + require.Equal(t, uint64(1), entries[0].UID) + require.Equal(t, uint32(3), entries[0].Value) + require.Equal(t, uint64(5), entries[1].UID) + require.Equal(t, uint32(7), entries[1].Value) + require.Equal(t, uint64(20), entries[3].UID) +} + +func TestBlockMetaFromEntries(t *testing.T) { + entries := []bm25enc.Entry{ + {UID: 10, Value: 2}, + {UID: 20, Value: 8}, + {UID: 30, Value: 1}, + } + meta := BlockMetaFromEntries(5, entries) + require.Equal(t, uint32(5), meta.BlockID) + require.Equal(t, uint64(10), meta.FirstUID) + require.Equal(t, uint32(3), meta.Count) + require.Equal(t, uint32(8), meta.MaxTF) +} + +func TestBlockMetaFromEntriesEmpty(t *testing.T) { + meta := BlockMetaFromEntries(0, nil) + require.Equal(t, uint32(0), meta.Count) +} + +func BenchmarkSplitIntoBlocks(b *testing.B) { + entries := make([]bm25enc.Entry, 100000) + for i := range entries { + entries[i] = bm25enc.Entry{UID: uint64(i*3 + 1), Value: uint32(i%100 + 1)} + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + SplitIntoBlocks(entries) + } +} diff --git a/x/keys.go b/x/keys.go index 23196fd89c9..0a23ba19c6a 100644 --- a/x/keys.go +++ b/x/keys.go @@ -310,6 +310,30 @@ func BM25StatsKey(attr string) []byte { return IndexKey(attr, BM25Prefix+"__stats__") } +// BM25TermDirKey generates the key for a BM25 term's block directory. +func BM25TermDirKey(attr, term string) []byte { + return IndexKey(attr, BM25Prefix+"__dir__"+term) +} + +// BM25TermBlockKey generates the key for an individual BM25 term posting block. +func BM25TermBlockKey(attr, term string, blockID uint32) []byte { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], blockID) + return IndexKey(attr, BM25Prefix+"__blk__"+term+string(buf[:])) +} + +// BM25DocLenDirKey generates the key for the BM25 document-length block directory. +func BM25DocLenDirKey(attr string) []byte { + return IndexKey(attr, BM25Prefix+"__dldir__") +} + +// BM25DocLenBlockKey generates the key for an individual BM25 document-length block. +func BM25DocLenBlockKey(attr string, segID uint32) []byte { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], segID) + return IndexKey(attr, BM25Prefix+"__dlblk__"+string(buf[:])) +} + // ParsedKey represents a key that has been parsed into its multiple attributes. type ParsedKey struct { Attr string From 59f18cc2631a1d26b7543f998f4cf4e62e6c95a2 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:11:04 -0500 Subject: [PATCH 07/19] feat(bm25): segmented block writes and WAND/Block-Max WAND query path Phases 2-4 of BM25 scaling plan: Phase 2 - Segmented mutation path: - addBM25IndexMutations now writes to block-based storage - Each term's postings split into ~128-entry blocks with a directory - Blocks automatically split when exceeding 256 entries - Doc-length list also uses block-based storage - Block removal and directory cleanup on deletes Phase 3 - WAND top-k query path: - New bm25wand.go with listIter for block-based posting list iteration - WAND algorithm with min-heap for top-k early termination - Per-block upper bounds (UBPre) computed from maxTF at query time - Suffix-max UBPre for efficient threshold checking - Falls back to scoring all docs when no first: limit or offset is used Phase 4 - Block-Max WAND: - skipToWithBMW skips entire blocks whose UB + other terms can't beat theta - Avoids Badger reads for blocks that can't contribute to top-k - Enabled by default in handleBM25Search Co-Authored-By: Claude Opus 4.6 --- posting/index.go | 183 ++++++++++++++--- worker/bm25wand.go | 501 +++++++++++++++++++++++++++++++++++++++++++++ worker/task.go | 95 ++------- 3 files changed, 668 insertions(+), 111 deletions(-) create mode 100644 worker/bm25wand.go diff --git a/posting/index.go b/posting/index.go index 826355a3633..d2eadd904e0 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,6 +28,7 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" + "github.com/dgraph-io/dgraph/v25/posting/bm25block" "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" @@ -232,8 +233,9 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok } // addBM25IndexMutations handles index mutations for the BM25 tokenizer. -// It stores term frequencies, document lengths, and corpus statistics as direct -// Badger KV entries using compact varint encoding, bypassing posting lists. +// It stores term frequencies, document lengths, and corpus statistics using +// block-based storage: each term's postings and the doclen list are split into +// fixed-size blocks (~128 entries) with a lightweight directory for navigation. func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { attr := info.edge.Attr uid := info.edge.Entity @@ -261,45 +263,168 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn } if info.op == pb.DirectedEdge_DEL { - // For DELETE: remove uid from all BM25 term posting lists and doc length list. + // For DELETE: remove uid from all term blocks and doclen blocks. for term := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term - key := x.BM25IndexKey(attr, encodedTerm) - blob := txn.cache.ReadBM25Blob(key) - entries := bm25enc.Decode(blob) - entries = bm25enc.Remove(entries, uid) - txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) - } - // Remove doc length entry. - dlKey := x.BM25DocLenKey(attr) - blob := txn.cache.ReadBM25Blob(dlKey) - entries := bm25enc.Decode(blob) - entries = bm25enc.Remove(entries, uid) - txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) - - // Update corpus stats: decrement doc count and total terms. + txn.bm25BlockRemove(attr, encodedTerm, uid) + } + txn.bm25DocLenBlockRemove(attr, uid) return txn.updateBM25Stats(attr, -1, -int64(docLen)) } - // For SET: store term frequencies and doc length. + // For SET: upsert term frequencies and doc length into blocks. for term, tf := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term - key := x.BM25IndexKey(attr, encodedTerm) - blob := txn.cache.ReadBM25Blob(key) - entries := bm25enc.Decode(blob) - entries = bm25enc.Upsert(entries, uid, tf) - txn.cache.WriteBM25Blob(key, bm25enc.Encode(entries)) + txn.bm25BlockUpsert(attr, encodedTerm, uid, tf) + } + txn.bm25DocLenBlockUpsert(attr, uid, docLen) + return txn.updateBM25Stats(attr, 1, int64(docLen)) +} + +// bm25BlockUpsert inserts or updates a (uid, value) entry in the block-based +// posting list for the given term. Handles block creation and splitting. +func (txn *Txn) bm25BlockUpsert(attr, encodedTerm string, uid uint64, value uint32) { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + // First entry for this term: create a single block. + blockID := dir.AllocBlockID() + entries := []bm25enc.Entry{{UID: uid, Value: value}} + blockKey := x.BM25TermBlockKey(attr, encodedTerm, blockID) + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.Blocks = append(dir.Blocks, bm25block.BlockMetaFromEntries(blockID, entries)) + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) + return + } + + // Find the target block, read it, upsert, and handle splits. + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25TermBlockKey(attr, encodedTerm, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Upsert(entries, uid, value) + + if len(entries) > bm25block.MaxBlockSize { + // Split the block. + mid := len(entries) / 2 + left := entries[:mid] + right := entries[mid:] + + // Write left block (reuse existing blockID). + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(left)) + dir.UpdateBlockMeta(blockIdx, left) + + // Write right block (new blockID). + newBlockID := dir.AllocBlockID() + newBlockKey := x.BM25TermBlockKey(attr, encodedTerm, newBlockID) + txn.cache.WriteBM25Blob(newBlockKey, bm25enc.Encode(right)) + dir.InsertBlockMeta(blockIdx+1, bm25block.BlockMetaFromEntries(newBlockID, right)) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) +} + +// bm25BlockRemove removes a uid from the block-based posting list for the given term. +func (txn *Txn) bm25BlockRemove(attr, encodedTerm string, uid uint64) { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return } - // Store document length. - dlKey := x.BM25DocLenKey(attr) - blob := txn.cache.ReadBM25Blob(dlKey) + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25TermBlockKey(attr, encodedTerm, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + + if len(entries) == 0 { + // Block is empty; remove it from the directory. + txn.cache.WriteBM25Blob(blockKey, nil) + dir.RemoveBlockMeta(blockIdx) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) +} + +// bm25DocLenBlockUpsert inserts or updates a doc-length entry in the block-based +// document-length list. +func (txn *Txn) bm25DocLenBlockUpsert(attr string, uid uint64, docLen uint32) { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + blockID := dir.AllocBlockID() + entries := []bm25enc.Entry{{UID: uid, Value: docLen}} + blockKey := x.BM25DocLenBlockKey(attr, blockID) + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.Blocks = append(dir.Blocks, bm25block.BlockMetaFromEntries(blockID, entries)) + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) + return + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) entries := bm25enc.Decode(blob) entries = bm25enc.Upsert(entries, uid, docLen) - txn.cache.WriteBM25Blob(dlKey, bm25enc.Encode(entries)) - // Update corpus stats: increment doc count by 1 and total terms by docLen. - return txn.updateBM25Stats(attr, 1, int64(docLen)) + if len(entries) > bm25block.MaxBlockSize { + mid := len(entries) / 2 + left := entries[:mid] + right := entries[mid:] + + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(left)) + dir.UpdateBlockMeta(blockIdx, left) + + newBlockID := dir.AllocBlockID() + newBlockKey := x.BM25DocLenBlockKey(attr, newBlockID) + txn.cache.WriteBM25Blob(newBlockKey, bm25enc.Encode(right)) + dir.InsertBlockMeta(blockIdx+1, bm25block.BlockMetaFromEntries(newBlockID, right)) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) +} + +// bm25DocLenBlockRemove removes a uid from the block-based document-length list. +func (txn *Txn) bm25DocLenBlockRemove(attr string, uid uint64) { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + entries = bm25enc.Remove(entries, uid) + + if len(entries) == 0 { + txn.cache.WriteBM25Blob(blockKey, nil) + dir.RemoveBlockMeta(blockIdx) + } else { + txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) + dir.UpdateBlockMeta(blockIdx, entries) + } + txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) } // updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, diff --git a/worker/bm25wand.go b/worker/bm25wand.go new file mode 100644 index 00000000000..7fc9dc74ec1 --- /dev/null +++ b/worker/bm25wand.go @@ -0,0 +1,501 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "sort" + + "github.com/dgraph-io/dgraph/v25/posting" + "github.com/dgraph-io/dgraph/v25/posting/bm25block" + "github.com/dgraph-io/dgraph/v25/posting/bm25enc" + "github.com/dgraph-io/dgraph/v25/x" +) + +// listIter iterates over a term's block-based posting list for WAND scoring. +type listIter struct { + attr string + encodedTerm string + readTs uint64 + idf float64 + k, b float64 + + dir *bm25block.Dir + ubPreSuf []float64 // suffix max of UBPre values + blockIdx int // current block index in dir.Blocks + block []bm25enc.Entry // decoded current block + inBlockPos int // position within current block + + exhausted bool +} + +// newListIter creates a new iterator for a term's block-based posting list. +func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *listIter { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return &listIter{exhausted: true} + } + + it := &listIter{ + attr: attr, + encodedTerm: encodedTerm, + readTs: readTs, + idf: idf, + k: k, + b: b, + dir: dir, + ubPreSuf: bm25block.SuffixMaxUBPre(dir, k, b), + blockIdx: -1, // will be advanced on first Next() + } + return it +} + +// currentDoc returns the UID at the current position. +func (it *listIter) currentDoc() uint64 { + if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + return math.MaxUint64 + } + return it.block[it.inBlockPos].UID +} + +// currentTF returns the term frequency at the current position. +func (it *listIter) currentTF() uint32 { + if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + return 0 + } + return it.block[it.inBlockPos].Value +} + +// remainingUB returns the IDF-weighted upper-bound score for the remaining postings. +func (it *listIter) remainingUB() float64 { + if it.exhausted || it.blockIdx >= len(it.ubPreSuf) { + return 0 + } + return it.idf * it.ubPreSuf[it.blockIdx] +} + +// blockUB returns the IDF-weighted upper-bound for the current block only. +func (it *listIter) blockUB() float64 { + if it.exhausted || it.blockIdx < 0 || it.blockIdx >= len(it.dir.Blocks) { + return 0 + } + return it.idf * bm25block.ComputeUBPre(it.dir.Blocks[it.blockIdx].MaxTF, it.k, it.b) +} + +// next advances to the next posting. Returns false if exhausted. +func (it *listIter) next() bool { + if it.exhausted { + return false + } + + // Try advancing within the current block. + if it.block != nil { + it.inBlockPos++ + if it.inBlockPos < len(it.block) { + return true + } + } + + // Move to the next block. + it.blockIdx++ + if it.blockIdx >= len(it.dir.Blocks) { + it.exhausted = true + return false + } + it.loadBlock(it.blockIdx) + return it.inBlockPos < len(it.block) +} + +// skipTo advances to the first posting with UID >= target. +// Returns false if exhausted. +func (it *listIter) skipTo(target uint64) bool { + if it.exhausted { + return false + } + + // If current doc is already >= target, no-op. + if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + return true + } + + // Check if target might be in the current block. + if it.block != nil && it.blockIdx < len(it.dir.Blocks) { + lastInBlock := it.block[len(it.block)-1].UID + if target <= lastInBlock { + // Binary search within current block. + pos := sort.Search(len(it.block)-it.inBlockPos, func(i int) bool { + return it.block[it.inBlockPos+i].UID >= target + }) + it.inBlockPos += pos + if it.inBlockPos < len(it.block) { + return true + } + } + } + + // Find the right block using the directory. + blockIdx := it.findBlockForTarget(target) + if blockIdx >= len(it.dir.Blocks) { + it.exhausted = true + return false + } + + it.blockIdx = blockIdx + it.loadBlock(blockIdx) + + // Binary search within the block. + pos := sort.Search(len(it.block), func(i int) bool { + return it.block[i].UID >= target + }) + it.inBlockPos = pos + if pos >= len(it.block) { + // Target is beyond this block; try the next. + return it.next() + } + return true +} + +// skipToWithBMW is like skipTo but uses Block-Max WAND to skip entire blocks +// whose upper bounds can't beat the given threshold. +func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) bool { + if it.exhausted { + return false + } + + // If current doc is already >= target, no-op. + if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + return true + } + + blockIdx := it.findBlockForTarget(target) + for blockIdx < len(it.dir.Blocks) { + // Check if this block's UB combined with other terms can beat theta. + blockUB := it.idf * bm25block.ComputeUBPre(it.dir.Blocks[blockIdx].MaxTF, it.k, it.b) + if blockUB+otherUB > theta { + // This block might have a winner; load and search it. + it.blockIdx = blockIdx + it.loadBlock(blockIdx) + pos := sort.Search(len(it.block), func(i int) bool { + return it.block[i].UID >= target + }) + it.inBlockPos = pos + if pos < len(it.block) { + return true + } + // Fall through to next block. + } + blockIdx++ + // Update target to the next block's firstUID (we've already skipped past target). + if blockIdx < len(it.dir.Blocks) { + target = it.dir.Blocks[blockIdx].FirstUID + } + } + it.exhausted = true + return false +} + +// findBlockForTarget returns the block index that should contain target. +func (it *listIter) findBlockForTarget(target uint64) int { + blocks := it.dir.Blocks + idx := sort.Search(len(blocks), func(i int) bool { + return blocks[i].FirstUID > target + }) + if idx > 0 { + return idx - 1 + } + return 0 +} + +// loadBlock decodes the block at the given directory index. +func (it *listIter) loadBlock(idx int) { + bm := it.dir.Blocks[idx] + blockKey := x.BM25TermBlockKey(it.attr, it.encodedTerm, bm.BlockID) + blob := posting.ReadBM25BlobAt(blockKey, it.readTs) + it.block = bm25enc.Decode(blob) + it.inBlockPos = 0 +} + +// scoredDoc holds a UID and its BM25 score for the min-heap. +type scoredDoc struct { + uid uint64 + score float64 +} + +// topKHeap is a min-heap of scored documents for top-k tracking. +type topKHeap struct { + docs []scoredDoc + k int +} + +func (h *topKHeap) Len() int { return len(h.docs) } +func (h *topKHeap) Less(i, j int) bool { return h.docs[i].score < h.docs[j].score } +func (h *topKHeap) Swap(i, j int) { h.docs[i], h.docs[j] = h.docs[j], h.docs[i] } +func (h *topKHeap) Push(x interface{}) { h.docs = append(h.docs, x.(scoredDoc)) } +func (h *topKHeap) Pop() interface{} { + old := h.docs + n := len(old) + item := old[n-1] + h.docs = old[:n-1] + return item +} + +// threshold returns the minimum score in the heap (the score to beat). +func (h *topKHeap) threshold() float64 { + if len(h.docs) < h.k { + return 0 + } + return h.docs[0].score +} + +// tryPush adds a doc if it beats the current threshold. Returns true if the +// threshold changed. +func (h *topKHeap) tryPush(uid uint64, score float64) bool { + if len(h.docs) < h.k { + heap.Push(h, scoredDoc{uid: uid, score: score}) + return len(h.docs) == h.k // threshold only meaningful once heap is full + } + if score > h.docs[0].score { + h.docs[0] = scoredDoc{uid: uid, score: score} + heap.Fix(h, 0) + return true + } + return false +} + +// sorted returns all docs sorted by score descending, then UID ascending. +func (h *topKHeap) sorted() []scoredDoc { + result := make([]scoredDoc, len(h.docs)) + copy(result, h.docs) + sort.Slice(result, func(i, j int) bool { + if result[i].score != result[j].score { + return result[i].score > result[j].score + } + return result[i].uid < result[j].uid + }) + return result +} + +// bm25Score computes the BM25 score for a single term occurrence. +func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { + return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) +} + +// lookupDocLen looks up a single UID's document length from the block-based doclen store. +func lookupDocLen(attr string, uid, readTs uint64) float64 { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return 1.0 // fallback + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := posting.ReadBM25BlobAt(blockKey, readTs) + entries := bm25enc.Decode(blob) + if v, ok := bm25enc.Search(entries, uid); ok { + return float64(v) + } + return 1.0 +} + +// wandSearch performs a WAND top-k search over block-based posting lists. +// If topK <= 0, it scores all matching documents (no early termination). +func wandSearch(attr string, readTs uint64, queryTokens []string, + k, b, avgDL, N float64, topK int, filterSet map[uint64]struct{}, + useBMW bool) []scoredDoc { + + // Build iterators for each query term. + var iters []*listIter + for _, token := range queryTokens { + dirKey := x.BM25TermDirKey(attr, token) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir := bm25block.DecodeDir(dirBlob) + if len(dir.Blocks) == 0 { + continue + } + + // Compute df from directory. + var df uint64 + for _, bm := range dir.Blocks { + df += uint64(bm.Count) + } + idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) + + it := newListIter(attr, token, readTs, idf, k, b) + if !it.exhausted { + it.next() // prime the iterator + if !it.exhausted { + iters = append(iters, it) + } + } + } + + if len(iters) == 0 { + return nil + } + + // If no top-k limit, score all matching documents. + if topK <= 0 { + return scoreAllDocs(iters, attr, readTs, k, b, avgDL, filterSet) + } + + // WAND algorithm with top-k heap. + h := &topKHeap{k: topK} + heap.Init(h) + + for { + // Remove exhausted iterators. + active := iters[:0] + for _, it := range iters { + if !it.exhausted { + active = append(active, it) + } + } + iters = active + if len(iters) == 0 { + break + } + + // Sort iterators by currentDoc ascending. + sort.Slice(iters, func(i, j int) bool { + return iters[i].currentDoc() < iters[j].currentDoc() + }) + + theta := h.threshold() + + // Find pivot: accumulate UBs until they exceed theta. + var sumUB float64 + pivot := -1 + var pivotDoc uint64 + for i, it := range iters { + sumUB += it.remainingUB() + if sumUB > theta { + pivot = i + pivotDoc = it.currentDoc() + break + } + } + if pivot == -1 { + break // sum of all UBs can't beat theta + } + + // Advance all iterators before pivot to pivotDoc. + allAtPivot := true + for i := 0; i < pivot; i++ { + if iters[i].currentDoc() < pivotDoc { + var ok bool + if useBMW { + // Compute other UBs for BMW skipping. + var otherUB float64 + for j, jt := range iters { + if j != i { + otherUB += jt.remainingUB() + } + } + ok = iters[i].skipToWithBMW(pivotDoc, theta, otherUB) + } else { + ok = iters[i].skipTo(pivotDoc) + } + if !ok { + allAtPivot = false + break + } + if iters[i].currentDoc() != pivotDoc { + allAtPivot = false + } + } + } + + if !allAtPivot { + continue // re-evaluate after advances + } + + // All iterators up to pivot are at pivotDoc. Score the candidate. + if filterSet != nil { + if _, ok := filterSet[pivotDoc]; !ok { + // Skip this doc (filtered out). Advance all iters at pivotDoc. + for _, it := range iters { + if it.currentDoc() == pivotDoc { + it.next() + } + } + continue + } + } + + dl := lookupDocLen(attr, pivotDoc, readTs) + var score float64 + for _, it := range iters { + if it.currentDoc() == pivotDoc { + tf := float64(it.currentTF()) + score += bm25Score(it.idf, tf, dl, avgDL, k, b) + } + } + h.tryPush(pivotDoc, score) + + // Advance all iterators at pivotDoc. + for _, it := range iters { + if it.currentDoc() == pivotDoc { + it.next() + } + } + } + + return h.sorted() +} + +// scoreAllDocs scores every matching document without early termination. +// Used when no top-k limit is specified (the original behavior). +func scoreAllDocs(iters []*listIter, attr string, readTs uint64, + k, b, avgDL float64, filterSet map[uint64]struct{}) []scoredDoc { + + // Collect all (uid, term) matches. + type termMatch struct { + idf float64 + tf uint32 + } + matches := make(map[uint64][]termMatch) + + for _, it := range iters { + for !it.exhausted { + uid := it.currentDoc() + tf := it.currentTF() + if filterSet == nil { + matches[uid] = append(matches[uid], termMatch{idf: it.idf, tf: tf}) + } else if _, ok := filterSet[uid]; ok { + matches[uid] = append(matches[uid], termMatch{idf: it.idf, tf: tf}) + } + it.next() + } + } + + // Score all matching documents. + results := make([]scoredDoc, 0, len(matches)) + for uid, terms := range matches { + dl := lookupDocLen(attr, uid, readTs) + var score float64 + for _, tm := range terms { + score += bm25Score(tm.idf, float64(tm.tf), dl, avgDL, k, b) + } + results = append(results, scoredDoc{uid: uid, score: score}) + } + + // Sort by score descending, then UID ascending. + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].uid < results[j].uid + }) + return results +} diff --git a/worker/task.go b/worker/task.go index fbc3189a42b..55697973275 100644 --- a/worker/task.go +++ b/worker/task.go @@ -31,6 +31,7 @@ import ( "github.com/dgraph-io/dgraph/v25/conn" "github.com/dgraph-io/dgraph/v25/posting" "github.com/dgraph-io/dgraph/v25/posting/bm25enc" + // bm25block and bm25wand are used via bm25wand.go in this package. "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" ctask "github.com/dgraph-io/dgraph/v25/task" @@ -1273,7 +1274,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return nil } - // 3. Read corpus stats from direct Badger KV. + // 3. Read corpus stats. statsKey := x.BM25StatsKey(attr) statsBlob := posting.ReadBM25BlobAt(statsKey, q.ReadTs) docCount, totalTerms := bm25enc.DecodeStats(statsBlob) @@ -1284,7 +1285,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error avgDL := float64(totalTerms) / float64(docCount) N := float64(docCount) - // Build filter set early if used as a filter, for efficient intersection during iteration. + // Build filter set if used as a filter. var filterSet map[uint64]struct{} if q.UidList != nil && len(q.UidList.Uids) > 0 { filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) @@ -1293,86 +1294,18 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // 4. For each query token, read the BM25 term blob and collect term info. - type termInfo struct { - idf float64 - uidTFs map[uint64]uint32 + // 4. Determine top-k: use WAND when first is set and no offset. + // When offset is set or first is unset, score all documents. + topK := 0 + if q.First > 0 && q.Offset == 0 { + topK = int(q.First) } - termInfos := make(map[string]*termInfo) - for _, token := range queryTokens { - key := x.BM25IndexKey(attr, token) - blob := posting.ReadBM25BlobAt(key, q.ReadTs) - entries := bm25enc.Decode(blob) - if len(entries) == 0 { - continue - } - - ti := &termInfo{uidTFs: make(map[uint64]uint32)} - df := float64(len(entries)) - for _, e := range entries { - if filterSet != nil { - if _, ok := filterSet[e.UID]; !ok { - continue - } - } - ti.uidTFs[e.UID] = e.Value - } - ti.idf = math.Log1p((N - df + 0.5) / (df + 0.5)) - termInfos[token] = ti - } - - // 5. Read doc lengths for all UIDs seen using binary search on the doclen blob. - allUids := make(map[uint64]struct{}) - for _, ti := range termInfos { - for uid := range ti.uidTFs { - allUids[uid] = struct{}{} - } - } - - dlKey := x.BM25DocLenKey(attr) - dlBlob := posting.ReadBM25BlobAt(dlKey, q.ReadTs) - dlEntries := bm25enc.Decode(dlBlob) - - docLens := make(map[uint64]uint32, len(allUids)) - for uid := range allUids { - if v, ok := bm25enc.Search(dlEntries, uid); ok { - docLens[uid] = v - } - } - - // 6. Compute final BM25 scores. - scores := make(map[uint64]float64) - for _, ti := range termInfos { - for uid, tf := range ti.uidTFs { - dl := float64(1) - if v, ok := docLens[uid]; ok { - dl = float64(v) - } - tfFloat := float64(tf) - score := ti.idf * (k + 1) * tfFloat / (k*(1-b+b*dl/avgDL) + tfFloat) - scores[uid] += score - } - } - - // 7. Sort by score descending. - type uidScore struct { - uid uint64 - score float64 - } - results := make([]uidScore, 0, len(scores)) - for uid, score := range scores { - results = append(results, uidScore{uid: uid, score: score}) - } - sort.Slice(results, func(i, j int) bool { - if results[i].score != results[j].score { - return results[i].score > results[j].score - } - return results[i].uid < results[j].uid - }) + // 5. Run WAND search over block-based posting lists (with Block-Max skipping). + results := wandSearch(attr, q.ReadTs, queryTokens, k, b, avgDL, N, topK, filterSet, true) - // Apply first/offset pagination on score-sorted results before returning UIDs. - if q.First > 0 || q.Offset > 0 { + // 6. Apply first/offset pagination on score-sorted results. + if topK <= 0 && (q.First > 0 || q.Offset > 0) { offset := int(q.Offset) if offset > len(results) { offset = len(results) @@ -1383,7 +1316,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // Build output: UIDs sorted ascending (required by query pipeline) + // 7. Build output: UIDs sorted ascending (required by query pipeline) // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) @@ -1392,8 +1325,6 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) - // Populate ValueMatrix with BM25 scores aligned to UIDs. - // Each entry is a ValueList with a single float64 value. scoreValues := make([]*pb.ValueList, len(results)) for i, r := range results { buf := make([]byte, 8) From 1a3ada3843d34370cefd7cc57897c96ae222ca3c Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:12:59 -0500 Subject: [PATCH 08/19] feat(bm25): add legacy format fallback for migration and WAND unit tests Phase 5 - Migration support: - newListIter falls back to legacy monolithic blob when no block directory exists - lookupDocLen falls back to legacy BM25DocLenKey blob - wandSearch falls back to legacy BM25IndexKey for df computation - Legacy data transparently served through synthetic single-block directory - New writes always use block format; old data works until overwritten Unit tests for WAND components: - TestTopKHeapBasic: heap operations, threshold, eviction - TestTopKHeapTieBreaking: deterministic ordering on score ties - TestBm25ScoreFunction: formula verification, tf/dl/b edge cases - TestBm25ScoreNaN: no NaN/Inf for edge-case inputs Co-Authored-By: Claude Opus 4.6 --- worker/bm25wand.go | 77 +++++++++++++++++++++++++++++---- worker/bm25wand_test.go | 96 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 9 deletions(-) create mode 100644 worker/bm25wand_test.go diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 7fc9dc74ec1..c946ecd9202 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -31,16 +31,55 @@ type listIter struct { inBlockPos int // position within current block exhausted bool + legacy bool // true if using legacy monolithic blob (migration fallback) } // newListIter creates a new iterator for a term's block-based posting list. +// Falls back to the legacy monolithic blob format if no block directory exists. func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *listIter { dirKey := x.BM25TermDirKey(attr, encodedTerm) dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) dir := bm25block.DecodeDir(dirBlob) if len(dir.Blocks) == 0 { - return &listIter{exhausted: true} + // Fallback: try reading the legacy monolithic blob and wrap it as a single block. + legacyKey := x.BM25IndexKey(attr, encodedTerm) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) + legacyEntries := bm25enc.Decode(legacyBlob) + if len(legacyEntries) == 0 { + return &listIter{exhausted: true} + } + // Build a synthetic single-block directory from the legacy data. + var maxTF uint32 + for _, e := range legacyEntries { + if e.Value > maxTF { + maxTF = e.Value + } + } + dir = &bm25block.Dir{ + NextID: 1, + Blocks: []bm25block.BlockMeta{{ + FirstUID: legacyEntries[0].UID, + BlockID: 0, + Count: uint32(len(legacyEntries)), + MaxTF: maxTF, + }}, + } + it := &listIter{ + attr: attr, + encodedTerm: encodedTerm, + readTs: readTs, + idf: idf, + k: k, + b: b, + dir: dir, + ubPreSuf: bm25block.SuffixMaxUBPre(dir, k, b), + blockIdx: 0, + block: legacyEntries, // pre-loaded + inBlockPos: -1, // will advance on first next() + legacy: true, + } + return it } it := &listIter{ @@ -215,6 +254,11 @@ func (it *listIter) findBlockForTarget(target uint64) int { // loadBlock decodes the block at the given directory index. func (it *listIter) loadBlock(idx int) { + if it.legacy { + // Legacy mode: single block already loaded. + it.inBlockPos = 0 + return + } bm := it.dir.Blocks[idx] blockKey := x.BM25TermBlockKey(it.attr, it.encodedTerm, bm.BlockID) blob := posting.ReadBM25BlobAt(blockKey, it.readTs) @@ -288,13 +332,21 @@ func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { } // lookupDocLen looks up a single UID's document length from the block-based doclen store. +// Falls back to the legacy monolithic doclen blob if no block directory exists. func lookupDocLen(attr string, uid, readTs uint64) float64 { dirKey := x.BM25DocLenDirKey(attr) dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) dir := bm25block.DecodeDir(dirBlob) if len(dir.Blocks) == 0 { - return 1.0 // fallback + // Fallback: try the legacy monolithic doclen blob. + legacyKey := x.BM25DocLenKey(attr) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) + legacyEntries := bm25enc.Decode(legacyBlob) + if v, ok := bm25enc.Search(legacyEntries, uid); ok { + return float64(v) + } + return 1.0 } blockIdx := dir.FindBlock(uid) @@ -317,17 +369,24 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, // Build iterators for each query term. var iters []*listIter for _, token := range queryTokens { + // Compute df: try block directory first, then fall back to legacy blob. + var df uint64 dirKey := x.BM25TermDirKey(attr, token) dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) dir := bm25block.DecodeDir(dirBlob) - if len(dir.Blocks) == 0 { - continue + if len(dir.Blocks) > 0 { + for _, bm := range dir.Blocks { + df += uint64(bm.Count) + } + } else { + // Legacy fallback: read the monolithic blob to get df. + legacyKey := x.BM25IndexKey(attr, token) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) + legacyEntries := bm25enc.Decode(legacyBlob) + df = uint64(len(legacyEntries)) } - - // Compute df from directory. - var df uint64 - for _, bm := range dir.Blocks { - df += uint64(bm.Count) + if df == 0 { + continue } idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) diff --git a/worker/bm25wand_test.go b/worker/bm25wand_test.go new file mode 100644 index 00000000000..5982f94d0b8 --- /dev/null +++ b/worker/bm25wand_test.go @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTopKHeapBasic(t *testing.T) { + h := &topKHeap{k: 3} + heap.Init(h) + + require.Equal(t, 0.0, h.threshold()) + + h.tryPush(1, 5.0) + h.tryPush(2, 3.0) + require.Equal(t, 0.0, h.threshold()) // not full yet + + h.tryPush(3, 7.0) + require.InEpsilon(t, 3.0, h.threshold(), 1e-9) // full, min is 3.0 + + h.tryPush(4, 4.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) // 3.0 evicted, min is now 4.0 + + // 2.0 shouldn't be accepted. + h.tryPush(5, 2.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) + + sorted := h.sorted() + require.Len(t, sorted, 3) + require.Equal(t, uint64(3), sorted[0].uid) // highest score (7.0) + require.Equal(t, uint64(1), sorted[1].uid) // 5.0 + require.Equal(t, uint64(4), sorted[2].uid) // 4.0 +} + +func TestTopKHeapTieBreaking(t *testing.T) { + h := &topKHeap{k: 5} + heap.Init(h) + + // Same score, different UIDs — should sort by UID ascending. + h.tryPush(10, 5.0) + h.tryPush(5, 5.0) + h.tryPush(15, 5.0) + + sorted := h.sorted() + require.Equal(t, uint64(5), sorted[0].uid) + require.Equal(t, uint64(10), sorted[1].uid) + require.Equal(t, uint64(15), sorted[2].uid) +} + +func TestBm25ScoreFunction(t *testing.T) { + k, b := 1.2, 0.75 + avgDL := 10.0 + + // idf * (k+1) * tf / (k*(1-b+b*dl/avgDL) + tf) + idf := 1.5 + tf := 3.0 + dl := 10.0 + + expected := idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) + got := bm25Score(idf, tf, dl, avgDL, k, b) + require.InEpsilon(t, expected, got, 1e-9) + + // With b=0: no length normalization. + expected0 := idf * (k + 1) * tf / (k + tf) + got0 := bm25Score(idf, tf, dl, avgDL, k, 0) + require.InEpsilon(t, expected0, got0, 1e-9) + + // Score should be positive for positive inputs. + require.Greater(t, bm25Score(1.0, 1.0, 5.0, 10.0, k, b), 0.0) + + // Higher tf should produce higher score (same dl). + s1 := bm25Score(idf, 1.0, dl, avgDL, k, b) + s3 := bm25Score(idf, 3.0, dl, avgDL, k, b) + require.Greater(t, s3, s1) + + // Shorter doc should score higher (same tf). + sShort := bm25Score(idf, tf, 5.0, avgDL, k, b) + sLong := bm25Score(idf, tf, 20.0, avgDL, k, b) + require.Greater(t, sShort, sLong) +} + +func TestBm25ScoreNaN(t *testing.T) { + // Ensure no NaN/Inf for edge-case inputs. + score := bm25Score(0.5, 1.0, 0.0, 10.0, 1.2, 0.75) + require.False(t, math.IsNaN(score)) + require.False(t, math.IsInf(score, 0)) + require.Greater(t, score, 0.0) +} From fc6a2128f4b1e7c774477691135dccffdcad80ba Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:26:11 -0500 Subject: [PATCH 09/19] fix(bm25): address GPT-5 code review findings in WAND implementation Fixes critical bugs and performance issues identified by GPT-5 review: - Fix negative inBlockPos panic: guard currentDoc/currentTF/skipTo against inBlockPos < 0 (possible before first next() call) - Fix empty block pathological behavior: next()/skipTo()/skipToWithBMW() now skip empty blocks instead of leaving iterator in invalid state with MaxUint64 pivotDoc - Fix legacy loadBlock: no longer resets inBlockPos to 0 (was moving pointer backwards, could cause re-scoring or infinite loops) - Fix remainingUB panic: guard against blockIdx < 0 (before first next()) - Add docLenCache: caches doclen directory + block reads within a single query, avoiding repeated Badger reads per scored document - Optimize BMW otherUB: compute as sumUB - thisUB (O(1)) instead of iterating all other terms (O(q^2) -> O(q)) Co-Authored-By: Claude Opus 4.6 --- worker/bm25wand.go | 168 ++++++++++++++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 55 deletions(-) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index c946ecd9202..5aab0cd0e44 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -24,11 +24,11 @@ type listIter struct { idf float64 k, b float64 - dir *bm25block.Dir - ubPreSuf []float64 // suffix max of UBPre values - blockIdx int // current block index in dir.Blocks - block []bm25enc.Entry // decoded current block - inBlockPos int // position within current block + dir *bm25block.Dir + ubPreSuf []float64 // suffix max of UBPre values + blockIdx int // current block index in dir.Blocks + block []bm25enc.Entry // decoded current block + inBlockPos int // position within current block exhausted bool legacy bool // true if using legacy monolithic blob (migration fallback) @@ -98,7 +98,7 @@ func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *li // currentDoc returns the UID at the current position. func (it *listIter) currentDoc() uint64 { - if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + if it.exhausted || it.block == nil || it.inBlockPos < 0 || it.inBlockPos >= len(it.block) { return math.MaxUint64 } return it.block[it.inBlockPos].UID @@ -106,7 +106,7 @@ func (it *listIter) currentDoc() uint64 { // currentTF returns the term frequency at the current position. func (it *listIter) currentTF() uint32 { - if it.exhausted || it.block == nil || it.inBlockPos >= len(it.block) { + if it.exhausted || it.block == nil || it.inBlockPos < 0 || it.inBlockPos >= len(it.block) { return 0 } return it.block[it.inBlockPos].Value @@ -114,10 +114,17 @@ func (it *listIter) currentTF() uint32 { // remainingUB returns the IDF-weighted upper-bound score for the remaining postings. func (it *listIter) remainingUB() float64 { - if it.exhausted || it.blockIdx >= len(it.ubPreSuf) { + if it.exhausted || len(it.ubPreSuf) == 0 { return 0 } - return it.idf * it.ubPreSuf[it.blockIdx] + idx := it.blockIdx + if idx < 0 { + idx = 0 + } + if idx >= len(it.ubPreSuf) { + return 0 + } + return it.idf * it.ubPreSuf[idx] } // blockUB returns the IDF-weighted upper-bound for the current block only. @@ -137,19 +144,24 @@ func (it *listIter) next() bool { // Try advancing within the current block. if it.block != nil { it.inBlockPos++ - if it.inBlockPos < len(it.block) { + if it.inBlockPos >= 0 && it.inBlockPos < len(it.block) { return true } } // Move to the next block. - it.blockIdx++ - if it.blockIdx >= len(it.dir.Blocks) { - it.exhausted = true - return false + for { + it.blockIdx++ + if it.blockIdx >= len(it.dir.Blocks) { + it.exhausted = true + return false + } + it.loadBlock(it.blockIdx) + if len(it.block) > 0 { + return true + } + // Empty block (corruption/race): skip it. } - it.loadBlock(it.blockIdx) - return it.inBlockPos < len(it.block) } // skipTo advances to the first posting with UID >= target. @@ -160,19 +172,25 @@ func (it *listIter) skipTo(target uint64) bool { } // If current doc is already >= target, no-op. - if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + if it.block != nil && it.inBlockPos >= 0 && it.inBlockPos < len(it.block) && + it.block[it.inBlockPos].UID >= target { return true } // Check if target might be in the current block. - if it.block != nil && it.blockIdx < len(it.dir.Blocks) { + if it.block != nil && len(it.block) > 0 && it.blockIdx >= 0 && + it.blockIdx < len(it.dir.Blocks) { lastInBlock := it.block[len(it.block)-1].UID if target <= lastInBlock { - // Binary search within current block. - pos := sort.Search(len(it.block)-it.inBlockPos, func(i int) bool { - return it.block[it.inBlockPos+i].UID >= target + startPos := it.inBlockPos + if startPos < 0 { + startPos = 0 + } + // Binary search within current block from startPos. + pos := sort.Search(len(it.block)-startPos, func(i int) bool { + return it.block[startPos+i].UID >= target }) - it.inBlockPos += pos + it.inBlockPos = startPos + pos if it.inBlockPos < len(it.block) { return true } @@ -188,6 +206,9 @@ func (it *listIter) skipTo(target uint64) bool { it.blockIdx = blockIdx it.loadBlock(blockIdx) + if len(it.block) == 0 { + return it.next() // skip empty block + } // Binary search within the block. pos := sort.Search(len(it.block), func(i int) bool { @@ -209,7 +230,8 @@ func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) } // If current doc is already >= target, no-op. - if it.block != nil && it.inBlockPos < len(it.block) && it.block[it.inBlockPos].UID >= target { + if it.block != nil && it.inBlockPos >= 0 && it.inBlockPos < len(it.block) && + it.block[it.inBlockPos].UID >= target { return true } @@ -221,6 +243,10 @@ func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) // This block might have a winner; load and search it. it.blockIdx = blockIdx it.loadBlock(blockIdx) + if len(it.block) == 0 { + blockIdx++ + continue // skip empty block + } pos := sort.Search(len(it.block), func(i int) bool { return it.block[i].UID >= target }) @@ -231,7 +257,7 @@ func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) // Fall through to next block. } blockIdx++ - // Update target to the next block's firstUID (we've already skipped past target). + // Update target to the next block's firstUID. if blockIdx < len(it.dir.Blocks) { target = it.dir.Blocks[blockIdx].FirstUID } @@ -255,8 +281,7 @@ func (it *listIter) findBlockForTarget(target uint64) int { // loadBlock decodes the block at the given directory index. func (it *listIter) loadBlock(idx int) { if it.legacy { - // Legacy mode: single block already loaded. - it.inBlockPos = 0 + // Legacy mode: single pre-loaded block; don't reset position. return } bm := it.dir.Blocks[idx] @@ -331,29 +356,65 @@ func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) } -// lookupDocLen looks up a single UID's document length from the block-based doclen store. -// Falls back to the legacy monolithic doclen blob if no block directory exists. -func lookupDocLen(attr string, uid, readTs uint64) float64 { - dirKey := x.BM25DocLenDirKey(attr) - dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) - dir := bm25block.DecodeDir(dirBlob) +// docLenCache caches document length lookups within a single query to avoid +// repeated Badger reads for the same doclen block directory and blocks. +type docLenCache struct { + attr string + readTs uint64 + dir *bm25block.Dir + loaded bool + legacy bool + // Per-block cache: blockIdx -> decoded entries. + blocks map[int][]bm25enc.Entry + // Legacy entries (when using monolithic blob). + legacyEntries []bm25enc.Entry +} - if len(dir.Blocks) == 0 { - // Fallback: try the legacy monolithic doclen blob. - legacyKey := x.BM25DocLenKey(attr) - legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) - legacyEntries := bm25enc.Decode(legacyBlob) - if v, ok := bm25enc.Search(legacyEntries, uid); ok { +func newDocLenCache(attr string, readTs uint64) *docLenCache { + return &docLenCache{ + attr: attr, + readTs: readTs, + blocks: make(map[int][]bm25enc.Entry), + } +} + +func (c *docLenCache) ensureLoaded() { + if c.loaded { + return + } + c.loaded = true + dirKey := x.BM25DocLenDirKey(c.attr) + dirBlob := posting.ReadBM25BlobAt(dirKey, c.readTs) + c.dir = bm25block.DecodeDir(dirBlob) + if len(c.dir.Blocks) == 0 { + // Try legacy. + legacyKey := x.BM25DocLenKey(c.attr) + legacyBlob := posting.ReadBM25BlobAt(legacyKey, c.readTs) + c.legacyEntries = bm25enc.Decode(legacyBlob) + c.legacy = true + } +} + +func (c *docLenCache) lookup(uid uint64) float64 { + c.ensureLoaded() + if c.legacy { + if v, ok := bm25enc.Search(c.legacyEntries, uid); ok { return float64(v) } return 1.0 } - - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) - blob := posting.ReadBM25BlobAt(blockKey, readTs) - entries := bm25enc.Decode(blob) + if len(c.dir.Blocks) == 0 { + return 1.0 + } + blockIdx := c.dir.FindBlock(uid) + entries, ok := c.blocks[blockIdx] + if !ok { + bm := c.dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(c.attr, bm.BlockID) + blob := posting.ReadBM25BlobAt(blockKey, c.readTs) + entries = bm25enc.Decode(blob) + c.blocks[blockIdx] = entries + } if v, ok := bm25enc.Search(entries, uid); ok { return float64(v) } @@ -366,6 +427,8 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, k, b, avgDL, N float64, topK int, filterSet map[uint64]struct{}, useBMW bool) []scoredDoc { + dlCache := newDocLenCache(attr, readTs) + // Build iterators for each query term. var iters []*listIter for _, token := range queryTokens { @@ -405,7 +468,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, // If no top-k limit, score all matching documents. if topK <= 0 { - return scoreAllDocs(iters, attr, readTs, k, b, avgDL, filterSet) + return scoreAllDocs(iters, dlCache, k, b, avgDL, filterSet) } // WAND algorithm with top-k heap. @@ -454,13 +517,8 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, if iters[i].currentDoc() < pivotDoc { var ok bool if useBMW { - // Compute other UBs for BMW skipping. - var otherUB float64 - for j, jt := range iters { - if j != i { - otherUB += jt.remainingUB() - } - } + // Compute otherUB = total UB - this iter's UB (O(1) instead of O(q)). + otherUB := sumUB - iters[i].remainingUB() ok = iters[i].skipToWithBMW(pivotDoc, theta, otherUB) } else { ok = iters[i].skipTo(pivotDoc) @@ -492,7 +550,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, } } - dl := lookupDocLen(attr, pivotDoc, readTs) + dl := dlCache.lookup(pivotDoc) var score float64 for _, it := range iters { if it.currentDoc() == pivotDoc { @@ -515,7 +573,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, // scoreAllDocs scores every matching document without early termination. // Used when no top-k limit is specified (the original behavior). -func scoreAllDocs(iters []*listIter, attr string, readTs uint64, +func scoreAllDocs(iters []*listIter, dlCache *docLenCache, k, b, avgDL float64, filterSet map[uint64]struct{}) []scoredDoc { // Collect all (uid, term) matches. @@ -541,7 +599,7 @@ func scoreAllDocs(iters []*listIter, attr string, readTs uint64, // Score all matching documents. results := make([]scoredDoc, 0, len(matches)) for uid, terms := range matches { - dl := lookupDocLen(attr, uid, readTs) + dl := dlCache.lookup(uid) var score float64 for _, tm := range terms { score += bm25Score(tm.idf, float64(tm.tf), dl, avgDL, k, b) From 1073cbabf94b2016fa2f8240c113aee72d95854f Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 08:53:46 -0500 Subject: [PATCH 10/19] fix(bm25): prevent stats double-counting on updates and fix BMW otherUB underestimate Three fixes: 1. CRITICAL: addBM25IndexMutations now checks if a UID already exists in doclen blocks before incrementing stats, preventing double-counting on SET when the document was already indexed (defensive guard for batch mutations). 2. HIGH: WAND sumUB now accumulates across ALL iterators (not just up to pivot), so BMW's otherUB calculation is correct and won't skip valid candidate blocks. 3. PERF: newListIter accepts pre-read Dir to eliminate duplicate Badger reads (directory was read once for df, then again inside newListIter). Co-Authored-By: Claude Opus 4.6 --- posting/index.go | 38 ++++++++++++++++++++++++++++++++++++-- worker/bm25wand.go | 17 ++++++++++------- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/posting/index.go b/posting/index.go index d2eadd904e0..e19deae8d73 100644 --- a/posting/index.go +++ b/posting/index.go @@ -272,13 +272,26 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn return txn.updateBM25Stats(attr, -1, -int64(docLen)) } - // For SET: upsert term frequencies and doc length into blocks. + // For SET: check if this UID already has a doclen entry (i.e., this is an update). + // If so, subtract old stats to avoid double-counting. + oldDocLen, isUpdate := txn.bm25DocLenBlockLookup(attr, uid) + for term, tf := range termFreqs { encodedTerm := string([]byte{tok.IdentBM25}) + term txn.bm25BlockUpsert(attr, encodedTerm, uid, tf) } txn.bm25DocLenBlockUpsert(attr, uid, docLen) - return txn.updateBM25Stats(attr, 1, int64(docLen)) + + var docCountDelta int64 + var totalTermsDelta int64 + if isUpdate { + // Document already existed: don't increment docCount, adjust totalTerms by diff. + totalTermsDelta = int64(docLen) - int64(oldDocLen) + } else { + docCountDelta = 1 + totalTermsDelta = int64(docLen) + } + return txn.updateBM25Stats(attr, docCountDelta, totalTermsDelta) } // bm25BlockUpsert inserts or updates a (uid, value) entry in the block-based @@ -400,6 +413,27 @@ func (txn *Txn) bm25DocLenBlockUpsert(attr string, uid uint64, docLen uint32) { txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) } +// bm25DocLenBlockLookup checks if a uid exists in the doclen blocks and returns its value. +func (txn *Txn) bm25DocLenBlockLookup(attr string, uid uint64) (uint32, bool) { + dirKey := x.BM25DocLenDirKey(attr) + dirBlob := txn.cache.ReadBM25Blob(dirKey) + dir := bm25block.DecodeDir(dirBlob) + + if len(dir.Blocks) == 0 { + return 0, false + } + + blockIdx := dir.FindBlock(uid) + bm := dir.Blocks[blockIdx] + blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) + blob := txn.cache.ReadBM25Blob(blockKey) + entries := bm25enc.Decode(blob) + if v, ok := bm25enc.Search(entries, uid); ok { + return v, true + } + return 0, false +} + // bm25DocLenBlockRemove removes a uid from the block-based document-length list. func (txn *Txn) bm25DocLenBlockRemove(attr string, uid uint64) { dirKey := x.BM25DocLenDirKey(attr) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 5aab0cd0e44..de5ecfb2b3c 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -36,10 +36,13 @@ type listIter struct { // newListIter creates a new iterator for a term's block-based posting list. // Falls back to the legacy monolithic blob format if no block directory exists. -func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64) *listIter { - dirKey := x.BM25TermDirKey(attr, encodedTerm) - dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) - dir := bm25block.DecodeDir(dirBlob) +// If dir is non-nil, it is used directly (avoids re-reading from Badger). +func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64, dir *bm25block.Dir) *listIter { + if dir == nil { + dirKey := x.BM25TermDirKey(attr, encodedTerm) + dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) + dir = bm25block.DecodeDir(dirBlob) + } if len(dir.Blocks) == 0 { // Fallback: try reading the legacy monolithic blob and wrap it as a single block. @@ -453,7 +456,7 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, } idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) - it := newListIter(attr, token, readTs, idf, k, b) + it := newListIter(attr, token, readTs, idf, k, b, dir) if !it.exhausted { it.next() // prime the iterator if !it.exhausted { @@ -501,12 +504,12 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, var pivotDoc uint64 for i, it := range iters { sumUB += it.remainingUB() - if sumUB > theta { + if sumUB > theta && pivot == -1 { pivot = i pivotDoc = it.currentDoc() - break } } + // sumUB now contains the total UB across ALL iterators (needed for BMW). if pivot == -1 { break // sum of all UBs can't beat theta } From 81b18c1e4a9b3a918d82cb0c05cf720b6978a5e6 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Thu, 5 Mar 2026 09:10:29 -0500 Subject: [PATCH 11/19] fix(bm25): clamp startPos in skipTo to prevent negative sort.Search length Defensive hardening from GPT-5 review: if inBlockPos exceeds block length after next() reaches end of block, the sort.Search span could go negative. Co-Authored-By: Claude Opus 4.6 --- worker/bm25wand.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index de5ecfb2b3c..4ae2569fa7a 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -188,6 +188,8 @@ func (it *listIter) skipTo(target uint64) bool { startPos := it.inBlockPos if startPos < 0 { startPos = 0 + } else if startPos > len(it.block) { + startPos = len(it.block) } // Binary search within current block from startPos. pos := sort.Search(len(it.block)-startPos, func(i int) bool { From edf466dfdc935ecfbafc4fb8acce44af37a28258 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 18 Mar 2026 21:23:16 -0400 Subject: [PATCH 12/19] fix(bm25): address Gemini/GPT-5 code review findings - Add DecodeCount() to bm25enc for O(1) entry count reads without full decode, preventing OOM on legacy migration with large posting lists (e.g., common terms with millions of entries) - Use DecodeCount in WAND search legacy DF calculation path - Fix integer overflow in DecodeDir bounds check by using uint64 arithmetic (prevents panic on corrupted data with MaxUint32 count) - Pre-allocate shared score buffer in handleBM25Search with three-index slices to prevent accidental append corruption - Document bm25Writes concurrency model and limitations Co-Authored-By: Claude Opus 4.6 (1M context) --- posting/bm25block/bm25block.go | 3 +- posting/bm25block/bm25block_test.go | 13 ++++ posting/bm25enc/bm25enc.go | 11 +++ posting/bm25enc/bm25enc_test.go | 31 ++++++++ posting/lists.go | 12 +++ query/query_bm25_test.go | 115 ++++++++++++++++++---------- worker/bm25wand.go | 6 +- worker/task.go | 14 +++- 8 files changed, 159 insertions(+), 46 deletions(-) diff --git a/posting/bm25block/bm25block.go b/posting/bm25block/bm25block.go index f529ed8fab8..e9c4fa1776e 100644 --- a/posting/bm25block/bm25block.go +++ b/posting/bm25block/bm25block.go @@ -77,7 +77,8 @@ func DecodeDir(data []byte) *Dir { } count := binary.BigEndian.Uint32(data[0:4]) nextID := binary.BigEndian.Uint32(data[4:8]) - if int(count)*dirEntrySize+dirHeaderSize > len(data) { + // Use uint64 arithmetic to prevent integer overflow on corrupted data. + if uint64(count)*dirEntrySize+dirHeaderSize > uint64(len(data)) { return &Dir{NextID: nextID} } blocks := make([]BlockMeta, count) diff --git a/posting/bm25block/bm25block_test.go b/posting/bm25block/bm25block_test.go index a7cc26f493a..f9cccb7554a 100644 --- a/posting/bm25block/bm25block_test.go +++ b/posting/bm25block/bm25block_test.go @@ -6,6 +6,7 @@ package bm25block import ( + "encoding/binary" "math" "testing" @@ -39,6 +40,18 @@ func TestDirRoundtripEmpty(t *testing.T) { require.Empty(t, got.Blocks) } +func TestDecodeDirCorruptedLargeCount(t *testing.T) { + // A corrupted blob with a massive count should not panic due to integer overflow. + // count = MaxUint32, nextID = 0, followed by only 8 bytes of data. + data := make([]byte, 16) + binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) // count = MaxUint32 + binary.BigEndian.PutUint32(data[4:8], 0) // nextID = 0 + got := DecodeDir(data) + // Should return an empty Dir (with nextID preserved) rather than panicking. + require.Empty(t, got.Blocks) + require.Equal(t, uint32(0), got.NextID) +} + func TestDirRoundtripSingle(t *testing.T) { dir := &Dir{ NextID: 1, diff --git a/posting/bm25enc/bm25enc.go b/posting/bm25enc/bm25enc.go index 8da82b299dd..86bfe5f5bb1 100644 --- a/posting/bm25enc/bm25enc.go +++ b/posting/bm25enc/bm25enc.go @@ -130,6 +130,17 @@ func UIDs(entries []Entry) []uint64 { return uids } +// DecodeCount reads just the entry count from the header of an encoded blob +// without decoding any entries. This is O(1) and avoids allocating a full +// []Entry slice, which matters for large posting lists (e.g., common terms +// during legacy format migration). +func DecodeCount(data []byte) uint32 { + if len(data) < 4 { + return 0 + } + return binary.BigEndian.Uint32(data[:4]) +} + // EncodeStats encodes BM25 corpus statistics (docCount, totalTerms) as 16 bytes. func EncodeStats(docCount, totalTerms uint64) []byte { buf := make([]byte, 16) diff --git a/posting/bm25enc/bm25enc_test.go b/posting/bm25enc/bm25enc_test.go index 1969e472ed2..f4cfec6bf62 100644 --- a/posting/bm25enc/bm25enc_test.go +++ b/posting/bm25enc/bm25enc_test.go @@ -92,6 +92,37 @@ func TestUIDs(t *testing.T) { require.Equal(t, []uint64{1, 5, 100}, UIDs(entries)) } +func TestDecodeCount(t *testing.T) { + // Normal case: count matches actual entries. + entries := []Entry{ + {UID: 1, Value: 3}, + {UID: 5, Value: 1}, + {UID: 100, Value: 7}, + } + data := Encode(entries) + require.Equal(t, uint32(3), DecodeCount(data)) + + // Empty/nil input. + require.Equal(t, uint32(0), DecodeCount(nil)) + require.Equal(t, uint32(0), DecodeCount([]byte{})) + require.Equal(t, uint32(0), DecodeCount([]byte{1, 2, 3})) + + // Zero count. + require.Equal(t, uint32(0), DecodeCount([]byte{0, 0, 0, 0})) + + // Single entry. + single := Encode([]Entry{{UID: 42, Value: 10}}) + require.Equal(t, uint32(1), DecodeCount(single)) + + // Large count. + large := make([]Entry, 10000) + for i := range large { + large[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} + } + data = Encode(large) + require.Equal(t, uint32(10000), DecodeCount(data)) +} + func TestStatsRoundtrip(t *testing.T) { data := EncodeStats(12345, 98765) dc, tt := DecodeStats(data) diff --git a/posting/lists.go b/posting/lists.go index 0bd9848de23..22d20a53973 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -79,6 +79,18 @@ type LocalCache struct { // bm25Writes buffers BM25 direct KV writes (key → encoded blob). // These bypass the posting list infrastructure entirely. + // + // CONCURRENCY NOTE: BM25 blocks use full-value overwrites rather than + // posting list deltas. Within a single Dgraph transaction this is safe + // (each Txn has its own LocalCache). Across concurrent transactions, + // Dgraph's Raft-based mutation serialization prevents lost updates for + // the same predicate+UID pair. However, two transactions updating + // different UIDs that share a common term could theoretically race on + // the same term block. In practice this is mitigated by: + // 1. Dgraph serializes mutations through Raft proposals + // 2. Block splits keep contention surface small + // If higher write concurrency is needed, blocks should be integrated + // into the posting list delta mechanism. bm25Writes map[string][]byte } diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index 457c7b46452..1411ad3916e 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -19,6 +19,23 @@ import ( "github.com/stretchr/testify/require" ) +// uidHex queries Dgraph for the hex UID string of a given decimal UID. +// This avoids hardcoding hex values that depend on UID assignment order. +func uidHex(t *testing.T, decimalUID int) string { + t.Helper() + js := processQueryNoErr(t, fmt.Sprintf(`{ me(func: uid(%d)) { uid } }`, decimalUID)) + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me, "UID %d should exist", decimalUID) + return resp.Data.Me[0].UID +} + func TestBM25Basic(t *testing.T) { query := ` { @@ -376,9 +393,9 @@ func TestBM25IncrementalAddBatch(t *testing.T) { js = processQueryNoErr(t, countQuery) require.Contains(t, js, `"count":8`) - // Verify specific new UIDs are searchable. - js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid } }`) - require.Contains(t, js, `"0x25e"`) // 606 + // Verify specific new terms are searchable. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid description_bm25 } }`) + require.Contains(t, js, "whiskey") } func TestBM25CorpusStatsAffectIDF(t *testing.T) { @@ -417,7 +434,7 @@ func TestBM25CorpusStatsAffectIDF(t *testing.T) { scoresAfter := parseScoresFromJSON(t, jsAfter) // Compare score for UID 503 ("fox fox fox") — should increase. - uid503 := "0x1f7" + uid503 := uidHex(t, 503) before, ok1 := scoresBefore[uid503] after, ok2 := scoresAfter[uid503] require.True(t, ok1 && ok2, "UID 503 should appear in both before and after results") @@ -432,6 +449,8 @@ func TestBM25DocumentUpdate(t *testing.T) { deleteTriplesInCluster(`<620> * .`) }) + uid620 := uidHex(t, 620) + // Should rank top for "fox". js := processQueryNoErr(t, ` { @@ -439,7 +458,7 @@ func TestBM25DocumentUpdate(t *testing.T) { uid } }`) - require.Contains(t, js, `"0x26c"`) // 620 + require.Contains(t, js, `"`+uid620+`"`) // Update to remove "fox", add "cat". deleteTriplesInCluster(`<620> "fox fox fox fox" .`) @@ -452,7 +471,7 @@ func TestBM25DocumentUpdate(t *testing.T) { uid } }`) - require.NotContains(t, js, `"0x26c"`) + require.NotContains(t, js, `"`+uid620+`"`) // Should appear in "cat" results. js = processQueryNoErr(t, ` @@ -461,7 +480,7 @@ func TestBM25DocumentUpdate(t *testing.T) { uid } }`) - require.Contains(t, js, `"0x26c"`) + require.Contains(t, js, `"`+uid620+`"`) } func TestBM25DocumentDeletion(t *testing.T) { @@ -471,9 +490,11 @@ func TestBM25DocumentDeletion(t *testing.T) { deleteTriplesInCluster(`<625> * .`) }) + uid625 := uidHex(t, 625) + // Should find the elephant doc. js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.Contains(t, js, `"0x271"`) // 625 + require.Contains(t, js, `"`+uid625+`"`) // Delete it. deleteTriplesInCluster(`<625> "unique elephant term" .`) @@ -483,7 +504,7 @@ func TestBM25DocumentDeletion(t *testing.T) { require.JSONEq(t, `{"data": {"me":[]}}`, js) // Baseline "fox" results should be unaffected. - js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid } }`) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid description_bm25 } }`) require.Contains(t, js, "fox") } @@ -499,7 +520,7 @@ func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { } } ` - uid503 := "0x1f7" + uid503 := uidHex(t, 503) // Phase 1: baseline score. js1 := processQueryNoErr(t, scoreQuery) @@ -642,8 +663,8 @@ func TestBM25EdgeCaseLongDocument(t *testing.T) { js := processQueryNoErr(t, scoreQuery) scores := parseScoresFromJSON(t, js) - uid503 := "0x1f7" // "fox fox fox" (doclen=3) - uid645 := "0x285" // long doc (doclen~500) + uid503 := uidHex(t, 503) // "fox fox fox" (doclen=3) + uid645 := uidHex(t, 645) // long doc (doclen~500) s503, ok1 := scores[uid503] s645, ok2 := scores[uid645] require.True(t, ok1, "UID 503 must appear in fox results") @@ -667,17 +688,21 @@ func TestBM25EdgeCaseUnicode(t *testing.T) { `) }) + uid650 := uidHex(t, 650) + uid651 := uidHex(t, 651) + uid652 := uidHex(t, 652) + // Query German term. js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "Fuchs")) { uid } }`) - require.Contains(t, js, `"0x28a"`) // 650 + require.Contains(t, js, `"`+uid650+`"`) // Query French term. js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "renard")) { uid } }`) - require.Contains(t, js, `"0x28b"`) // 651 + require.Contains(t, js, `"`+uid651+`"`) // Query Spanish term. js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "zorro")) { uid } }`) - require.Contains(t, js, `"0x28c"`) // 652 + require.Contains(t, js, `"`+uid652+`"`) } func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { @@ -686,9 +711,11 @@ func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { deleteTriplesInCluster(`<655> * .`) }) + uid655 := uidHex(t, 655) + // Query "the" — should return empty since "the" is a stopword. js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "the")) { uid } }`) - require.NotContains(t, js, `"0x28f"`) // 655 should not appear + require.NotContains(t, js, `"`+uid655+`"`) // 655 should not appear // But the doc should exist via has(). js = processQueryNoErr(t, ` @@ -697,7 +724,7 @@ func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { uid } }`) - require.Contains(t, js, `"0x28f"`) + require.Contains(t, js, `"`+uid655+`"`) } func TestBM25WithUidFilter(t *testing.T) { @@ -711,12 +738,16 @@ func TestBM25WithUidFilter(t *testing.T) { } ` js := processQueryNoErr(t, query) + uid501 := uidHex(t, 501) + uid502 := uidHex(t, 502) + uid503 := uidHex(t, 503) + uid506 := uidHex(t, 506) // Should contain only UIDs 501 and 503. - require.Contains(t, js, `"0x1f5"`) // 501 - require.Contains(t, js, `"0x1f7"`) // 503 - // Should NOT contain other fox docs like 502, 506, 507. - require.NotContains(t, js, `"0x1f6"`) // 502 - require.NotContains(t, js, `"0x1fa"`) // 506 + require.Contains(t, js, `"`+uid501+`"`) + require.Contains(t, js, `"`+uid503+`"`) + // Should NOT contain other fox docs like 502, 506. + require.NotContains(t, js, `"`+uid502+`"`) + require.NotContains(t, js, `"`+uid506+`"`) } func TestBM25ScoreValuesAreValidFloats(t *testing.T) { @@ -770,22 +801,23 @@ func TestBM25IncrementalAddThenDeleteThenReadd(t *testing.T) { // Phase 1: add with "elephant". require.NoError(t, addTriplesToCluster(`<670> "elephant roams the savanna" .`)) + uid670 := uidHex(t, 670) js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.Contains(t, js, `"0x29e"`) // 670 + require.Contains(t, js, `"`+uid670+`"`) // Phase 2: delete. deleteTriplesInCluster(`<670> "elephant roams the savanna" .`) js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.NotContains(t, js, `"0x29e"`) + require.NotContains(t, js, `"`+uid670+`"`) // Phase 3: re-add with different content. require.NoError(t, addTriplesToCluster(`<670> "penguin waddles on the ice" .`)) js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "penguin")) { uid } }`) - require.Contains(t, js, `"0x29e"`) + require.Contains(t, js, `"`+uid670+`"`) // "elephant" should still not match 670. js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) - require.NotContains(t, js, `"0x29e"`) + require.NotContains(t, js, `"`+uid670+`"`) } func TestBM25NonIndexedPredicateError(t *testing.T) { @@ -828,11 +860,11 @@ func TestBM25ConcurrentBatchAdd(t *testing.T) { // Spot-check a doc from each batch. for batch := 0; batch < 5; batch++ { - uid := 680 + batch*4 - hexUID := fmt.Sprintf(`"0x%x"`, uid) + decUID := 680 + batch*4 + hexUID := uidHex(t, decUID) term := fmt.Sprintf("batch%d", batch) js = processQueryNoErr(t, fmt.Sprintf(`{ me(func: bm25(description_bm25, "%s")) { uid } }`, term)) - require.Contains(t, js, hexUID, "doc %d from batch %d should be searchable", uid, batch) + require.Contains(t, js, `"`+hexUID+`"`, "doc %d from batch %d should be searchable", decUID, batch) } } @@ -895,10 +927,12 @@ func TestBM25ExactScoreValues(t *testing.T) { // Doc 851 "quasar nebula pulsar": tf=1, b=0 → score = idf * 2.2 * 1 / 2.2 = idf expected851 := idf * (k + 1) * 1.0 / (k + 1.0) - actual850, ok := scores["0x352"] // 850 - require.True(t, ok, "UID 850 (0x352) must be in results") - actual851, ok := scores["0x353"] // 851 - require.True(t, ok, "UID 851 (0x353) must be in results") + uid850 := uidHex(t, 850) + uid851 := uidHex(t, 851) + actual850, ok := scores[uid850] + require.True(t, ok, "UID 850 (%s) must be in results", uid850) + actual851, ok := scores[uid851] + require.True(t, ok, "UID 851 (%s) must be in results", uid851) require.InEpsilon(t, expected850, actual850, 1e-6, "Doc 850 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", @@ -940,8 +974,10 @@ func TestBM25BM15NoLengthNormalization(t *testing.T) { js := processQueryNoErr(t, scoreQuery) scores := parseScoresFromJSON(t, js) - score860, ok1 := scores["0x35c"] // 860 - score861, ok2 := scores["0x35d"] // 861 + uid860 := uidHex(t, 860) + uid861 := uidHex(t, 861) + score860, ok1 := scores[uid860] + score861, ok2 := scores[uid861] require.True(t, ok1, "UID 860 must be in results") require.True(t, ok2, "UID 861 must be in results") @@ -964,8 +1000,8 @@ func TestBM25BM15NoLengthNormalization(t *testing.T) { js = processQueryNoErr(t, scoreQueryDefault) scoresDefault := parseScoresFromJSON(t, js) - defScore860, ok1 := scoresDefault["0x35c"] - defScore861, ok2 := scoresDefault["0x35d"] + defScore860, ok1 := scoresDefault[uid860] + defScore861, ok2 := scoresDefault[uid861] require.True(t, ok1, "UID 860 must be in default results") require.True(t, ok2, "UID 861 must be in default results") require.Greater(t, defScore860, defScore861, @@ -999,8 +1035,9 @@ func TestBM25SingleMatchingDocument(t *testing.T) { require.Len(t, scores, 1, "exactly one document should match 'aardvark'") - actual, ok := scores["0x361"] // 865 - require.True(t, ok, "UID 865 (0x361) must be in results") + uid865 := uidHex(t, 865) + actual, ok := scores[uid865] + require.True(t, ok, "UID 865 (%s) must be in results", uid865) // With df=1, tf=1, b=0, k=1.2: // idf = log1p((N - 1 + 0.5) / (1 + 0.5)) = log1p((N - 0.5) / 1.5) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 4ae2569fa7a..07988c845df 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -447,11 +447,11 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, df += uint64(bm.Count) } } else { - // Legacy fallback: read the monolithic blob to get df. + // Legacy fallback: read just the count header to get df. + // Avoids decoding the full posting list (which could be huge for common terms). legacyKey := x.BM25IndexKey(attr, token) legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) - legacyEntries := bm25enc.Decode(legacyBlob) - df = uint64(len(legacyEntries)) + df = uint64(bm25enc.DecodeCount(legacyBlob)) } if df == 0 { continue diff --git a/worker/task.go b/worker/task.go index 55697973275..0345e9e75f8 100644 --- a/worker/task.go +++ b/worker/task.go @@ -1318,6 +1318,8 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error // 7. Build output: UIDs sorted ascending (required by query pipeline) // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). + // We use a single pre-allocated buffer for all score encodings to reduce + // per-result heap allocations. sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) for i, r := range results { @@ -1325,12 +1327,18 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + // Encode scores into ValueMatrix. Each entry in ValueMatrix corresponds + // positionally to a UID in UidMatrix[0], enabling the bm25_score + // pseudo-predicate in query.go to map UIDs to scores. + scoreBuf := make([]byte, len(results)*8) scoreValues := make([]*pb.ValueList, len(results)) for i, r := range results { - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, math.Float64bits(r.score)) + off := i * 8 + binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(r.score)) + // Use three-index slice to cap capacity at 8, preventing any downstream + // append from corrupting adjacent scores in the shared backing array. scoreValues[i] = &pb.ValueList{ - Values: []*pb.TaskValue{{Val: buf, ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, } } args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) From 6fd041ea6a9e20a3b04a1775c69eef8903d89c11 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 3 Jun 2026 21:02:39 -0400 Subject: [PATCH 13/19] feat(bm25): rework BM25 onto standard posting lists Replaces the parallel block-storage + retrieval stack (declined in review) with an implementation that rides Dgraph's standard posting-list machinery, addressing the maintainer's feedback (independently endorsed by GPT-5 and Gemini). Net ~1300 fewer lines. Storage / indexing: - BM25 term postings are standard index posting lists at IndexKey(attr, IdentBM25||term), written via the normal delta path, so they inherit MVCC, deltas, rollup, splits, backup and snapshot. Each posting is a REF posting whose value packs (term-frequency, doc-length) as two uvarints. - Fix the linchpin: List.encode() now retains REF postings that carry a value through rollup (otherwise the term frequency was silently stripped). Mirrors how faceted postings already coexist in Pack (uid) + Postings (payload). - Document length is packed into the posting value rather than a separate list, avoiding a write-hot doclen key and a per-candidate random read at query time. - Corpus stats (docCount, totalTerms) are sharded across 32 buckets keyed by uid%32 so concurrent writers rarely contend, while same-bucket updates still conflict-and-retry (mirrors the @count pattern). Term postings get the standard index conflict key (fingerprint(key)^uid), so two docs sharing a term commit concurrently without conflict -- resolving the concurrency regression that the block version only mitigated via Raft serialization. - Delete posting/bm25block and posting/bm25enc; remove LocalCache.bm25Writes, BitBM25Data, and the BM25 commit branch in mvcc.go. Query / scoring: - WAND / Block-Max WAND reworked over the standard posting-list iterator: per-term cursors are materialized from the in-memory List with per-128-posting maxTF/minDocLen bounds -- no parallel block format, no proto changes. - Surface the score via Dgraph's existing value-variable mechanism: the bm25 root function binds its per-doc score to its own variable (Uids + Vals), so `scores as var(func: bm25(...))` works with uid(scores), val(scores) and orderdesc: val(scores). Removes the bm25_score pseudo-predicate and the __bm25_scores__ ParentVars channel. - Skip the query-layer pagination pass for bm25 roots (the worker already paginates over score order), mirroring the existing `has` handling, to avoid double-applying first/offset. Tests: - Rollup TF/doc-length survival, bucketed-stats accumulation (incl. same-bucket in-transaction read-your-own-writes and deletes), value-codec round trip, and a 200-trial randomized WAND/Block-Max-WAND vs brute-force correctness check. - Convert the query integration tests to the value-variable syntax. Co-Authored-By: Claude Opus 4.8 (1M context) --- bm25-redesign-plan.md | 74 ++++ posting/bm25.go | 224 ++++++++++ posting/bm25_test.go | 140 +++++++ posting/bm25block/bm25block.go | 262 ------------ posting/bm25block/bm25block_test.go | 271 ------------ posting/bm25enc/bm25enc.go | 158 ------- posting/bm25enc/bm25enc_test.go | 163 -------- posting/index.go | 252 +----------- posting/list.go | 11 +- posting/lists.go | 68 --- posting/mvcc.go | 14 - query/query.go | 82 ++-- query/query_bm25_test.go | 54 +-- worker/bm25wand.go | 615 +++++++++------------------- worker/bm25wand_test.go | 104 +++++ worker/task.go | 46 +-- x/keys.go | 63 +-- 17 files changed, 859 insertions(+), 1742 deletions(-) create mode 100644 bm25-redesign-plan.md create mode 100644 posting/bm25.go create mode 100644 posting/bm25_test.go delete mode 100644 posting/bm25block/bm25block.go delete mode 100644 posting/bm25block/bm25block_test.go delete mode 100644 posting/bm25enc/bm25enc.go delete mode 100644 posting/bm25enc/bm25enc_test.go diff --git a/bm25-redesign-plan.md b/bm25-redesign-plan.md new file mode 100644 index 00000000000..4ba98f9573a --- /dev/null +++ b/bm25-redesign-plan.md @@ -0,0 +1,74 @@ +# BM25 Redesign — Implementation Spec + +Reworks the BM25 feature per the maintainer's review (decline of the block-storage +PR). Endorsed independently by GPT-5 and Gemini. Goal: BM25 rides Dgraph's standard +posting-list machinery (MVCC, deltas, rollup, splits, backup, snapshot) instead of a +parallel storage+retrieval stack. + +## What gets deleted +- `posting/bm25block/` and `posting/bm25enc/` (parallel block format). +- `LocalCache.bm25Writes`, `ReadBM25Blob`/`WriteBM25Blob` (second write path). +- `BitBM25Data` user-meta + the BM25 commit branch in `posting/mvcc.go`. +- `bm25_score` pseudo-predicate + `__bm25_scores__` `ParentVars` threading in `query/query.go`. +- Legacy-format fallback / block dir+block keys in `x/keys.go`. + +## Storage model (standard posting lists) +- **Term postings**: one standard index posting list per term at + `IndexKey(attr, IdentBM25 || term)`. Each posting: `Uid = docUID`, + `Value = encodeBM25(tf, docLen)`, `ValType = INT`. Written via `plist.addMutation` + (the normal delta path) → inherits rollup/splits/backup. + - **Rollup-survival fix (linchpin)**: `NewPosting` makes any edge with `ValueId != 0` + a `REF` posting, and `List.encode()` (rollup) keeps a posting's `Value` only when + `Facets != nil || PostingType != REF`. A plain valued REF index posting would have + its TF **stripped at rollup**. Fix: one-line change in `encode()` to also retain + postings that carry a non-empty `Value`. This is the faithful realization of the + maintainer's "TF as the value", and matches how faceted postings already coexist + in both `Pack` (uid) and `Postings` (value). Covered by a forced-rollup regression test. +- **Doc length**: packed into the posting value alongside TF (`encodeBM25(tf, docLen)`), + NOT a separate per-predicate doclen list. Rationale: a single doclen list is a write- + conflict hotspot (every doc mutation writes the same key) and forces a query-time random + read per candidate. Packing makes scoring read `(uid, tf, docLen)` in one shot, + contention-free. Cost: docLen duplicated across a doc's unique terms (acceptable; a doc's + postings are all rewritten together on update anyway). +- **Corpus stats** (`N` docs, `totalTerms` → `avgDL`): conflict-free **bucketed** stats. + `BM25StatsKey(attr, bucket)`, `bucket = docUID % numBuckets` (B=32). Each bucket holds + `(docCount_b, totalTerms_b)`. Mutations touch only their bucket → ~B-fold less contention + than a single hot key. Read path sums across buckets. BM25 tolerates the slight staleness. + +## Value codec `encodeBM25(tf, docLen)` +Two unsigned varints: `tf` then `docLen`. Decoded during scoring. Small file +`posting/bm25.go` (no new package) holds encode/decode + index-mutation logic. + +## Query path (no pseudo-predicate) +- `bm25(attr, "query", [k], [b])` parses to `bm25SearchFn` (unchanged keyword). +- `worker/task.go handleBM25Search`: tokenize query, read bucketed stats → `N`, `avgDL`, + load each term's standard posting `List` via the cache, run WAND, emit `UidMatrix` + (uids asc) + `ValueMatrix` (float64 scores aligned to uids). +- **Surfacing/ordering the score**: via Dgraph's existing **value-variable** (`val()`) + mechanism — the function's `ValueMatrix` populates a value var the user binds and orders + by. No `bm25_score` pseudo-predicate, no new `ParentVars` channel. + +## WAND on the standard iterator (no parallel block format) +Dgraph loads a whole posting list (or split-part) into memory on `Get`. So: +- For each query term, one `List.Iterate` pass materializes a sorted cursor of + `(uid, tf, docLen)`, plus `df`, term `maxTF`, and per-chunk (128) `maxTF`/`minDocLen` + for Block-Max upper bounds — all computed from the in-memory list, **no storage-format + change**. +- WAND / Block-Max WAND DAAT with a top-k min-heap (reuse scoring + heap from the existing + `worker/bm25wand.go`, swapping the block-reading cursor for the standard-list cursor). +- (Future optimization, out of scope now: persist per-block maxTF at rollup to avoid + recomputing for hot terms.) + +## Scoring +`idf = log1p((N - df + 0.5)/(df + 0.5))`; `score = Σ idf·(k+1)·tf / (k·(1-b+b·dl/avgDL) + tf)`. +Defaults `k=1.2`, `b=0.75`. + +## Implementation phases +1. Storage+index: `encode()` retention fix; `posting/bm25.go` (value codec + mutations); + bucketed stats; delete bm25block/bm25enc, bm25Writes, BitBM25Data, mvcc branch. +2. Keys: trim `x/keys.go` to `BM25IndexKey` + bucketed `BM25StatsKey`. +3. Tokenizer: keep `BM25Tokenizer` + query tokens (minor cleanup). +4. Query+WAND: rewrite `worker/bm25wand.go` over standard lists; rewrite `handleBM25Search`; + remove pseudo-predicate/ParentVars from `query/query.go`; wire value-var scoring. +5. Tests: forced-rollup TF-survival test; bucketed-stats test; WAND unit tests over standard + lists; adapt `query/query_bm25_test.go`; build + run. diff --git a/posting/bm25.go b/posting/bm25.go new file mode 100644 index 00000000000..a72fc7b72cd --- /dev/null +++ b/posting/bm25.go @@ -0,0 +1,224 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "context" + "encoding/binary" + + ostats "go.opencensus.io/stats" + + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/tok" + "github.com/dgraph-io/dgraph/v25/x" +) + +// BM25Posting is a single materialized entry of a BM25 term posting list: the +// document UID together with the term frequency and document length decoded from +// the posting value. +type BM25Posting struct { + Uid uint64 + TF uint32 + DocLen uint32 +} + +// ReadBM25TermPostings materializes the postings of a BM25 term's standard index +// list at readTs into a UID-ascending slice, decoding (tf, docLen) from each +// posting value. getList reads a posting list for a key (e.g. LocalCache.Get), +// keeping the value encoding encapsulated in this package. +func ReadBM25TermPostings(getList func(key []byte) (*List, error), attr, encodedTerm string, + readTs uint64) ([]BM25Posting, error) { + key := x.BM25IndexKey(attr, encodedTerm) + pl, err := getList(key) + if err != nil { + return nil, err + } + var out []BM25Posting + err = pl.Iterate(readTs, 0, func(p *pb.Posting) error { + tf, docLen, ok := decodeBM25Value(p.Value) + if !ok { + // Corrupt/truncated posting value: skip rather than inject a bogus + // zero-frequency match that would silently distort scoring. + return nil + } + out = append(out, BM25Posting{Uid: p.Uid, TF: tf, DocLen: docLen}) + return nil + }) + if err != nil { + return nil, err + } + return out, nil +} + +// numBM25StatsBuckets is the number of buckets the BM25 corpus statistics (document +// count and total term count) are sharded across, keyed by uid%numBM25StatsBuckets. +// Sharding spreads the read-modify-write contention of stats maintenance across +// independent posting lists so that concurrent mutations on different documents +// rarely conflict, while same-bucket updates still conflict (and retry) — avoiding +// lost updates. A single hot stats key would serialize all writes to the predicate. +const numBM25StatsBuckets = 32 + +// encodeBM25Value packs a posting's term frequency and document length into the +// posting Value as two unsigned varints. Storing the document length alongside the +// term frequency makes scoring read (tf, docLen) in a single posting access — no +// separate document-length list (which would be a write-hot key) and no random +// per-candidate lookup at query time. The document length is duplicated across a +// document's unique terms, but a document's postings are always rewritten together +// on update, so they stay consistent. +func encodeBM25Value(tf, docLen uint32) []byte { + buf := make([]byte, binary.MaxVarintLen32*2) + n := binary.PutUvarint(buf, uint64(tf)) + n += binary.PutUvarint(buf[n:], uint64(docLen)) + return buf[:n] +} + +// decodeBM25Value reverses encodeBM25Value. ok is false when the input does not +// hold two complete varints (e.g. a truncated or corrupt posting value), so the +// caller can skip it rather than silently scoring it as a zero-frequency match. +func decodeBM25Value(b []byte) (tf, docLen uint32, ok bool) { + tf64, n := binary.Uvarint(b) + if n <= 0 { + return 0, 0, false + } + docLen64, m := binary.Uvarint(b[n:]) + if m <= 0 { + return 0, 0, false + } + return uint32(tf64), uint32(docLen64), true +} + +// encodeBM25Stats encodes corpus statistics (document count, total term count) as +// two unsigned varints. +func encodeBM25Stats(docCount, totalTerms uint64) []byte { + buf := make([]byte, binary.MaxVarintLen64*2) + n := binary.PutUvarint(buf, docCount) + n += binary.PutUvarint(buf[n:], totalTerms) + return buf[:n] +} + +// decodeBM25Stats reverses encodeBM25Stats. It returns (0, 0) on malformed input. +func decodeBM25Stats(b []byte) (docCount, totalTerms uint64) { + docCount, n := binary.Uvarint(b) + if n <= 0 { + return 0, 0 + } + totalTerms, m := binary.Uvarint(b[n:]) + if m <= 0 { + return docCount, 0 + } + return docCount, totalTerms +} + +// addBM25TermPosting writes (op=SET) or removes (op=DEL) the posting for the given +// (term, uid) pair in the term's standard index posting list. On SET the posting's +// Value packs (tf, docLen); on DEL only the UID matters. The posting is a REF +// posting (ValueId set) that carries a Value — List.encode retains such postings +// through rollup (see the len(p.Value) > 0 clause there). +func (txn *Txn) addBM25TermPosting(ctx context.Context, attr, term string, uid uint64, + tf, docLen uint32, op pb.DirectedEdge_Op) error { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: op, + } + if op != pb.DirectedEdge_DEL { + edge.Value = encodeBM25Value(tf, docLen) + edge.ValueType = pb.Posting_BINARY + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + ostats.Record(ctx, x.NumEdges.M(1)) + return nil +} + +// updateBM25Stats applies (docCountDelta, totalTermsDelta) to the bucketed corpus +// statistics for attr. The bucket is selected by uid%numBM25StatsBuckets. The +// running totals are stored as a single value posting per bucket; the read at +// txn.StartTs sees this transaction's own earlier writes (read-your-own-writes), +// so multiple documents in the same transaction that land in the same bucket +// accumulate correctly. +func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, uid uint64, + docCountDelta, totalTermsDelta int64) error { + bucket := int(uid % numBM25StatsBuckets) + key := x.BM25StatsKey(attr, bucket) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + + var docCount, totalTerms uint64 + val, err := plist.Value(txn.StartTs) + switch { + case err == nil: + if data, ok := val.Value.([]byte); ok { + docCount, totalTerms = decodeBM25Stats(data) + } + case err == ErrNoValue: + // No stats yet for this bucket; start from zero. + default: + return err + } + + docCount = applyBM25Delta(docCount, docCountDelta) + totalTerms = applyBM25Delta(totalTerms, totalTermsDelta) + + edge := &pb.DirectedEdge{ + Attr: attr, + Value: encodeBM25Stats(docCount, totalTerms), + ValueType: pb.Posting_BINARY, + Op: pb.DirectedEdge_SET, + } + return plist.addMutation(ctx, txn, edge) +} + +// applyBM25Delta adds a signed delta to an unsigned counter, clamping at zero. +func applyBM25Delta(v uint64, delta int64) uint64 { + if delta >= 0 { + return v + uint64(delta) + } + dec := uint64(-delta) + if dec > v { + return 0 + } + return v - dec +} + +// ReadBM25Stats sums the bucketed corpus statistics for attr at readTs, returning +// the document count and total term count. avgDL = totalTerms / docCount. The +// getList closure reads a posting list for a key (e.g. LocalCache.Get) so the +// caller controls caching and the read timestamp. +func ReadBM25Stats(getList func(key []byte) (*List, error), attr string, + readTs uint64) (docCount, totalTerms uint64, err error) { + for b := 0; b < numBM25StatsBuckets; b++ { + key := x.BM25StatsKey(attr, b) + pl, perr := getList(key) + if perr != nil { + return 0, 0, perr + } + val, verr := pl.Value(readTs) + if verr == ErrNoValue { + continue + } + if verr != nil { + return 0, 0, verr + } + data, ok := val.Value.([]byte) + if !ok || len(data) == 0 { + continue + } + dc, tt := decodeBM25Stats(data) + docCount += dc + totalTerms += tt + } + return docCount, totalTerms, nil +} diff --git a/posting/bm25_test.go b/posting/bm25_test.go new file mode 100644 index 00000000000..82a4f712d60 --- /dev/null +++ b/posting/bm25_test.go @@ -0,0 +1,140 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "context" + "math" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/x" +) + +func TestBM25ValueCodecRoundTrip(t *testing.T) { + cases := [][2]uint32{{0, 0}, {1, 1}, {3, 12}, {7, 200}, {65535, 1 << 20}, {1 << 24, 1 << 24}} + for _, c := range cases { + tf, dl, ok := decodeBM25Value(encodeBM25Value(c[0], c[1])) + require.True(t, ok) + require.Equal(t, c[0], tf) + require.Equal(t, c[1], dl) + } + // Malformed/truncated input is reported as invalid so callers can skip it. + _, _, ok := decodeBM25Value(nil) + require.False(t, ok) + _, _, ok = decodeBM25Value([]byte{0x80}) // varint continuation byte with no terminator + require.False(t, ok) +} + +// TestBM25ValueSurvivesRollup verifies the linchpin of the BM25 redesign: a REF +// index posting that carries a packed (tf, docLen) value is retained — value and +// all — through rollup, instead of being collapsed to a UID-only Pack entry (the +// default behavior for REF postings before the len(p.Value) > 0 retention clause +// in List.encode). +func TestBM25ValueSurvivesRollup(t *testing.T) { + attr := x.AttrInRootNamespace("bm25rollup") + encodedTerm := string([]byte{0x10}) + "fox" // IdentBM25 || term + key := x.BM25IndexKey(attr, encodedTerm) + + docs := []struct { + uid uint64 + tf uint32 + docLen uint32 + }{ + {uid: 5, tf: 3, docLen: 12}, + {uid: 9, tf: 1, docLen: 40}, + {uid: 100, tf: 7, docLen: 200}, + } + + ts := uint64(1) + for _, d := range docs { + l, err := GetNoStore(key, ts) + require.NoError(t, err) + edge := &pb.DirectedEdge{ + ValueId: d.uid, + Attr: attr, + Value: encodeBM25Value(d.tf, d.docLen), + ValueType: pb.Posting_BINARY, + } + addMutation(t, l, edge, Set, ts, ts+1, false) + ts += 2 + } + + // Force a rollup and decode the resulting posting list directly. + l, err := getNew(key, pstore, math.MaxUint64, false) + require.NoError(t, err) + kvs, err := l.Rollup(nil, math.MaxUint64) + require.NoError(t, err) + require.NotEmpty(t, kvs) + + var plist pb.PostingList + require.NoError(t, proto.Unmarshal(kvs[0].Value, &plist)) + + got := make(map[uint64][2]uint32) + for _, p := range plist.Postings { + tf, docLen, ok := decodeBM25Value(p.Value) + require.True(t, ok) + got[p.Uid] = [2]uint32{tf, docLen} + } + for _, d := range docs { + v, ok := got[d.uid] + require.Truef(t, ok, "uid %d posting missing after rollup (value stripped?)", d.uid) + require.Equal(t, d.tf, v[0], "tf for uid %d", d.uid) + require.Equal(t, d.docLen, v[1], "docLen for uid %d", d.uid) + } + + // Reading the list back materializes the same (uid, tf, docLen) triples. + posts, err := ReadBM25TermPostings(func(k []byte) (*List, error) { + return getNew(k, pstore, math.MaxUint64, false) + }, attr, encodedTerm, math.MaxUint64) + require.NoError(t, err) + require.Len(t, posts, len(docs)) + for _, p := range posts { + require.Equal(t, got[p.Uid][0], p.TF) + require.Equal(t, got[p.Uid][1], p.DocLen) + } +} + +// TestBM25StatsBucketed verifies that bucketed corpus statistics accumulate +// correctly across documents (including two documents that hash to the same +// bucket, exercising in-transaction read-your-own-writes) and that deletes +// subtract correctly. +func TestBM25StatsBucketed(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25stats") + ts := uint64(101) + txn := Oracle().RegisterStartTs(ts) + + // uid 1 and uid 33 both fall in bucket 1 (mod 32), exercising same-bucket + // accumulation within a single transaction. + docs := []struct { + uid uint64 + dl int64 + }{{1, 10}, {2, 20}, {33, 5}, {64, 7}, {100, 8}} + + var wantCount, wantTerms int64 + for _, d := range docs { + require.NoError(t, txn.updateBM25Stats(ctx, attr, d.uid, 1, d.dl)) + wantCount++ + wantTerms += d.dl + } + + get := func(k []byte) (*List, error) { return txn.cache.GetFromDelta(k) } + dc, tt, err := ReadBM25Stats(get, attr, ts) + require.NoError(t, err) + require.Equal(t, uint64(wantCount), dc) + require.Equal(t, uint64(wantTerms), tt) + + // Delete uid 2: docCount and totalTerms drop accordingly. + require.NoError(t, txn.updateBM25Stats(ctx, attr, 2, -1, -20)) + dc, tt, err = ReadBM25Stats(get, attr, ts) + require.NoError(t, err) + require.Equal(t, uint64(wantCount-1), dc) + require.Equal(t, uint64(wantTerms-20), tt) +} diff --git a/posting/bm25block/bm25block.go b/posting/bm25block/bm25block.go deleted file mode 100644 index e9c4fa1776e..00000000000 --- a/posting/bm25block/bm25block.go +++ /dev/null @@ -1,262 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -// Package bm25block provides block-based storage for BM25 index data. -// -// Instead of storing all postings for a term in a single blob, this package -// splits them into fixed-size blocks (~128 entries). Each block is stored as -// a separate Badger KV entry, and a lightweight directory indexes the blocks. -// -// This enables: -// - Selective I/O: queries only read blocks they need -// - WAND/Block-Max WAND: per-block upper bounds enable early termination -// - Efficient mutations: only the affected block is rewritten -package bm25block - -import ( - "encoding/binary" - "math" - "sort" - - "github.com/dgraph-io/dgraph/v25/posting/bm25enc" -) - -const ( - // TargetBlockSize is the ideal number of entries per block. - TargetBlockSize = 128 - // MaxBlockSize is the threshold at which a block is split. - MaxBlockSize = 256 - // DocLenBlockSize is the target entries per document-length block. - DocLenBlockSize = 512 - - // dirHeaderSize is 4 (blockCount) + 4 (nextID). - dirHeaderSize = 8 - // dirEntrySize is 8 (firstUID) + 4 (blockID) + 4 (count) + 4 (maxTF). - dirEntrySize = 20 -) - -// BlockMeta stores metadata for a single block in a directory. -type BlockMeta struct { - FirstUID uint64 - BlockID uint32 - Count uint32 - MaxTF uint32 -} - -// Dir is a block directory for a term's posting list or document-length list. -type Dir struct { - Blocks []BlockMeta - NextID uint32 // next available block ID -} - -// EncodeDir encodes a directory to bytes. Returns nil for an empty directory. -func EncodeDir(d *Dir) []byte { - if d == nil || len(d.Blocks) == 0 { - return nil - } - buf := make([]byte, dirHeaderSize+len(d.Blocks)*dirEntrySize) - binary.BigEndian.PutUint32(buf[0:4], uint32(len(d.Blocks))) - binary.BigEndian.PutUint32(buf[4:8], d.NextID) - off := dirHeaderSize - for _, b := range d.Blocks { - binary.BigEndian.PutUint64(buf[off:off+8], b.FirstUID) - binary.BigEndian.PutUint32(buf[off+8:off+12], b.BlockID) - binary.BigEndian.PutUint32(buf[off+12:off+16], b.Count) - binary.BigEndian.PutUint32(buf[off+16:off+20], b.MaxTF) - off += dirEntrySize - } - return buf -} - -// DecodeDir decodes a directory from bytes. Returns an empty Dir for nil/invalid input. -func DecodeDir(data []byte) *Dir { - if len(data) < dirHeaderSize { - return &Dir{} - } - count := binary.BigEndian.Uint32(data[0:4]) - nextID := binary.BigEndian.Uint32(data[4:8]) - // Use uint64 arithmetic to prevent integer overflow on corrupted data. - if uint64(count)*dirEntrySize+dirHeaderSize > uint64(len(data)) { - return &Dir{NextID: nextID} - } - blocks := make([]BlockMeta, count) - off := dirHeaderSize - for i := uint32(0); i < count; i++ { - blocks[i] = BlockMeta{ - FirstUID: binary.BigEndian.Uint64(data[off : off+8]), - BlockID: binary.BigEndian.Uint32(data[off+8 : off+12]), - Count: binary.BigEndian.Uint32(data[off+12 : off+16]), - MaxTF: binary.BigEndian.Uint32(data[off+16 : off+20]), - } - off += dirEntrySize - } - return &Dir{Blocks: blocks, NextID: nextID} -} - -// FindBlock returns the index of the block that should contain uid. -// Returns 0 if the directory is empty (caller should create first block). -func (d *Dir) FindBlock(uid uint64) int { - if len(d.Blocks) == 0 { - return 0 - } - // Binary search: find the last block where FirstUID <= uid. - i := sort.Search(len(d.Blocks), func(i int) bool { - return d.Blocks[i].FirstUID > uid - }) - if i > 0 { - return i - 1 - } - return 0 -} - -// AllocBlockID returns the next available block ID and increments the counter. -func (d *Dir) AllocBlockID() uint32 { - id := d.NextID - d.NextID++ - return id -} - -// UpdateBlockMeta recomputes metadata for the block at index idx from entries. -func (d *Dir) UpdateBlockMeta(idx int, entries []bm25enc.Entry) { - if idx < 0 || idx >= len(d.Blocks) || len(entries) == 0 { - return - } - d.Blocks[idx].FirstUID = entries[0].UID - d.Blocks[idx].Count = uint32(len(entries)) - var maxTF uint32 - for _, e := range entries { - if e.Value > maxTF { - maxTF = e.Value - } - } - d.Blocks[idx].MaxTF = maxTF -} - -// InsertBlockMeta inserts a new block at position idx. -func (d *Dir) InsertBlockMeta(idx int, meta BlockMeta) { - d.Blocks = append(d.Blocks, BlockMeta{}) - copy(d.Blocks[idx+1:], d.Blocks[idx:]) - d.Blocks[idx] = meta -} - -// RemoveBlockMeta removes the block at position idx. -func (d *Dir) RemoveBlockMeta(idx int) { - if idx < 0 || idx >= len(d.Blocks) { - return - } - d.Blocks = append(d.Blocks[:idx], d.Blocks[idx+1:]...) -} - -// SplitIntoBlocks splits a sorted entry slice into blocks of TargetBlockSize. -// Returns a new Dir and a map of blockID -> entries. -func SplitIntoBlocks(entries []bm25enc.Entry) (*Dir, map[uint32][]bm25enc.Entry) { - if len(entries) == 0 { - return &Dir{}, nil - } - dir := &Dir{} - blockMap := make(map[uint32][]bm25enc.Entry) - - for i := 0; i < len(entries); i += TargetBlockSize { - end := i + TargetBlockSize - if end > len(entries) { - end = len(entries) - } - block := entries[i:end] - blockID := dir.AllocBlockID() - - var maxTF uint32 - for _, e := range block { - if e.Value > maxTF { - maxTF = e.Value - } - } - - dir.Blocks = append(dir.Blocks, BlockMeta{ - FirstUID: block[0].UID, - BlockID: blockID, - Count: uint32(len(block)), - MaxTF: maxTF, - }) - // Make a copy so the caller owns the slice. - cp := make([]bm25enc.Entry, len(block)) - copy(cp, block) - blockMap[blockID] = cp - } - return dir, blockMap -} - -// MergeAllBlocks reads all block entries from a map (keyed by blockID), -// merges them into a single sorted slice, then re-splits into clean blocks. -func MergeAllBlocks(dir *Dir, readBlock func(blockID uint32) []bm25enc.Entry) (*Dir, map[uint32][]bm25enc.Entry) { - var all []bm25enc.Entry - for _, bm := range dir.Blocks { - entries := readBlock(bm.BlockID) - all = append(all, entries...) - } - // Sort by UID and deduplicate (keep last occurrence for same UID). - sort.Slice(all, func(i, j int) bool { return all[i].UID < all[j].UID }) - deduped := make([]bm25enc.Entry, 0, len(all)) - for i, e := range all { - if i > 0 && e.UID == all[i-1].UID { - deduped[len(deduped)-1] = e // overwrite with latest - continue - } - deduped = append(deduped, e) - } - // Remove tombstones (Value == 0). - live := deduped[:0] - for _, e := range deduped { - if e.Value > 0 { - live = append(live, e) - } - } - return SplitIntoBlocks(live) -} - -// ComputeUBPre computes the upper-bound pre-IDF BM25 contribution for a block -// given its maxTF and query parameters k and b. -// With dl=0 (best case for scoring): score = (maxTF*(k+1)) / (maxTF + k*(1-b)) -func ComputeUBPre(maxTF uint32, k, b float64) float64 { - if maxTF == 0 { - return 0 - } - tf := float64(maxTF) - return tf * (k + 1) / (tf + k*(1-b)) -} - -// SuffixMaxUBPre computes suffix maxima of UBPre values for WAND. -// suffixMax[i] = max(ubPre[i], ubPre[i+1], ..., ubPre[n-1]) -func SuffixMaxUBPre(dir *Dir, k, b float64) []float64 { - n := len(dir.Blocks) - if n == 0 { - return nil - } - suf := make([]float64, n) - suf[n-1] = ComputeUBPre(dir.Blocks[n-1].MaxTF, k, b) - for i := n - 2; i >= 0; i-- { - ub := ComputeUBPre(dir.Blocks[i].MaxTF, k, b) - suf[i] = math.Max(ub, suf[i+1]) - } - return suf -} - -// BlockMetaFromEntries computes a BlockMeta from entries. -func BlockMetaFromEntries(blockID uint32, entries []bm25enc.Entry) BlockMeta { - if len(entries) == 0 { - return BlockMeta{BlockID: blockID} - } - var maxTF uint32 - for _, e := range entries { - if e.Value > maxTF { - maxTF = e.Value - } - } - return BlockMeta{ - FirstUID: entries[0].UID, - BlockID: blockID, - Count: uint32(len(entries)), - MaxTF: maxTF, - } -} diff --git a/posting/bm25block/bm25block_test.go b/posting/bm25block/bm25block_test.go deleted file mode 100644 index f9cccb7554a..00000000000 --- a/posting/bm25block/bm25block_test.go +++ /dev/null @@ -1,271 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package bm25block - -import ( - "encoding/binary" - "math" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/dgraph-io/dgraph/v25/posting/bm25enc" -) - -func TestDirRoundtrip(t *testing.T) { - dir := &Dir{ - NextID: 5, - Blocks: []BlockMeta{ - {FirstUID: 100, BlockID: 0, Count: 128, MaxTF: 10}, - {FirstUID: 500, BlockID: 1, Count: 128, MaxTF: 5}, - {FirstUID: 900, BlockID: 2, Count: 64, MaxTF: 20}, - }, - } - data := EncodeDir(dir) - got := DecodeDir(data) - require.Equal(t, dir.NextID, got.NextID) - require.Equal(t, dir.Blocks, got.Blocks) -} - -func TestDirRoundtripEmpty(t *testing.T) { - require.Nil(t, EncodeDir(nil)) - require.Nil(t, EncodeDir(&Dir{})) - - got := DecodeDir(nil) - require.Empty(t, got.Blocks) - got = DecodeDir([]byte{}) - require.Empty(t, got.Blocks) -} - -func TestDecodeDirCorruptedLargeCount(t *testing.T) { - // A corrupted blob with a massive count should not panic due to integer overflow. - // count = MaxUint32, nextID = 0, followed by only 8 bytes of data. - data := make([]byte, 16) - binary.BigEndian.PutUint32(data[0:4], 0xFFFFFFFF) // count = MaxUint32 - binary.BigEndian.PutUint32(data[4:8], 0) // nextID = 0 - got := DecodeDir(data) - // Should return an empty Dir (with nextID preserved) rather than panicking. - require.Empty(t, got.Blocks) - require.Equal(t, uint32(0), got.NextID) -} - -func TestDirRoundtripSingle(t *testing.T) { - dir := &Dir{ - NextID: 1, - Blocks: []BlockMeta{{FirstUID: 42, BlockID: 0, Count: 1, MaxTF: 3}}, - } - got := DecodeDir(EncodeDir(dir)) - require.Equal(t, dir.Blocks, got.Blocks) -} - -func TestFindBlock(t *testing.T) { - dir := &Dir{ - Blocks: []BlockMeta{ - {FirstUID: 100}, - {FirstUID: 500}, - {FirstUID: 900}, - }, - } - require.Equal(t, 0, dir.FindBlock(50)) // before first block - require.Equal(t, 0, dir.FindBlock(100)) // exact first - require.Equal(t, 0, dir.FindBlock(200)) // within first block - require.Equal(t, 1, dir.FindBlock(500)) // exact second - require.Equal(t, 1, dir.FindBlock(700)) // within second block - require.Equal(t, 2, dir.FindBlock(900)) // exact third - require.Equal(t, 2, dir.FindBlock(9999)) // beyond last block -} - -func TestFindBlockEmpty(t *testing.T) { - dir := &Dir{} - require.Equal(t, 0, dir.FindBlock(100)) -} - -func TestAllocBlockID(t *testing.T) { - dir := &Dir{NextID: 3} - require.Equal(t, uint32(3), dir.AllocBlockID()) - require.Equal(t, uint32(4), dir.AllocBlockID()) - require.Equal(t, uint32(5), dir.NextID) -} - -func TestSplitIntoBlocks(t *testing.T) { - // Create 300 entries. - entries := make([]bm25enc.Entry, 300) - for i := range entries { - entries[i] = bm25enc.Entry{UID: uint64(i + 1), Value: uint32(i%10 + 1)} - } - dir, blockMap := SplitIntoBlocks(entries) - - // Should split into ceil(300/128) = 3 blocks. - require.Len(t, dir.Blocks, 3) - require.Len(t, blockMap, 3) - - // First block: 128 entries. - require.Equal(t, uint32(128), dir.Blocks[0].Count) - require.Equal(t, uint64(1), dir.Blocks[0].FirstUID) - require.Len(t, blockMap[dir.Blocks[0].BlockID], 128) - - // Second block: 128 entries. - require.Equal(t, uint32(128), dir.Blocks[1].Count) - require.Equal(t, uint64(129), dir.Blocks[1].FirstUID) - - // Third block: 44 entries. - require.Equal(t, uint32(44), dir.Blocks[2].Count) - require.Equal(t, uint64(257), dir.Blocks[2].FirstUID) - - // NextID should be 3. - require.Equal(t, uint32(3), dir.NextID) -} - -func TestSplitIntoBlocksEmpty(t *testing.T) { - dir, blockMap := SplitIntoBlocks(nil) - require.Empty(t, dir.Blocks) - require.Nil(t, blockMap) -} - -func TestSplitIntoBlocksSmall(t *testing.T) { - entries := []bm25enc.Entry{{UID: 1, Value: 5}, {UID: 2, Value: 3}} - dir, blockMap := SplitIntoBlocks(entries) - require.Len(t, dir.Blocks, 1) - require.Equal(t, uint32(2), dir.Blocks[0].Count) - require.Equal(t, uint32(5), dir.Blocks[0].MaxTF) - require.Equal(t, entries, blockMap[0]) -} - -func TestUpdateBlockMeta(t *testing.T) { - dir := &Dir{ - Blocks: []BlockMeta{{FirstUID: 100, BlockID: 0, Count: 3, MaxTF: 5}}, - } - entries := []bm25enc.Entry{ - {UID: 50, Value: 2}, - {UID: 100, Value: 8}, - {UID: 200, Value: 3}, - {UID: 300, Value: 1}, - } - dir.UpdateBlockMeta(0, entries) - require.Equal(t, uint64(50), dir.Blocks[0].FirstUID) - require.Equal(t, uint32(4), dir.Blocks[0].Count) - require.Equal(t, uint32(8), dir.Blocks[0].MaxTF) -} - -func TestInsertRemoveBlockMeta(t *testing.T) { - dir := &Dir{ - Blocks: []BlockMeta{ - {FirstUID: 100, BlockID: 0}, - {FirstUID: 500, BlockID: 1}, - }, - } - dir.InsertBlockMeta(1, BlockMeta{FirstUID: 300, BlockID: 2}) - require.Len(t, dir.Blocks, 3) - require.Equal(t, uint64(300), dir.Blocks[1].FirstUID) - require.Equal(t, uint64(500), dir.Blocks[2].FirstUID) - - dir.RemoveBlockMeta(1) - require.Len(t, dir.Blocks, 2) - require.Equal(t, uint64(500), dir.Blocks[1].FirstUID) -} - -func TestComputeUBPre(t *testing.T) { - k, b := 1.2, 0.75 - - // maxTF=0 -> 0 - require.Equal(t, 0.0, ComputeUBPre(0, k, b)) - - // maxTF=1: 1 * 2.2 / (1 + 1.2*0.25) = 2.2 / 1.3 - expected := 2.2 / 1.3 - require.InEpsilon(t, expected, ComputeUBPre(1, k, b), 1e-9) - - // maxTF=10: 10 * 2.2 / (10 + 1.2*0.25) = 22 / 10.3 - expected = 22.0 / 10.3 - require.InEpsilon(t, expected, ComputeUBPre(10, k, b), 1e-9) - - // With b=0: score = tf*(k+1)/(tf+k) — no length normalization. - expected = 5.0 * 2.2 / (5.0 + 1.2) - require.InEpsilon(t, expected, ComputeUBPre(5, k, 0), 1e-9) -} - -func TestSuffixMaxUBPre(t *testing.T) { - dir := &Dir{ - Blocks: []BlockMeta{ - {MaxTF: 1}, - {MaxTF: 10}, - {MaxTF: 3}, - }, - } - k, b := 1.2, 0.75 - suf := SuffixMaxUBPre(dir, k, b) - require.Len(t, suf, 3) - - ub0 := ComputeUBPre(1, k, b) - ub1 := ComputeUBPre(10, k, b) - ub2 := ComputeUBPre(3, k, b) - - require.InEpsilon(t, math.Max(ub0, math.Max(ub1, ub2)), suf[0], 1e-9) - require.InEpsilon(t, math.Max(ub1, ub2), suf[1], 1e-9) - require.InEpsilon(t, ub2, suf[2], 1e-9) -} - -func TestSuffixMaxUBPreEmpty(t *testing.T) { - require.Nil(t, SuffixMaxUBPre(&Dir{}, 1.2, 0.75)) -} - -func TestMergeAllBlocks(t *testing.T) { - // Simulate overlapping blocks with a tombstone. - blocks := map[uint32][]bm25enc.Entry{ - 0: {{UID: 1, Value: 3}, {UID: 5, Value: 1}}, - 1: {{UID: 5, Value: 7}, {UID: 10, Value: 2}}, // UID 5 overrides - 2: {{UID: 15, Value: 0}, {UID: 20, Value: 4}}, // UID 15 is tombstone - } - dir := &Dir{ - Blocks: []BlockMeta{ - {FirstUID: 1, BlockID: 0, Count: 2}, - {FirstUID: 5, BlockID: 1, Count: 2}, - {FirstUID: 15, BlockID: 2, Count: 2}, - }, - NextID: 3, - } - newDir, newBlocks := MergeAllBlocks(dir, func(id uint32) []bm25enc.Entry { - return blocks[id] - }) - // After merge: UID 1(3), 5(7), 10(2), 20(4) — UID 15 removed (tombstone). - require.Len(t, newDir.Blocks, 1) // 4 entries fits in one block - require.Len(t, newBlocks, 1) - entries := newBlocks[newDir.Blocks[0].BlockID] - require.Len(t, entries, 4) - require.Equal(t, uint64(1), entries[0].UID) - require.Equal(t, uint32(3), entries[0].Value) - require.Equal(t, uint64(5), entries[1].UID) - require.Equal(t, uint32(7), entries[1].Value) - require.Equal(t, uint64(20), entries[3].UID) -} - -func TestBlockMetaFromEntries(t *testing.T) { - entries := []bm25enc.Entry{ - {UID: 10, Value: 2}, - {UID: 20, Value: 8}, - {UID: 30, Value: 1}, - } - meta := BlockMetaFromEntries(5, entries) - require.Equal(t, uint32(5), meta.BlockID) - require.Equal(t, uint64(10), meta.FirstUID) - require.Equal(t, uint32(3), meta.Count) - require.Equal(t, uint32(8), meta.MaxTF) -} - -func TestBlockMetaFromEntriesEmpty(t *testing.T) { - meta := BlockMetaFromEntries(0, nil) - require.Equal(t, uint32(0), meta.Count) -} - -func BenchmarkSplitIntoBlocks(b *testing.B) { - entries := make([]bm25enc.Entry, 100000) - for i := range entries { - entries[i] = bm25enc.Entry{UID: uint64(i*3 + 1), Value: uint32(i%100 + 1)} - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - SplitIntoBlocks(entries) - } -} diff --git a/posting/bm25enc/bm25enc.go b/posting/bm25enc/bm25enc.go deleted file mode 100644 index 86bfe5f5bb1..00000000000 --- a/posting/bm25enc/bm25enc.go +++ /dev/null @@ -1,158 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -// Package bm25enc provides compact binary encoding for BM25 index data. -// -// Two types of lists share the same format: -// - Term posting lists: (UID, term-frequency) pairs -// - Document length lists: (UID, doc-length) pairs -// -// Binary format: -// -// Header: -// [4 bytes] uint32 big-endian: entry count -// Entries (sorted ascending by UID): -// [varint] UID delta from previous (first entry is absolute) -// [varint] value (TF or doclen) -package bm25enc - -import ( - "encoding/binary" - "sort" -) - -// Entry represents a single (UID, Value) pair in a BM25 posting list. -type Entry struct { - UID uint64 - Value uint32 -} - -// Encode encodes a sorted slice of entries into the compact binary format. -// Entries must be sorted by UID ascending. Returns nil for empty input. -func Encode(entries []Entry) []byte { - if len(entries) == 0 { - return nil - } - - // Pre-allocate: 4 header + ~6 bytes per entry is a reasonable estimate. - buf := make([]byte, 4, 4+len(entries)*6) - binary.BigEndian.PutUint32(buf, uint32(len(entries))) - - var tmp [binary.MaxVarintLen64]byte - var prevUID uint64 - for _, e := range entries { - delta := e.UID - prevUID - n := binary.PutUvarint(tmp[:], delta) - buf = append(buf, tmp[:n]...) - n = binary.PutUvarint(tmp[:], uint64(e.Value)) - buf = append(buf, tmp[:n]...) - prevUID = e.UID - } - return buf -} - -// Decode decodes the binary format into a sorted slice of entries. -// Returns nil for nil/empty input. -func Decode(data []byte) []Entry { - if len(data) < 4 { - return nil - } - count := binary.BigEndian.Uint32(data[:4]) - if count == 0 { - return nil - } - - entries := make([]Entry, 0, count) - pos := 4 - var prevUID uint64 - for i := uint32(0); i < count; i++ { - delta, n := binary.Uvarint(data[pos:]) - if n <= 0 { - break - } - pos += n - - val, n := binary.Uvarint(data[pos:]) - if n <= 0 { - break - } - pos += n - - uid := prevUID + delta - entries = append(entries, Entry{UID: uid, Value: uint32(val)}) - prevUID = uid - } - return entries -} - -// Upsert inserts or updates the entry for uid in a sorted entries slice. -// Returns the new sorted slice. -func Upsert(entries []Entry, uid uint64, value uint32) []Entry { - i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) - if i < len(entries) && entries[i].UID == uid { - entries[i].Value = value - return entries - } - // Insert at position i. - entries = append(entries, Entry{}) - copy(entries[i+1:], entries[i:]) - entries[i] = Entry{UID: uid, Value: value} - return entries -} - -// Remove removes the entry for uid from a sorted entries slice. -// Returns the new slice (may be shorter). -func Remove(entries []Entry, uid uint64) []Entry { - i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) - if i < len(entries) && entries[i].UID == uid { - return append(entries[:i], entries[i+1:]...) - } - return entries -} - -// Search returns the value for uid using binary search, and whether it was found. -func Search(entries []Entry, uid uint64) (uint32, bool) { - i := sort.Search(len(entries), func(i int) bool { return entries[i].UID >= uid }) - if i < len(entries) && entries[i].UID == uid { - return entries[i].Value, true - } - return 0, false -} - -// UIDs extracts just the UIDs from entries as a uint64 slice. -func UIDs(entries []Entry) []uint64 { - uids := make([]uint64, len(entries)) - for i, e := range entries { - uids[i] = e.UID - } - return uids -} - -// DecodeCount reads just the entry count from the header of an encoded blob -// without decoding any entries. This is O(1) and avoids allocating a full -// []Entry slice, which matters for large posting lists (e.g., common terms -// during legacy format migration). -func DecodeCount(data []byte) uint32 { - if len(data) < 4 { - return 0 - } - return binary.BigEndian.Uint32(data[:4]) -} - -// EncodeStats encodes BM25 corpus statistics (docCount, totalTerms) as 16 bytes. -func EncodeStats(docCount, totalTerms uint64) []byte { - buf := make([]byte, 16) - binary.BigEndian.PutUint64(buf[0:8], docCount) - binary.BigEndian.PutUint64(buf[8:16], totalTerms) - return buf -} - -// DecodeStats decodes BM25 corpus statistics. Returns (0,0) for invalid input. -func DecodeStats(data []byte) (docCount, totalTerms uint64) { - if len(data) != 16 { - return 0, 0 - } - return binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) -} diff --git a/posting/bm25enc/bm25enc_test.go b/posting/bm25enc/bm25enc_test.go deleted file mode 100644 index f4cfec6bf62..00000000000 --- a/posting/bm25enc/bm25enc_test.go +++ /dev/null @@ -1,163 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package bm25enc - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestRoundtrip(t *testing.T) { - entries := []Entry{ - {UID: 1, Value: 3}, - {UID: 5, Value: 1}, - {UID: 100, Value: 7}, - {UID: 200, Value: 2}, - } - data := Encode(entries) - got := Decode(data) - require.Equal(t, entries, got) -} - -func TestRoundtripEmpty(t *testing.T) { - require.Nil(t, Encode(nil)) - require.Nil(t, Encode([]Entry{})) - require.Nil(t, Decode(nil)) - require.Nil(t, Decode([]byte{})) - require.Nil(t, Decode([]byte{0, 0, 0, 0})) // count=0 -} - -func TestRoundtripSingle(t *testing.T) { - entries := []Entry{{UID: 42, Value: 10}} - got := Decode(Encode(entries)) - require.Equal(t, entries, got) -} - -func TestRoundtripLargeUIDs(t *testing.T) { - entries := []Entry{ - {UID: 1<<40 + 1, Value: 1}, - {UID: 1<<40 + 1000, Value: 5}, - {UID: 1<<50 + 999, Value: 99}, - } - got := Decode(Encode(entries)) - require.Equal(t, entries, got) -} - -func TestUpsertNew(t *testing.T) { - entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} - entries = Upsert(entries, 3, 7) - require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 3, Value: 7}, {UID: 5, Value: 1}}, entries) -} - -func TestUpsertExisting(t *testing.T) { - entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} - entries = Upsert(entries, 5, 99) - require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 99}}, entries) -} - -func TestUpsertEmpty(t *testing.T) { - var entries []Entry - entries = Upsert(entries, 10, 5) - require.Equal(t, []Entry{{UID: 10, Value: 5}}, entries) -} - -func TestRemove(t *testing.T) { - entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 10, Value: 2}} - entries = Remove(entries, 5) - require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 10, Value: 2}}, entries) -} - -func TestRemoveNotFound(t *testing.T) { - entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}} - entries = Remove(entries, 99) - require.Equal(t, []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}}, entries) -} - -func TestSearch(t *testing.T) { - entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 100, Value: 7}} - v, ok := Search(entries, 5) - require.True(t, ok) - require.Equal(t, uint32(1), v) - - _, ok = Search(entries, 50) - require.False(t, ok) -} - -func TestUIDs(t *testing.T) { - entries := []Entry{{UID: 1, Value: 3}, {UID: 5, Value: 1}, {UID: 100, Value: 7}} - require.Equal(t, []uint64{1, 5, 100}, UIDs(entries)) -} - -func TestDecodeCount(t *testing.T) { - // Normal case: count matches actual entries. - entries := []Entry{ - {UID: 1, Value: 3}, - {UID: 5, Value: 1}, - {UID: 100, Value: 7}, - } - data := Encode(entries) - require.Equal(t, uint32(3), DecodeCount(data)) - - // Empty/nil input. - require.Equal(t, uint32(0), DecodeCount(nil)) - require.Equal(t, uint32(0), DecodeCount([]byte{})) - require.Equal(t, uint32(0), DecodeCount([]byte{1, 2, 3})) - - // Zero count. - require.Equal(t, uint32(0), DecodeCount([]byte{0, 0, 0, 0})) - - // Single entry. - single := Encode([]Entry{{UID: 42, Value: 10}}) - require.Equal(t, uint32(1), DecodeCount(single)) - - // Large count. - large := make([]Entry, 10000) - for i := range large { - large[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} - } - data = Encode(large) - require.Equal(t, uint32(10000), DecodeCount(data)) -} - -func TestStatsRoundtrip(t *testing.T) { - data := EncodeStats(12345, 98765) - dc, tt := DecodeStats(data) - require.Equal(t, uint64(12345), dc) - require.Equal(t, uint64(98765), tt) -} - -func TestStatsInvalid(t *testing.T) { - dc, tt := DecodeStats(nil) - require.Zero(t, dc) - require.Zero(t, tt) - dc, tt = DecodeStats([]byte{1, 2, 3}) - require.Zero(t, dc) - require.Zero(t, tt) -} - -func BenchmarkEncode(b *testing.B) { - entries := make([]Entry, 10000) - for i := range entries { - entries[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - Encode(entries) - } -} - -func BenchmarkDecode(b *testing.B) { - entries := make([]Entry, 10000) - for i := range entries { - entries[i] = Entry{UID: uint64(i*3 + 1), Value: uint32(i % 100)} - } - data := Encode(entries) - b.ResetTimer() - for i := 0; i < b.N; i++ { - Decode(data) - } -} diff --git a/posting/index.go b/posting/index.go index e19deae8d73..edb997cc4b6 100644 --- a/posting/index.go +++ b/posting/index.go @@ -28,8 +28,6 @@ import ( "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4/options" bpb "github.com/dgraph-io/badger/v4/pb" - "github.com/dgraph-io/dgraph/v25/posting/bm25block" - "github.com/dgraph-io/dgraph/v25/posting/bm25enc" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" "github.com/dgraph-io/dgraph/v25/tok" @@ -232,10 +230,20 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok return nil } -// addBM25IndexMutations handles index mutations for the BM25 tokenizer. -// It stores term frequencies, document lengths, and corpus statistics using -// block-based storage: each term's postings and the doclen list are split into -// fixed-size blocks (~128 entries) with a lightweight directory for navigation. +// addBM25IndexMutations handles index mutations for the BM25 tokenizer. Unlike +// other tokenizers, each BM25 index posting carries a value that packs the term +// frequency together with the document length (see encodeBM25Value). The postings +// are written through the standard delta path (plist.addMutation), so BM25 rides +// Dgraph's normal posting-list machinery — MVCC, deltas, rollup, splits, backup — +// with no separate storage path. Corpus statistics (document count and total term +// count, from which the average document length is derived) are kept in bucketed +// stats posting lists keyed by uid%numBM25StatsBuckets to avoid a single write-hot +// key while preserving conflict detection per bucket. +// +// Updates are driven entirely by the caller (AddMutationWithIndex), which issues a +// DEL for the previous value followed by a SET for the new one. The DEL re-tokenizes +// the old value and removes its postings and stats contribution; the SET adds the new +// ones. We therefore never need to detect updates here. func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { attr := info.edge.Attr uid := info.edge.Entity @@ -262,236 +270,16 @@ func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationIn return nil } - if info.op == pb.DirectedEdge_DEL { - // For DELETE: remove uid from all term blocks and doclen blocks. - for term := range termFreqs { - encodedTerm := string([]byte{tok.IdentBM25}) + term - txn.bm25BlockRemove(attr, encodedTerm, uid) - } - txn.bm25DocLenBlockRemove(attr, uid) - return txn.updateBM25Stats(attr, -1, -int64(docLen)) - } - - // For SET: check if this UID already has a doclen entry (i.e., this is an update). - // If so, subtract old stats to avoid double-counting. - oldDocLen, isUpdate := txn.bm25DocLenBlockLookup(attr, uid) - for term, tf := range termFreqs { - encodedTerm := string([]byte{tok.IdentBM25}) + term - txn.bm25BlockUpsert(attr, encodedTerm, uid, tf) - } - txn.bm25DocLenBlockUpsert(attr, uid, docLen) - - var docCountDelta int64 - var totalTermsDelta int64 - if isUpdate { - // Document already existed: don't increment docCount, adjust totalTerms by diff. - totalTermsDelta = int64(docLen) - int64(oldDocLen) - } else { - docCountDelta = 1 - totalTermsDelta = int64(docLen) - } - return txn.updateBM25Stats(attr, docCountDelta, totalTermsDelta) -} - -// bm25BlockUpsert inserts or updates a (uid, value) entry in the block-based -// posting list for the given term. Handles block creation and splitting. -func (txn *Txn) bm25BlockUpsert(attr, encodedTerm string, uid uint64, value uint32) { - dirKey := x.BM25TermDirKey(attr, encodedTerm) - dirBlob := txn.cache.ReadBM25Blob(dirKey) - dir := bm25block.DecodeDir(dirBlob) - - if len(dir.Blocks) == 0 { - // First entry for this term: create a single block. - blockID := dir.AllocBlockID() - entries := []bm25enc.Entry{{UID: uid, Value: value}} - blockKey := x.BM25TermBlockKey(attr, encodedTerm, blockID) - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) - dir.Blocks = append(dir.Blocks, bm25block.BlockMetaFromEntries(blockID, entries)) - txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) - return - } - - // Find the target block, read it, upsert, and handle splits. - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25TermBlockKey(attr, encodedTerm, bm.BlockID) - blob := txn.cache.ReadBM25Blob(blockKey) - entries := bm25enc.Decode(blob) - entries = bm25enc.Upsert(entries, uid, value) - - if len(entries) > bm25block.MaxBlockSize { - // Split the block. - mid := len(entries) / 2 - left := entries[:mid] - right := entries[mid:] - - // Write left block (reuse existing blockID). - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(left)) - dir.UpdateBlockMeta(blockIdx, left) - - // Write right block (new blockID). - newBlockID := dir.AllocBlockID() - newBlockKey := x.BM25TermBlockKey(attr, encodedTerm, newBlockID) - txn.cache.WriteBM25Blob(newBlockKey, bm25enc.Encode(right)) - dir.InsertBlockMeta(blockIdx+1, bm25block.BlockMetaFromEntries(newBlockID, right)) - } else { - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) - dir.UpdateBlockMeta(blockIdx, entries) - } - txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) -} - -// bm25BlockRemove removes a uid from the block-based posting list for the given term. -func (txn *Txn) bm25BlockRemove(attr, encodedTerm string, uid uint64) { - dirKey := x.BM25TermDirKey(attr, encodedTerm) - dirBlob := txn.cache.ReadBM25Blob(dirKey) - dir := bm25block.DecodeDir(dirBlob) - - if len(dir.Blocks) == 0 { - return - } - - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25TermBlockKey(attr, encodedTerm, bm.BlockID) - blob := txn.cache.ReadBM25Blob(blockKey) - entries := bm25enc.Decode(blob) - entries = bm25enc.Remove(entries, uid) - - if len(entries) == 0 { - // Block is empty; remove it from the directory. - txn.cache.WriteBM25Blob(blockKey, nil) - dir.RemoveBlockMeta(blockIdx) - } else { - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) - dir.UpdateBlockMeta(blockIdx, entries) - } - txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) -} - -// bm25DocLenBlockUpsert inserts or updates a doc-length entry in the block-based -// document-length list. -func (txn *Txn) bm25DocLenBlockUpsert(attr string, uid uint64, docLen uint32) { - dirKey := x.BM25DocLenDirKey(attr) - dirBlob := txn.cache.ReadBM25Blob(dirKey) - dir := bm25block.DecodeDir(dirBlob) - - if len(dir.Blocks) == 0 { - blockID := dir.AllocBlockID() - entries := []bm25enc.Entry{{UID: uid, Value: docLen}} - blockKey := x.BM25DocLenBlockKey(attr, blockID) - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) - dir.Blocks = append(dir.Blocks, bm25block.BlockMetaFromEntries(blockID, entries)) - txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) - return - } - - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) - blob := txn.cache.ReadBM25Blob(blockKey) - entries := bm25enc.Decode(blob) - entries = bm25enc.Upsert(entries, uid, docLen) - - if len(entries) > bm25block.MaxBlockSize { - mid := len(entries) / 2 - left := entries[:mid] - right := entries[mid:] - - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(left)) - dir.UpdateBlockMeta(blockIdx, left) - - newBlockID := dir.AllocBlockID() - newBlockKey := x.BM25DocLenBlockKey(attr, newBlockID) - txn.cache.WriteBM25Blob(newBlockKey, bm25enc.Encode(right)) - dir.InsertBlockMeta(blockIdx+1, bm25block.BlockMetaFromEntries(newBlockID, right)) - } else { - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) - dir.UpdateBlockMeta(blockIdx, entries) - } - txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) -} - -// bm25DocLenBlockLookup checks if a uid exists in the doclen blocks and returns its value. -func (txn *Txn) bm25DocLenBlockLookup(attr string, uid uint64) (uint32, bool) { - dirKey := x.BM25DocLenDirKey(attr) - dirBlob := txn.cache.ReadBM25Blob(dirKey) - dir := bm25block.DecodeDir(dirBlob) - - if len(dir.Blocks) == 0 { - return 0, false - } - - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) - blob := txn.cache.ReadBM25Blob(blockKey) - entries := bm25enc.Decode(blob) - if v, ok := bm25enc.Search(entries, uid); ok { - return v, true - } - return 0, false -} - -// bm25DocLenBlockRemove removes a uid from the block-based document-length list. -func (txn *Txn) bm25DocLenBlockRemove(attr string, uid uint64) { - dirKey := x.BM25DocLenDirKey(attr) - dirBlob := txn.cache.ReadBM25Blob(dirKey) - dir := bm25block.DecodeDir(dirBlob) - - if len(dir.Blocks) == 0 { - return - } - - blockIdx := dir.FindBlock(uid) - bm := dir.Blocks[blockIdx] - blockKey := x.BM25DocLenBlockKey(attr, bm.BlockID) - blob := txn.cache.ReadBM25Blob(blockKey) - entries := bm25enc.Decode(blob) - entries = bm25enc.Remove(entries, uid) - - if len(entries) == 0 { - txn.cache.WriteBM25Blob(blockKey, nil) - dir.RemoveBlockMeta(blockIdx) - } else { - txn.cache.WriteBM25Blob(blockKey, bm25enc.Encode(entries)) - dir.UpdateBlockMeta(blockIdx, entries) - } - txn.cache.WriteBM25Blob(dirKey, bm25block.EncodeDir(dir)) -} - -// updateBM25Stats reads the current corpus statistics for a BM25-indexed attribute, -// applies the given deltas, and writes back as a direct Badger KV entry. -func (txn *Txn) updateBM25Stats(attr string, docCountDelta int64, totalTermsDelta int64) error { - statsKey := x.BM25StatsKey(attr) - blob := txn.cache.ReadBM25Blob(statsKey) - docCount, totalTerms := bm25enc.DecodeStats(blob) - - // Apply deltas. - if docCountDelta >= 0 { - docCount += uint64(docCountDelta) - } else { - dec := uint64(-docCountDelta) - if dec > docCount { - docCount = 0 - } else { - docCount -= dec - } - } - if totalTermsDelta >= 0 { - totalTerms += uint64(totalTermsDelta) - } else { - dec := uint64(-totalTermsDelta) - if dec > totalTerms { - totalTerms = 0 - } else { - totalTerms -= dec + if err := txn.addBM25TermPosting(ctx, attr, term, uid, tf, docLen, info.op); err != nil { + return err } } - txn.cache.WriteBM25Blob(statsKey, bm25enc.EncodeStats(docCount, totalTerms)) - return nil + if info.op == pb.DirectedEdge_DEL { + return txn.updateBM25Stats(ctx, attr, uid, -1, -int64(docLen)) + } + return txn.updateBM25Stats(ctx, attr, uid, 1, int64(docLen)) } // countParams is sent to updateCount function. It is used to update the count index. diff --git a/posting/list.go b/posting/list.go index 5420a69a157..610eaf5b5c2 100644 --- a/posting/list.go +++ b/posting/list.go @@ -60,8 +60,6 @@ const ( BitCompletePosting byte = 0x08 // BitEmptyPosting signals that the value stores an empty posting list. BitEmptyPosting byte = 0x10 - // BitBM25Data signals that the value stores BM25 index data (direct KV, not a posting list). - BitBM25Data byte = 0x20 ) // List stores the in-memory representation of a posting list. @@ -1629,7 +1627,14 @@ func (l *List) encode(out *rollupOutput, readTs uint64, split bool) error { } enc.Add(p.Uid) - if p.Facets != nil || p.PostingType != pb.Posting_REF { + // Retain the full posting (not just its UID in the Pack) whenever it + // carries facets, is not a plain UID reference, or carries a value. + // BM25 index postings are REF postings that pack (term-frequency, + // doc-length) into Value; without the len(p.Value) > 0 clause that + // value would be stripped at rollup, silently losing all term + // frequencies. This mirrors how faceted postings already coexist in + // both Pack (UID) and Postings (payload). + if p.Facets != nil || p.PostingType != pb.Posting_REF || len(p.Value) > 0 { plist.Postings = append(plist.Postings, p) } return nil diff --git a/posting/lists.go b/posting/lists.go index 22d20a53973..a4bc4fb355b 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -76,22 +76,6 @@ type LocalCache struct { // plists are posting lists in memory. They can be discarded to reclaim space. plists map[string]*List - - // bm25Writes buffers BM25 direct KV writes (key → encoded blob). - // These bypass the posting list infrastructure entirely. - // - // CONCURRENCY NOTE: BM25 blocks use full-value overwrites rather than - // posting list deltas. Within a single Dgraph transaction this is safe - // (each Txn has its own LocalCache). Across concurrent transactions, - // Dgraph's Raft-based mutation serialization prevents lost updates for - // the same predicate+UID pair. However, two transactions updating - // different UIDs that share a common term could theoretically race on - // the same term block. In practice this is mitigated by: - // 1. Dgraph serializes mutations through Raft proposals - // 2. Block splits keep contention surface small - // If higher write concurrency is needed, blocks should be integrated - // into the posting list delta mechanism. - bm25Writes map[string][]byte } // struct to implement LocalCache interface from vector-indexer @@ -151,7 +135,6 @@ func NewLocalCache(startTs uint64) *LocalCache { deltas: make(map[string][]byte), plists: make(map[string]*List), maxVersions: make(map[string]uint64), - bm25Writes: make(map[string][]byte), } } @@ -161,57 +144,6 @@ func NoCache(startTs uint64) *LocalCache { return &LocalCache{startTs: startTs} } -// ReadBM25Blob returns the BM25 blob for the given key. -// It checks the in-memory buffer first (read-your-own-writes), -// then falls back to reading from pstore at startTs. -func (lc *LocalCache) ReadBM25Blob(key []byte) []byte { - lc.RLock() - if blob, ok := lc.bm25Writes[string(key)]; ok { - lc.RUnlock() - return blob - } - lc.RUnlock() - - // Fall back to Badger. - txn := pstore.NewTransactionAt(lc.startTs, false) - defer txn.Discard() - item, err := txn.Get(key) - if err != nil { - return nil - } - val, err := item.ValueCopy(nil) - if err != nil { - return nil - } - return val -} - -// WriteBM25Blob buffers a BM25 blob write for the given key. -func (lc *LocalCache) WriteBM25Blob(key []byte, blob []byte) { - lc.Lock() - defer lc.Unlock() - if lc.bm25Writes == nil { - lc.bm25Writes = make(map[string][]byte) - } - lc.bm25Writes[string(key)] = blob -} - -// ReadBM25BlobAt reads a BM25 blob from pstore at the given read timestamp. -// This is used by the query read path (worker/task.go). -func ReadBM25BlobAt(key []byte, readTs uint64) []byte { - txn := pstore.NewTransactionAt(readTs, false) - defer txn.Discard() - item, err := txn.Get(key) - if err != nil { - return nil - } - val, err := item.ValueCopy(nil) - if err != nil { - return nil - } - return val -} - func (lc *LocalCache) UpdateCommitTs(commitTs uint64) { lc.Lock() defer lc.Unlock() diff --git a/posting/mvcc.go b/posting/mvcc.go index 3b9510ef6bb..108cdfc3b3e 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -319,20 +319,6 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { } } - // Flush BM25 direct KV writes. These are complete blobs (not deltas) - // and don't need rollup. - for key, blob := range cache.bm25Writes { - if err := writer.update(commitTs, func(btxn *badger.Txn) error { - return btxn.SetEntry(&badger.Entry{ - Key: []byte(key), - Value: blob, - UserMeta: BitBM25Data, - }) - }); err != nil { - return err - } - } - return nil } diff --git a/query/query.go b/query/query.go index e241d18946b..80ca66b185d 100644 --- a/query/query.go +++ b/query/query.go @@ -1383,9 +1383,6 @@ func (sg *SubGraph) valueVarAggregation(doneVars map[string]varValue, path []*Su case sg.Attr == "uid" && sg.Params.DoCount: // This is the count(uid) case. // We will do the computation later while constructing the result. - case sg.Attr == "bm25_score": - // bm25_score is a pseudo-predicate handled inline during children processing. - // Its valueMatrix is already populated. Nothing to aggregate. default: return errors.Errorf("Unhandled pb.node <%v> with parent <%v>", sg.Attr, parent.Attr) } @@ -1608,6 +1605,31 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su Value: int64(len(sg.SrcUIDs.Uids)), } doneVars[sg.Params.Var].Vals.Set(math.MaxUint64, val) + case sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && len(sg.uidMatrix) > 0 && + len(sg.valueMatrix) > 0: + // A query-side ranker (BM25) binds its per-document relevance score as a + // value variable. We populate BOTH the matched uid set and the uid->score + // map so the variable works with uid(var), val(var) and orderdesc: val(var) + // — surfacing and ordering by score without a pseudo-predicate or a + // ParentVars channel. The valueMatrix is positionally aligned with the + // function's returned uidMatrix[0]. + if v, ok = doneVars[sg.Params.Var]; !ok { + v = varValue{Vals: types.NewShardedMap(), path: sgPath, strList: sg.valueMatrix} + } + v.Uids = sg.DestUIDs + uids := sg.uidMatrix[0].GetUids() + for idx, uid := range uids { + if idx >= len(sg.valueMatrix) || len(sg.valueMatrix[idx].Values) == 0 { + continue + } + tv := sg.valueMatrix[idx].Values[0] + if len(tv.Val) != 8 { + continue + } + score := math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) + v.Vals.Set(uid, types.Val{Tid: types.FloatID, Value: score}) + } + doneVars[sg.Params.Var] = v case len(sg.DestUIDs.Uids) != 0 || (sg.Attr == "uid" && sg.SrcUIDs != nil): // 3. A uid variable. The variable could be defined in one of two places. // a) Either on the actual predicate. @@ -2293,30 +2315,6 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.List = result.List sg.vectorMetrics = result.VectorMetrics - // If this is a BM25 root function, extract scores from ValueMatrix - // and store them in ParentVars for bm25_score pseudo-predicate children. - if sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && len(result.UidMatrix) > 0 && - len(result.ValueMatrix) > 0 { - bm25Scores := types.NewShardedMap() - uids := result.UidMatrix[0].GetUids() - for i, uid := range uids { - if i < len(result.ValueMatrix) && len(result.ValueMatrix[i].Values) > 0 { - tv := result.ValueMatrix[i].Values[0] - if len(tv.Val) == 8 { - score := math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) - bm25Scores.Set(uid, types.Val{ - Tid: types.FloatID, - Value: score, - }) - } - } - } - if sg.Params.ParentVars == nil { - sg.Params.ParentVars = make(map[string]varValue) - } - sg.Params.ParentVars["__bm25_scores__"] = varValue{Vals: bm25Scores} - } - if sg.Params.DoCount { if len(sg.Filters) == 0 { // If there is a filter, we need to do more work to get the actual count. @@ -2415,9 +2413,12 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } if len(sg.Params.Order) == 0 && len(sg.Params.FacetsOrder) == 0 { - // for `has` function when there is no filtering and ordering, we fetch - // correct paginated results so no need to apply pagination here. - if !(len(sg.Filters) == 0 && sg.SrcFunc != nil && sg.SrcFunc.Name == "has") { + // For `has` and `bm25`, the worker already returns correctly paginated + // results (bm25 paginates over score order, which the uid-sorted query-layer + // pagination cannot reproduce), so applying pagination again here would + // double-apply first/offset. Skip it when there is no filtering/ordering. + if !(len(sg.Filters) == 0 && sg.SrcFunc != nil && + (sg.SrcFunc.Name == "has" || sg.SrcFunc.Name == "bm25")) { // There is no ordering. Just apply pagination and return. if err = sg.applyPagination(ctx); err != nil { rch <- err @@ -2495,27 +2496,6 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { child.SrcUIDs = sg.DestUIDs // Make the connection. - // Handle bm25_score pseudo-predicate: populate valueMatrix from parent's - // BM25 scores. Mark IsInternal so populateUidValVar case 4 (value variable) - // fires instead of case 3 (UID variable). - if child.Attr == "bm25_score" { - if bm25Var, ok := child.Params.ParentVars["__bm25_scores__"]; ok && bm25Var.Vals != nil { - child.valueMatrix = make([]*pb.ValueList, len(child.SrcUIDs.GetUids())) - for j, uid := range child.SrcUIDs.GetUids() { - if val, okv := bm25Var.Vals.Get(uid); okv { - child.valueMatrix[j] = &pb.ValueList{ - Values: []*pb.TaskValue{valToTaskValue(val)}, - } - } else { - child.valueMatrix[j] = &pb.ValueList{} - } - } - } - child.DestUIDs = &pb.List{} - child.Params.IsInternal = true - continue - } - if child.IsInternal() { // We dont have to execute these nodes. continue diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go index 1411ad3916e..8a16c31f1a3 100644 --- a/query/query_bm25_test.go +++ b/query/query_bm25_test.go @@ -244,12 +244,10 @@ func TestBM25Pagination(t *testing.T) { } func TestBM25ScoreOrdering(t *testing.T) { - // Use the bm25_score pseudo-predicate with var block to order results by score. + // Bind the bm25 score to a value variable and order results by it via val(). query := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score), first: 1) { uid description_bm25 @@ -267,9 +265,7 @@ func TestBM25ScoreOrderingMultiTerm(t *testing.T) { // since it contains both terms. query := ` { - var(func: bm25(description_bm25, "quick lazy")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "quick lazy")) me(func: uid(score), orderdesc: val(score), first: 1) { uid description_bm25 @@ -285,9 +281,7 @@ func TestBM25ScoreOrderingAllResults(t *testing.T) { // Verify all results are returned in score-descending order via val(score). query := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score)) { uid description_bm25 @@ -307,9 +301,7 @@ func TestBM25ScoreWithPagination(t *testing.T) { // Use offset with score ordering. query := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score), first: 1, offset: 1) { uid description_bm25 @@ -402,9 +394,7 @@ func TestBM25CorpusStatsAffectIDF(t *testing.T) { // Capture baseline score for "fox" query. scoreQuery := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -511,9 +501,7 @@ func TestBM25DocumentDeletion(t *testing.T) { func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { scoreQuery := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -601,9 +589,7 @@ func TestBM25LargeCorpus(t *testing.T) { // Pagination: first:10, offset:40 for alpha should return 10 results. js = processQueryNoErr(t, ` { - var(func: bm25(description_bm25, "alpha")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "alpha")) me(func: uid(score), orderdesc: val(score), first: 10, offset: 40) { uid } @@ -651,9 +637,7 @@ func TestBM25EdgeCaseLongDocument(t *testing.T) { // Get scores for "fox" query. scoreQuery := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -753,9 +737,7 @@ func TestBM25WithUidFilter(t *testing.T) { func TestBM25ScoreValuesAreValidFloats(t *testing.T) { scoreQuery := ` { - var(func: bm25(description_bm25, "fox")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "fox")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -907,9 +889,7 @@ func TestBM25ExactScoreValues(t *testing.T) { // Query "quasar" with b=0 so score depends only on tf, k, and IDF (not avgDL). scoreQuery := ` { - var(func: bm25(description_bm25, "quasar", "1.2", "0")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "quasar", "1.2", "0")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -963,9 +943,7 @@ func TestBM25BM15NoLengthNormalization(t *testing.T) { // Query with b=0: length normalization disabled. scoreQuery := ` { - var(func: bm25(description_bm25, "vortex", "1.2", "0")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "vortex", "1.2", "0")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -989,9 +967,7 @@ func TestBM25BM15NoLengthNormalization(t *testing.T) { // Now verify that with default b=0.75, the shorter doc scores higher. scoreQueryDefault := ` { - var(func: bm25(description_bm25, "vortex")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "vortex")) me(func: uid(score), orderdesc: val(score)) { uid val(score) @@ -1022,9 +998,7 @@ func TestBM25SingleMatchingDocument(t *testing.T) { // Query with b=0 for exact verification. scoreQuery := ` { - var(func: bm25(description_bm25, "aardvark", "1.2", "0")) { - score as bm25_score - } + score as var(func: bm25(description_bm25, "aardvark", "1.2", "0")) me(func: uid(score), orderdesc: val(score)) { uid val(score) diff --git a/worker/bm25wand.go b/worker/bm25wand.go index 07988c845df..c950ffbe3e8 100644 --- a/worker/bm25wand.go +++ b/worker/bm25wand.go @@ -11,291 +11,158 @@ import ( "sort" "github.com/dgraph-io/dgraph/v25/posting" - "github.com/dgraph-io/dgraph/v25/posting/bm25block" - "github.com/dgraph-io/dgraph/v25/posting/bm25enc" - "github.com/dgraph-io/dgraph/v25/x" ) -// listIter iterates over a term's block-based posting list for WAND scoring. -type listIter struct { - attr string - encodedTerm string - readTs uint64 - idf float64 - k, b float64 - - dir *bm25block.Dir - ubPreSuf []float64 // suffix max of UBPre values - blockIdx int // current block index in dir.Blocks - block []bm25enc.Entry // decoded current block - inBlockPos int // position within current block - - exhausted bool - legacy bool // true if using legacy monolithic blob (migration fallback) +// wandBlockSize is the number of postings grouped into one logical block for +// Block-Max WAND upper bounds. The postings come from a standard Dgraph posting +// list (already resident in memory once loaded); these blocks exist only to give +// WAND per-block score bounds for pruning — they are not a storage format. +const wandBlockSize = 128 + +// termCursor is an in-memory cursor over one query term's posting list, +// materialized from the standard posting List as UID-ascending (uid, tf, docLen) +// entries. Document length travels with each posting, so scoring needs no separate +// lookup. Per-block max bounds drive Block-Max WAND pruning. +type termCursor struct { + postings []posting.BM25Posting + idf float64 + pos int + + // blockUBPre[i] is the pre-IDF BM25 upper bound for block i (max term + // frequency, min document length in the block). suffixUBPre[i] = max over + // j >= i of blockUBPre[j], for the remaining-list upper bound. + blockUBPre []float64 + suffixUBPre []float64 } -// newListIter creates a new iterator for a term's block-based posting list. -// Falls back to the legacy monolithic blob format if no block directory exists. -// If dir is non-nil, it is used directly (avoids re-reading from Badger). -func newListIter(attr, encodedTerm string, readTs uint64, idf, k, b float64, dir *bm25block.Dir) *listIter { - if dir == nil { - dirKey := x.BM25TermDirKey(attr, encodedTerm) - dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) - dir = bm25block.DecodeDir(dirBlob) +// ubPre computes the pre-IDF BM25 contribution upper bound for a block, using the +// block's maximum term frequency and minimum document length (the score is +// increasing in tf and decreasing in dl, so this is a safe upper bound). +func ubPre(maxTF, minDL uint32, k, b, avgDL float64) float64 { + if avgDL <= 0 { + avgDL = 1 + } + tf := float64(maxTF) + dl := float64(minDL) + denom := k*(1-b+b*dl/avgDL) + tf + if denom <= 0 { + return 0 } + return (k + 1) * tf / denom +} - if len(dir.Blocks) == 0 { - // Fallback: try reading the legacy monolithic blob and wrap it as a single block. - legacyKey := x.BM25IndexKey(attr, encodedTerm) - legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) - legacyEntries := bm25enc.Decode(legacyBlob) - if len(legacyEntries) == 0 { - return &listIter{exhausted: true} +// newTermCursor builds a cursor and precomputes its per-block upper bounds. +func newTermCursor(postings []posting.BM25Posting, idf, k, b, avgDL float64) *termCursor { + c := &termCursor{postings: postings, idf: idf} + numBlocks := (len(postings) + wandBlockSize - 1) / wandBlockSize + c.blockUBPre = make([]float64, numBlocks) + for blk := 0; blk < numBlocks; blk++ { + start := blk * wandBlockSize + end := start + wandBlockSize + if end > len(postings) { + end = len(postings) } - // Build a synthetic single-block directory from the legacy data. var maxTF uint32 - for _, e := range legacyEntries { - if e.Value > maxTF { - maxTF = e.Value + minDL := uint32(math.MaxUint32) + for i := start; i < end; i++ { + if postings[i].TF > maxTF { + maxTF = postings[i].TF + } + dl := postings[i].DocLen + if dl == 0 { + dl = 1 + } + if dl < minDL { + minDL = dl } } - dir = &bm25block.Dir{ - NextID: 1, - Blocks: []bm25block.BlockMeta{{ - FirstUID: legacyEntries[0].UID, - BlockID: 0, - Count: uint32(len(legacyEntries)), - MaxTF: maxTF, - }}, - } - it := &listIter{ - attr: attr, - encodedTerm: encodedTerm, - readTs: readTs, - idf: idf, - k: k, - b: b, - dir: dir, - ubPreSuf: bm25block.SuffixMaxUBPre(dir, k, b), - blockIdx: 0, - block: legacyEntries, // pre-loaded - inBlockPos: -1, // will advance on first next() - legacy: true, - } - return it + c.blockUBPre[blk] = ubPre(maxTF, minDL, k, b, avgDL) } - - it := &listIter{ - attr: attr, - encodedTerm: encodedTerm, - readTs: readTs, - idf: idf, - k: k, - b: b, - dir: dir, - ubPreSuf: bm25block.SuffixMaxUBPre(dir, k, b), - blockIdx: -1, // will be advanced on first Next() + c.suffixUBPre = make([]float64, numBlocks) + var running float64 + for blk := numBlocks - 1; blk >= 0; blk-- { + if c.blockUBPre[blk] > running { + running = c.blockUBPre[blk] + } + c.suffixUBPre[blk] = running } - return it + return c } -// currentDoc returns the UID at the current position. -func (it *listIter) currentDoc() uint64 { - if it.exhausted || it.block == nil || it.inBlockPos < 0 || it.inBlockPos >= len(it.block) { +func (c *termCursor) exhausted() bool { return c.pos >= len(c.postings) } + +func (c *termCursor) currentDoc() uint64 { + if c.exhausted() { return math.MaxUint64 } - return it.block[it.inBlockPos].UID + return c.postings[c.pos].Uid } -// currentTF returns the term frequency at the current position. -func (it *listIter) currentTF() uint32 { - if it.exhausted || it.block == nil || it.inBlockPos < 0 || it.inBlockPos >= len(it.block) { +func (c *termCursor) currentTF() uint32 { + if c.exhausted() { return 0 } - return it.block[it.inBlockPos].Value + return c.postings[c.pos].TF } -// remainingUB returns the IDF-weighted upper-bound score for the remaining postings. -func (it *listIter) remainingUB() float64 { - if it.exhausted || len(it.ubPreSuf) == 0 { - return 0 - } - idx := it.blockIdx - if idx < 0 { - idx = 0 - } - if idx >= len(it.ubPreSuf) { +func (c *termCursor) currentDocLen() uint32 { + if c.exhausted() { return 0 } - return it.idf * it.ubPreSuf[idx] + return c.postings[c.pos].DocLen } -// blockUB returns the IDF-weighted upper-bound for the current block only. -func (it *listIter) blockUB() float64 { - if it.exhausted || it.blockIdx < 0 || it.blockIdx >= len(it.dir.Blocks) { +// remainingUB returns the IDF-weighted upper-bound score over the remainder of the +// list from the current position. +func (c *termCursor) remainingUB() float64 { + if c.exhausted() || len(c.suffixUBPre) == 0 { return 0 } - return it.idf * bm25block.ComputeUBPre(it.dir.Blocks[it.blockIdx].MaxTF, it.k, it.b) -} - -// next advances to the next posting. Returns false if exhausted. -func (it *listIter) next() bool { - if it.exhausted { - return false - } - - // Try advancing within the current block. - if it.block != nil { - it.inBlockPos++ - if it.inBlockPos >= 0 && it.inBlockPos < len(it.block) { - return true - } + blk := c.pos / wandBlockSize + if blk >= len(c.suffixUBPre) { + return 0 } + return c.idf * c.suffixUBPre[blk] +} - // Move to the next block. - for { - it.blockIdx++ - if it.blockIdx >= len(it.dir.Blocks) { - it.exhausted = true - return false - } - it.loadBlock(it.blockIdx) - if len(it.block) > 0 { - return true - } - // Empty block (corruption/race): skip it. - } +// next advances by one posting. +func (c *termCursor) next() bool { + c.pos++ + return !c.exhausted() } // skipTo advances to the first posting with UID >= target. -// Returns false if exhausted. -func (it *listIter) skipTo(target uint64) bool { - if it.exhausted { +func (c *termCursor) skipTo(target uint64) bool { + if c.exhausted() { return false } - - // If current doc is already >= target, no-op. - if it.block != nil && it.inBlockPos >= 0 && it.inBlockPos < len(it.block) && - it.block[it.inBlockPos].UID >= target { + if c.postings[c.pos].Uid >= target { return true } - - // Check if target might be in the current block. - if it.block != nil && len(it.block) > 0 && it.blockIdx >= 0 && - it.blockIdx < len(it.dir.Blocks) { - lastInBlock := it.block[len(it.block)-1].UID - if target <= lastInBlock { - startPos := it.inBlockPos - if startPos < 0 { - startPos = 0 - } else if startPos > len(it.block) { - startPos = len(it.block) - } - // Binary search within current block from startPos. - pos := sort.Search(len(it.block)-startPos, func(i int) bool { - return it.block[startPos+i].UID >= target - }) - it.inBlockPos = startPos + pos - if it.inBlockPos < len(it.block) { - return true - } - } - } - - // Find the right block using the directory. - blockIdx := it.findBlockForTarget(target) - if blockIdx >= len(it.dir.Blocks) { - it.exhausted = true - return false - } - - it.blockIdx = blockIdx - it.loadBlock(blockIdx) - if len(it.block) == 0 { - return it.next() // skip empty block - } - - // Binary search within the block. - pos := sort.Search(len(it.block), func(i int) bool { - return it.block[i].UID >= target + rel := sort.Search(len(c.postings)-c.pos, func(i int) bool { + return c.postings[c.pos+i].Uid >= target }) - it.inBlockPos = pos - if pos >= len(it.block) { - // Target is beyond this block; try the next. - return it.next() - } - return true + c.pos += rel + return !c.exhausted() } -// skipToWithBMW is like skipTo but uses Block-Max WAND to skip entire blocks -// whose upper bounds can't beat the given threshold. -func (it *listIter) skipToWithBMW(target uint64, theta float64, otherUB float64) bool { - if it.exhausted { +// skipToWithBMW is skipTo with Block-Max WAND pruning: blocks whose upper bound +// combined with otherUB cannot beat theta are skipped wholesale. +func (c *termCursor) skipToWithBMW(target uint64, theta, otherUB float64) bool { + if !c.skipTo(target) { return false } - - // If current doc is already >= target, no-op. - if it.block != nil && it.inBlockPos >= 0 && it.inBlockPos < len(it.block) && - it.block[it.inBlockPos].UID >= target { - return true - } - - blockIdx := it.findBlockForTarget(target) - for blockIdx < len(it.dir.Blocks) { - // Check if this block's UB combined with other terms can beat theta. - blockUB := it.idf * bm25block.ComputeUBPre(it.dir.Blocks[blockIdx].MaxTF, it.k, it.b) - if blockUB+otherUB > theta { - // This block might have a winner; load and search it. - it.blockIdx = blockIdx - it.loadBlock(blockIdx) - if len(it.block) == 0 { - blockIdx++ - continue // skip empty block - } - pos := sort.Search(len(it.block), func(i int) bool { - return it.block[i].UID >= target - }) - it.inBlockPos = pos - if pos < len(it.block) { - return true - } - // Fall through to next block. - } - blockIdx++ - // Update target to the next block's firstUID. - if blockIdx < len(it.dir.Blocks) { - target = it.dir.Blocks[blockIdx].FirstUID + for !c.exhausted() { + blk := c.pos / wandBlockSize + if c.idf*c.blockUBPre[blk]+otherUB > theta { + return true } + // This block can't produce a winner; jump to the start of the next block. + c.pos = (blk + 1) * wandBlockSize } - it.exhausted = true return false } -// findBlockForTarget returns the block index that should contain target. -func (it *listIter) findBlockForTarget(target uint64) int { - blocks := it.dir.Blocks - idx := sort.Search(len(blocks), func(i int) bool { - return blocks[i].FirstUID > target - }) - if idx > 0 { - return idx - 1 - } - return 0 -} - -// loadBlock decodes the block at the given directory index. -func (it *listIter) loadBlock(idx int) { - if it.legacy { - // Legacy mode: single pre-loaded block; don't reset position. - return - } - bm := it.dir.Blocks[idx] - blockKey := x.BM25TermBlockKey(it.attr, it.encodedTerm, bm.BlockID) - blob := posting.ReadBM25BlobAt(blockKey, it.readTs) - it.block = bm25enc.Decode(blob) - it.inBlockPos = 0 -} - // scoredDoc holds a UID and its BM25 score for the min-heap. type scoredDoc struct { uid uint64 @@ -328,19 +195,16 @@ func (h *topKHeap) threshold() float64 { return h.docs[0].score } -// tryPush adds a doc if it beats the current threshold. Returns true if the -// threshold changed. -func (h *topKHeap) tryPush(uid uint64, score float64) bool { +// tryPush adds a doc if it beats the current threshold. +func (h *topKHeap) tryPush(uid uint64, score float64) { if len(h.docs) < h.k { heap.Push(h, scoredDoc{uid: uid, score: score}) - return len(h.docs) == h.k // threshold only meaningful once heap is full + return } if score > h.docs[0].score { h.docs[0] = scoredDoc{uid: uid, score: score} heap.Fix(h, 0) - return true } - return false } // sorted returns all docs sorted by score descending, then UID ascending. @@ -356,219 +220,149 @@ func (h *topKHeap) sorted() []scoredDoc { return result } -// bm25Score computes the BM25 score for a single term occurrence. +// bm25Score computes the BM25 contribution of a single term occurrence. func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { - return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) -} - -// docLenCache caches document length lookups within a single query to avoid -// repeated Badger reads for the same doclen block directory and blocks. -type docLenCache struct { - attr string - readTs uint64 - dir *bm25block.Dir - loaded bool - legacy bool - // Per-block cache: blockIdx -> decoded entries. - blocks map[int][]bm25enc.Entry - // Legacy entries (when using monolithic blob). - legacyEntries []bm25enc.Entry -} - -func newDocLenCache(attr string, readTs uint64) *docLenCache { - return &docLenCache{ - attr: attr, - readTs: readTs, - blocks: make(map[int][]bm25enc.Entry), - } -} - -func (c *docLenCache) ensureLoaded() { - if c.loaded { - return + if avgDL <= 0 { + avgDL = 1 } - c.loaded = true - dirKey := x.BM25DocLenDirKey(c.attr) - dirBlob := posting.ReadBM25BlobAt(dirKey, c.readTs) - c.dir = bm25block.DecodeDir(dirBlob) - if len(c.dir.Blocks) == 0 { - // Try legacy. - legacyKey := x.BM25DocLenKey(c.attr) - legacyBlob := posting.ReadBM25BlobAt(legacyKey, c.readTs) - c.legacyEntries = bm25enc.Decode(legacyBlob) - c.legacy = true + if dl <= 0 { + dl = 1 } + return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) } -func (c *docLenCache) lookup(uid uint64) float64 { - c.ensureLoaded() - if c.legacy { - if v, ok := bm25enc.Search(c.legacyEntries, uid); ok { - return float64(v) - } - return 1.0 - } - if len(c.dir.Blocks) == 0 { - return 1.0 - } - blockIdx := c.dir.FindBlock(uid) - entries, ok := c.blocks[blockIdx] - if !ok { - bm := c.dir.Blocks[blockIdx] - blockKey := x.BM25DocLenBlockKey(c.attr, bm.BlockID) - blob := posting.ReadBM25BlobAt(blockKey, c.readTs) - entries = bm25enc.Decode(blob) - c.blocks[blockIdx] = entries - } - if v, ok := bm25enc.Search(entries, uid); ok { - return float64(v) - } - return 1.0 -} - -// wandSearch performs a WAND top-k search over block-based posting lists. -// If topK <= 0, it scores all matching documents (no early termination). -func wandSearch(attr string, readTs uint64, queryTokens []string, - k, b, avgDL, N float64, topK int, filterSet map[uint64]struct{}, - useBMW bool) []scoredDoc { - - dlCache := newDocLenCache(attr, readTs) +// wandSearch performs a WAND / Block-Max WAND top-k BM25 search over standard +// posting lists. queryTokens must already carry the BM25 tokenizer identifier +// byte. getList reads a posting list for a key. If topK <= 0, every matching +// document is scored (no early termination). +func wandSearch(getList func(key []byte) (*posting.List, error), attr string, readTs uint64, + queryTokens []string, k, b, avgDL, N float64, topK int, + filterSet map[uint64]struct{}, useBMW bool) ([]scoredDoc, error) { - // Build iterators for each query term. - var iters []*listIter + var cursors []*termCursor for _, token := range queryTokens { - // Compute df: try block directory first, then fall back to legacy blob. - var df uint64 - dirKey := x.BM25TermDirKey(attr, token) - dirBlob := posting.ReadBM25BlobAt(dirKey, readTs) - dir := bm25block.DecodeDir(dirBlob) - if len(dir.Blocks) > 0 { - for _, bm := range dir.Blocks { - df += uint64(bm.Count) - } - } else { - // Legacy fallback: read just the count header to get df. - // Avoids decoding the full posting list (which could be huge for common terms). - legacyKey := x.BM25IndexKey(attr, token) - legacyBlob := posting.ReadBM25BlobAt(legacyKey, readTs) - df = uint64(bm25enc.DecodeCount(legacyBlob)) + postings, err := posting.ReadBM25TermPostings(getList, attr, token, readTs) + if err != nil { + return nil, err } + df := uint64(len(postings)) if df == 0 { continue } - idf := math.Log1p((N - float64(df) + 0.5) / (float64(df) + 0.5)) - - it := newListIter(attr, token, readTs, idf, k, b, dir) - if !it.exhausted { - it.next() // prime the iterator - if !it.exhausted { - iters = append(iters, it) - } + // N comes from bucketed stats and df from the term's posting list; if stats + // ever lag the postings, clamp N >= df for this term so the smoothed IDF + // stays non-negative and finite instead of producing a negative/NaN score. + dfN := float64(df) + nDocs := N + if nDocs < dfN { + nDocs = dfN } + idf := math.Log1p((nDocs - dfN + 0.5) / (dfN + 0.5)) + cursors = append(cursors, newTermCursor(postings, idf, k, b, avgDL)) } - if len(iters) == 0 { - return nil + if len(cursors) == 0 { + return nil, nil } - // If no top-k limit, score all matching documents. if topK <= 0 { - return scoreAllDocs(iters, dlCache, k, b, avgDL, filterSet) + return scoreAllDocs(cursors, k, b, avgDL, filterSet), nil } + return wandTopK(cursors, k, b, avgDL, topK, filterSet, useBMW), nil +} + +// wandTopK runs the WAND / Block-Max WAND main loop over prepared cursors and +// returns the top-k documents sorted by score descending. It is the core scoring +// loop, separated from posting-list I/O so it can be exercised directly. +func wandTopK(cursors []*termCursor, k, b, avgDL float64, topK int, + filterSet map[uint64]struct{}, useBMW bool) []scoredDoc { - // WAND algorithm with top-k heap. h := &topKHeap{k: topK} heap.Init(h) for { - // Remove exhausted iterators. - active := iters[:0] - for _, it := range iters { - if !it.exhausted { - active = append(active, it) + // Drop exhausted cursors. + active := cursors[:0] + for _, c := range cursors { + if !c.exhausted() { + active = append(active, c) } } - iters = active - if len(iters) == 0 { + cursors = active + if len(cursors) == 0 { break } - // Sort iterators by currentDoc ascending. - sort.Slice(iters, func(i, j int) bool { - return iters[i].currentDoc() < iters[j].currentDoc() + // Sort cursors by current document ascending. + sort.Slice(cursors, func(i, j int) bool { + return cursors[i].currentDoc() < cursors[j].currentDoc() }) theta := h.threshold() - // Find pivot: accumulate UBs until they exceed theta. + // Find pivot: accumulate upper bounds until they exceed theta. var sumUB float64 pivot := -1 var pivotDoc uint64 - for i, it := range iters { - sumUB += it.remainingUB() + for i, c := range cursors { + sumUB += c.remainingUB() if sumUB > theta && pivot == -1 { pivot = i - pivotDoc = it.currentDoc() + pivotDoc = c.currentDoc() } } - // sumUB now contains the total UB across ALL iterators (needed for BMW). if pivot == -1 { - break // sum of all UBs can't beat theta + break // sum of all upper bounds can't beat theta } - // Advance all iterators before pivot to pivotDoc. + // Advance all cursors before the pivot up to pivotDoc. allAtPivot := true for i := 0; i < pivot; i++ { - if iters[i].currentDoc() < pivotDoc { + if cursors[i].currentDoc() < pivotDoc { var ok bool if useBMW { - // Compute otherUB = total UB - this iter's UB (O(1) instead of O(q)). - otherUB := sumUB - iters[i].remainingUB() - ok = iters[i].skipToWithBMW(pivotDoc, theta, otherUB) + otherUB := sumUB - cursors[i].remainingUB() + ok = cursors[i].skipToWithBMW(pivotDoc, theta, otherUB) } else { - ok = iters[i].skipTo(pivotDoc) + ok = cursors[i].skipTo(pivotDoc) } if !ok { allAtPivot = false break } - if iters[i].currentDoc() != pivotDoc { + if cursors[i].currentDoc() != pivotDoc { allAtPivot = false } } } - if !allAtPivot { - continue // re-evaluate after advances + continue } - // All iterators up to pivot are at pivotDoc. Score the candidate. + // Score the pivot document. if filterSet != nil { if _, ok := filterSet[pivotDoc]; !ok { - // Skip this doc (filtered out). Advance all iters at pivotDoc. - for _, it := range iters { - if it.currentDoc() == pivotDoc { - it.next() + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + c.next() } } continue } } - dl := dlCache.lookup(pivotDoc) var score float64 - for _, it := range iters { - if it.currentDoc() == pivotDoc { - tf := float64(it.currentTF()) - score += bm25Score(it.idf, tf, dl, avgDL, k, b) + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + dl := float64(c.currentDocLen()) + score += bm25Score(c.idf, float64(c.currentTF()), dl, avgDL, k, b) } } h.tryPush(pivotDoc, score) - // Advance all iterators at pivotDoc. - for _, it := range iters { - if it.currentDoc() == pivotDoc { - it.next() + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + c.next() } } } @@ -576,43 +370,32 @@ func wandSearch(attr string, readTs uint64, queryTokens []string, return h.sorted() } -// scoreAllDocs scores every matching document without early termination. -// Used when no top-k limit is specified (the original behavior). -func scoreAllDocs(iters []*listIter, dlCache *docLenCache, - k, b, avgDL float64, filterSet map[uint64]struct{}) []scoredDoc { +// scoreAllDocs scores every matching document without early termination. Used when +// no top-k limit is specified. +func scoreAllDocs(cursors []*termCursor, k, b, avgDL float64, + filterSet map[uint64]struct{}) []scoredDoc { - // Collect all (uid, term) matches. - type termMatch struct { - idf float64 - tf uint32 - } - matches := make(map[uint64][]termMatch) - - for _, it := range iters { - for !it.exhausted { - uid := it.currentDoc() - tf := it.currentTF() - if filterSet == nil { - matches[uid] = append(matches[uid], termMatch{idf: it.idf, tf: tf}) - } else if _, ok := filterSet[uid]; ok { - matches[uid] = append(matches[uid], termMatch{idf: it.idf, tf: tf}) + scores := make(map[uint64]float64) + + for _, c := range cursors { + for !c.exhausted() { + uid := c.currentDoc() + if filterSet != nil { + if _, ok := filterSet[uid]; !ok { + c.next() + continue + } } - it.next() + scores[uid] += bm25Score(c.idf, float64(c.currentTF()), float64(c.currentDocLen()), + avgDL, k, b) + c.next() } } - // Score all matching documents. - results := make([]scoredDoc, 0, len(matches)) - for uid, terms := range matches { - dl := dlCache.lookup(uid) - var score float64 - for _, tm := range terms { - score += bm25Score(tm.idf, float64(tm.tf), dl, avgDL, k, b) - } - results = append(results, scoredDoc{uid: uid, score: score}) + results := make([]scoredDoc, 0, len(scores)) + for uid, s := range scores { + results = append(results, scoredDoc{uid: uid, score: s}) } - - // Sort by score descending, then UID ascending. sort.Slice(results, func(i, j int) bool { if results[i].score != results[j].score { return results[i].score > results[j].score diff --git a/worker/bm25wand_test.go b/worker/bm25wand_test.go index 5982f94d0b8..8f133a6ec30 100644 --- a/worker/bm25wand_test.go +++ b/worker/bm25wand_test.go @@ -8,9 +8,13 @@ package worker import ( "container/heap" "math" + "math/rand" + "sort" "testing" "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/posting" ) func TestTopKHeapBasic(t *testing.T) { @@ -94,3 +98,103 @@ func TestBm25ScoreNaN(t *testing.T) { require.False(t, math.IsInf(score, 0)) require.Greater(t, score, 0.0) } + +// brute force scores every doc across all cursors (ground truth for WAND). +func bruteForceTopK(termPostings [][]posting.BM25Posting, idfs []float64, + k, b, avgDL float64, topK int) []scoredDoc { + scores := map[uint64]float64{} + dls := map[uint64]uint32{} + for ti, ps := range termPostings { + for _, p := range ps { + scores[p.Uid] += bm25Score(idfs[ti], float64(p.TF), float64(p.DocLen), avgDL, k, b) + dls[p.Uid] = p.DocLen + } + } + out := make([]scoredDoc, 0, len(scores)) + for uid, s := range scores { + out = append(out, scoredDoc{uid: uid, score: s}) + } + sort.Slice(out, func(i, j int) bool { + if out[i].score != out[j].score { + return out[i].score > out[j].score + } + return out[i].uid < out[j].uid + }) + if topK > 0 && len(out) > topK { + out = out[:topK] + } + return out +} + +// TestWandMatchesBruteForce checks that WAND and Block-Max WAND return exactly the +// same top-k documents and scores as exhaustive scoring, across many randomized +// posting lists. This is the core correctness guarantee: pruning must never change +// the result, only the work done. +func TestWandMatchesBruteForce(t *testing.T) { + rng := rand.New(rand.NewSource(42)) + k, b, avgDL := 1.2, 0.75, 12.0 + + for trial := 0; trial < 200; trial++ { + numTerms := 1 + rng.Intn(4) + termPostings := make([][]posting.BM25Posting, numTerms) + idfs := make([]float64, numTerms) + for ti := 0; ti < numTerms; ti++ { + n := rng.Intn(400) // spans multiple wandBlockSize blocks + seen := map[uint64]bool{} + var ps []posting.BM25Posting + for j := 0; j < n; j++ { + uid := uint64(1 + rng.Intn(500)) + if seen[uid] { + continue + } + seen[uid] = true + ps = append(ps, posting.BM25Posting{ + Uid: uid, + TF: uint32(1 + rng.Intn(10)), + DocLen: uint32(1 + rng.Intn(30)), + }) + } + sort.Slice(ps, func(i, j int) bool { return ps[i].Uid < ps[j].Uid }) + termPostings[ti] = ps + // Vary IDF per term so different terms carry different weight. + idfs[ti] = 0.5 + rng.Float64()*2 + } + + topK := 1 + rng.Intn(10) + want := bruteForceTopK(termPostings, idfs, k, b, avgDL, topK) + // One extra result lets us detect a tie between the cutoff rank and the + // first excluded document (a boundary tie outside the top-k window). + wantPlus := bruteForceTopK(termPostings, idfs, k, b, avgDL, topK+1) + + build := func() []*termCursor { + cs := make([]*termCursor, 0, numTerms) + for ti, ps := range termPostings { + if len(ps) == 0 { + continue + } + cs = append(cs, newTermCursor(ps, idfs[ti], k, b, avgDL)) + } + return cs + } + + for _, useBMW := range []bool{false, true} { + got := wandTopK(build(), k, b, avgDL, topK, nil, useBMW) + require.Lenf(t, got, len(want), "trial %d bmw=%v len", trial, useBMW) + for i := range want { + // The score at each rank must match exactly: WAND/BMW pruning must + // never change which scores make the top-k, only the work done. + require.InEpsilonf(t, want[i].score, got[i].score, 1e-9, + "trial %d bmw=%v rank %d score", trial, useBMW, i) + // The uid is only guaranteed when this rank's score is not tied with + // a neighbor (including the first excluded doc); tied-boundary docs + // are interchangeable in the ranking. + tied := (i > 0 && wantPlus[i].score == wantPlus[i-1].score) || + (i+1 < len(wantPlus) && wantPlus[i].score == wantPlus[i+1].score) + if !tied { + require.Equalf(t, want[i].uid, got[i].uid, + "trial %d bmw=%v rank %d uid", trial, useBMW, i) + } + } + } + } +} diff --git a/worker/task.go b/worker/task.go index 0345e9e75f8..bcb41b903d2 100644 --- a/worker/task.go +++ b/worker/task.go @@ -30,8 +30,6 @@ import ( "github.com/dgraph-io/dgraph/v25/algo" "github.com/dgraph-io/dgraph/v25/conn" "github.com/dgraph-io/dgraph/v25/posting" - "github.com/dgraph-io/dgraph/v25/posting/bm25enc" - // bm25block and bm25wand are used via bm25wand.go in this package. "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" ctask "github.com/dgraph-io/dgraph/v25/task" @@ -1263,7 +1261,8 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return errors.Errorf("bm25: b must be between 0 and 1, got %v", b) } - // 2. Tokenize query (deduplicated) using fulltext pipeline. + // 2. Tokenize query (deduplicated) using the fulltext pipeline. The returned + // tokens already carry the BM25 tokenizer identifier byte. lang := langForFunc(q.Langs) queryTokens, err := tok.GetBM25QueryTokens([]string{queryText}, lang) if err != nil { @@ -1274,10 +1273,11 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error return nil } - // 3. Read corpus stats. - statsKey := x.BM25StatsKey(attr) - statsBlob := posting.ReadBM25BlobAt(statsKey, q.ReadTs) - docCount, totalTerms := bm25enc.DecodeStats(statsBlob) + // 3. Read bucketed corpus stats and derive N and the average document length. + docCount, totalTerms, err := posting.ReadBM25Stats(qs.cache.Get, attr, q.ReadTs) + if err != nil { + return err + } if docCount == 0 || totalTerms == 0 { args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) return nil @@ -1285,7 +1285,7 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error avgDL := float64(totalTerms) / float64(docCount) N := float64(docCount) - // Build filter set if used as a filter. + // Build a filter set if bm25 is used as a filter (@filter(bm25(...))). var filterSet map[uint64]struct{} if q.UidList != nil && len(q.UidList.Uids) > 0 { filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) @@ -1294,17 +1294,21 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // 4. Determine top-k: use WAND when first is set and no offset. - // When offset is set or first is unset, score all documents. + // 4. Use WAND top-k early termination only when first is set without an offset; + // otherwise score all matching documents and paginate afterwards. topK := 0 if q.First > 0 && q.Offset == 0 { topK = int(q.First) } - // 5. Run WAND search over block-based posting lists (with Block-Max skipping). - results := wandSearch(attr, q.ReadTs, queryTokens, k, b, avgDL, N, topK, filterSet, true) + // 5. Run WAND / Block-Max WAND over the standard posting lists. + results, err := wandSearch(qs.cache.Get, attr, q.ReadTs, queryTokens, k, b, avgDL, N, + topK, filterSet, true) + if err != nil { + return err + } - // 6. Apply first/offset pagination on score-sorted results. + // 6. Paginate score-sorted results when WAND did not already top-k them. if topK <= 0 && (q.First > 0 || q.Offset > 0) { offset := int(q.Offset) if offset > len(results) { @@ -1316,10 +1320,9 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } } - // 7. Build output: UIDs sorted ascending (required by query pipeline) - // and ValueMatrix with aligned scores (for bm25_score pseudo-predicate). - // We use a single pre-allocated buffer for all score encodings to reduce - // per-result heap allocations. + // 7. Emit UIDs ascending (required by the query pipeline) with positionally- + // aligned scores in ValueMatrix; the query layer binds these to a value + // variable so callers order by and project the score via val(). sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) uids := make([]uint64, len(results)) for i, r := range results { @@ -1327,18 +1330,15 @@ func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error } args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) - // Encode scores into ValueMatrix. Each entry in ValueMatrix corresponds - // positionally to a UID in UidMatrix[0], enabling the bm25_score - // pseudo-predicate in query.go to map UIDs to scores. scoreBuf := make([]byte, len(results)*8) scoreValues := make([]*pb.ValueList, len(results)) for i, r := range results { off := i * 8 binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(r.score)) - // Use three-index slice to cap capacity at 8, preventing any downstream - // append from corrupting adjacent scores in the shared backing array. + // Three-index slice caps capacity at 8 so a downstream append can't corrupt + // adjacent scores in the shared backing array. scoreValues[i] = &pb.ValueList{ - Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_ValType(pb.Posting_FLOAT)}}, + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_FLOAT}}, } } args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) diff --git a/x/keys.go b/x/keys.go index 0a23ba19c6a..a88dc4640a1 100644 --- a/x/keys.go +++ b/x/keys.go @@ -291,47 +291,28 @@ func CountKey(attr string, count uint32, reverse bool) []byte { return buf } -// BM25Prefix is the prefix used for BM25 index keys to prevent collision -// with regular fulltext index tokens. -const BM25Prefix = "\x00_bm25_" - -// BM25IndexKey generates an index key for a BM25 term posting list. -func BM25IndexKey(attr string, token string) []byte { - return IndexKey(attr, BM25Prefix+token) -} - -// BM25DocLenKey generates the key for the BM25 document length posting list. -func BM25DocLenKey(attr string) []byte { - return IndexKey(attr, BM25Prefix+"__doclen__") -} - -// BM25StatsKey generates the key for BM25 corpus statistics. -func BM25StatsKey(attr string) []byte { - return IndexKey(attr, BM25Prefix+"__stats__") -} - -// BM25TermDirKey generates the key for a BM25 term's block directory. -func BM25TermDirKey(attr, term string) []byte { - return IndexKey(attr, BM25Prefix+"__dir__"+term) -} - -// BM25TermBlockKey generates the key for an individual BM25 term posting block. -func BM25TermBlockKey(attr, term string, blockID uint32) []byte { - var buf [4]byte - binary.BigEndian.PutUint32(buf[:], blockID) - return IndexKey(attr, BM25Prefix+"__blk__"+term+string(buf[:])) -} - -// BM25DocLenDirKey generates the key for the BM25 document-length block directory. -func BM25DocLenDirKey(attr string) []byte { - return IndexKey(attr, BM25Prefix+"__dldir__") -} - -// BM25DocLenBlockKey generates the key for an individual BM25 document-length block. -func BM25DocLenBlockKey(attr string, segID uint32) []byte { - var buf [4]byte - binary.BigEndian.PutUint32(buf[:], segID) - return IndexKey(attr, BM25Prefix+"__dlblk__"+string(buf[:])) +// BM25IndexKey generates the index key for a BM25 term posting list. The +// encodedToken already carries the BM25 tokenizer identifier byte, so BM25 term +// postings live at the same standard index key as every other tokenizer — +// IndexKey(attr, identifier || term) — and inherit rollup, splits, backup, and +// index-rebuild handling for free. This is a thin alias of IndexKey so the index +// write path and the query read path share one definition. +func BM25IndexKey(attr string, encodedToken string) []byte { + return IndexKey(attr, encodedToken) +} + +// bm25StatsPrefix namespaces the BM25 corpus-statistics keys. These hold the +// document count and total term count (used to derive the average document +// length); they are auxiliary metadata, not term postings, so they use a reserved +// token that cannot collide with any stemmed BM25 term. +const bm25StatsPrefix = "\x00_bm25stats_" + +// BM25StatsKey generates the key for one bucket of BM25 corpus statistics. Stats +// are sharded across buckets (keyed by uid%numBuckets) to spread write contention. +func BM25StatsKey(attr string, bucket int) []byte { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], uint16(bucket)) + return IndexKey(attr, bm25StatsPrefix+string(buf[:])) } // ParsedKey represents a key that has been parsed into its multiple attributes. From 1ec5cc1496099b117c188542e8aeb5cd0c09a02a Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 3 Jun 2026 21:26:25 -0400 Subject: [PATCH 14/19] fix(bm25): accumulate corpus stats across transactions updateBM25Stats maintains the bucketed (docCount, totalTerms) counters via read-modify-write, but read them with LocalCache.GetFromDelta, which skips disk and returns only the current transaction's in-memory delta. Across separately committed mutations each transaction therefore started from zero and overwrote its stats bucket instead of accumulating, collapsing the corpus document count (e.g. to the per-term df). Since avgDL = totalTerms/docCount and idf depends on N = docCount, every length- and idf-weighted BM25 score was wrong, while result ordering (a constant idf factor for a single-term query) still looked correct. Read the committed stats with LocalCache.Get instead. Term postings are unaffected: they are additive deltas that merge through the normal posting-list mechanism and never read-modify-write. Found by the BM25 integration tests (exact-score and uid-filter cases). The prior unit test only exercised a single transaction, where read-your-own-writes masked the bug; add TestBM25StatsAccumulateAcrossTxns covering the multi- transaction path. Co-Authored-By: Claude Opus 4.8 (1M context) --- posting/bm25.go | 7 ++++++- posting/bm25_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/posting/bm25.go b/posting/bm25.go index a72fc7b72cd..9dcedc91aff 100644 --- a/posting/bm25.go +++ b/posting/bm25.go @@ -151,7 +151,12 @@ func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, uid uint64, docCountDelta, totalTermsDelta int64) error { bucket := int(uid % numBM25StatsBuckets) key := x.BM25StatsKey(attr, bucket) - plist, err := txn.cache.GetFromDelta(key) + // Stats are maintained by read-modify-write: we must read the committed total + // from disk (and merge this transaction's own writes), not just the in-memory + // delta. GetFromDelta skips disk and is only safe for write-only index mutations, + // so each transaction would otherwise overwrite the bucket instead of + // accumulating across transactions. Get reads committed state. + plist, err := txn.cache.Get(key) if err != nil { return err } diff --git a/posting/bm25_test.go b/posting/bm25_test.go index 82a4f712d60..169c3be437e 100644 --- a/posting/bm25_test.go +++ b/posting/bm25_test.go @@ -138,3 +138,37 @@ func TestBM25StatsBucketed(t *testing.T) { require.Equal(t, uint64(wantCount-1), dc) require.Equal(t, uint64(wantTerms-20), tt) } + +// TestBM25StatsAccumulateAcrossTxns verifies that stats accumulate across +// separately-committed transactions (not just within one). This guards against +// the read-modify-write reading only the in-memory delta instead of committed +// disk state, which would make each transaction overwrite its bucket and collapse +// the corpus document count. +func TestBM25StatsAccumulateAcrossTxns(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25statsxtxn") + + // Two documents in the SAME bucket (uid 5 and uid 37 → bucket 5), committed in + // two separate transactions. + commitDoc := func(startTs, commitTs, uid uint64, docLen int64) { + txn := Oracle().RegisterStartTs(startTs) + txn.cache = NewLocalCache(startTs) + require.NoError(t, txn.updateBM25Stats(ctx, attr, uid, 1, docLen)) + txn.Update() + txn.UpdateCachedKeys(commitTs) + writer := NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + } + + commitDoc(201, 202, 5, 10) + commitDoc(203, 204, 37, 6) + + // A fresh reader at a later ts must see BOTH documents (count 2, terms 16), + // not just the most recently committed one. + get := func(k []byte) (*List, error) { return GetNoStore(k, 205) } + dc, tt, err := ReadBM25Stats(get, attr, 205) + require.NoError(t, err) + require.Equal(t, uint64(2), dc, "doc count must accumulate across transactions") + require.Equal(t, uint64(16), tt, "total terms must accumulate across transactions") +} From 235c5eb38b4ed763521d819f16cf9e7467d3f86e Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 3 Jun 2026 21:39:53 -0400 Subject: [PATCH 15/19] fix(bm25): reject @index(bm25) on list predicates BM25 scores a single document (one value) per UID, so per-document length and corpus statistics are ill-defined for a list predicate. The bucketed stats also rely on conflict detection that a list predicate's value-dependent conflict key would not provide (a code-review concern about stats integrity on list/ @noconflict predicates). Reject the combination in checkSchema rather than silently mis-scoring. Co-Authored-By: Claude Opus 4.8 (1M context) --- worker/mutation.go | 13 +++++++++++++ worker/mutation_integration_test.go | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/worker/mutation.go b/worker/mutation.go index fdac2a41c1b..076beed185f 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -410,6 +410,19 @@ func checkSchema(s *pb.SchemaUpdate) error { x.ParseAttr(s.Predicate)) } + // BM25 scores a single document (one value) per UID: per-document length and + // corpus statistics are not well-defined for a list predicate, and the bucketed + // stats maintenance relies on conflict detection that a list predicate's + // value-dependent conflict key would not provide. Reject @index(bm25) on lists. + if s.List { + for _, tokenizer := range s.Tokenizer { + if tokenizer == "bm25" { + return errors.Errorf("Tokenizer 'bm25' cannot be applied to list predicate: %s", + x.ParseAttr(s.Predicate)) + } + } + } + // If schema update has upsert directive, it should have index directive. if s.Upsert && len(s.Tokenizer) == 0 && !s.Unique { return errors.Errorf("Index tokenizer is mandatory for: [%s] when specifying @upsert directive", diff --git a/worker/mutation_integration_test.go b/worker/mutation_integration_test.go index 99a2a1eed01..f1f4b81695b 100644 --- a/worker/mutation_integration_test.go +++ b/worker/mutation_integration_test.go @@ -93,6 +93,25 @@ func TestCheckSchema(t *testing.T) { } require.NoError(t, checkSchema(s1)) + // bm25 on a scalar string predicate is allowed. + s1 = &pb.SchemaUpdate{ + Predicate: x.AttrInRootNamespace("bio"), + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"bm25"}, + } + require.NoError(t, checkSchema(s1)) + + // bm25 on a list predicate is rejected. + s1 = &pb.SchemaUpdate{ + Predicate: x.AttrInRootNamespace("tags"), + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"bm25"}, + List: true, + } + require.Error(t, checkSchema(s1)) + s1 = &pb.SchemaUpdate{ Predicate: x.AttrInRootNamespace("friend"), ValueType: pb.Posting_UID, From 7b063f0c672f13b0421007de17f691cfd5dbffb6 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 3 Jun 2026 21:47:50 -0400 Subject: [PATCH 16/19] chore(bm25): drop working design notes The redesign rationale (including the doc-length storage decision) lives in code comments and the PR description; the planning doc does not belong in the tree. Co-Authored-By: Claude Opus 4.8 (1M context) --- bm25-redesign-plan.md | 74 ------------------------------------------- 1 file changed, 74 deletions(-) delete mode 100644 bm25-redesign-plan.md diff --git a/bm25-redesign-plan.md b/bm25-redesign-plan.md deleted file mode 100644 index 4ba98f9573a..00000000000 --- a/bm25-redesign-plan.md +++ /dev/null @@ -1,74 +0,0 @@ -# BM25 Redesign — Implementation Spec - -Reworks the BM25 feature per the maintainer's review (decline of the block-storage -PR). Endorsed independently by GPT-5 and Gemini. Goal: BM25 rides Dgraph's standard -posting-list machinery (MVCC, deltas, rollup, splits, backup, snapshot) instead of a -parallel storage+retrieval stack. - -## What gets deleted -- `posting/bm25block/` and `posting/bm25enc/` (parallel block format). -- `LocalCache.bm25Writes`, `ReadBM25Blob`/`WriteBM25Blob` (second write path). -- `BitBM25Data` user-meta + the BM25 commit branch in `posting/mvcc.go`. -- `bm25_score` pseudo-predicate + `__bm25_scores__` `ParentVars` threading in `query/query.go`. -- Legacy-format fallback / block dir+block keys in `x/keys.go`. - -## Storage model (standard posting lists) -- **Term postings**: one standard index posting list per term at - `IndexKey(attr, IdentBM25 || term)`. Each posting: `Uid = docUID`, - `Value = encodeBM25(tf, docLen)`, `ValType = INT`. Written via `plist.addMutation` - (the normal delta path) → inherits rollup/splits/backup. - - **Rollup-survival fix (linchpin)**: `NewPosting` makes any edge with `ValueId != 0` - a `REF` posting, and `List.encode()` (rollup) keeps a posting's `Value` only when - `Facets != nil || PostingType != REF`. A plain valued REF index posting would have - its TF **stripped at rollup**. Fix: one-line change in `encode()` to also retain - postings that carry a non-empty `Value`. This is the faithful realization of the - maintainer's "TF as the value", and matches how faceted postings already coexist - in both `Pack` (uid) and `Postings` (value). Covered by a forced-rollup regression test. -- **Doc length**: packed into the posting value alongside TF (`encodeBM25(tf, docLen)`), - NOT a separate per-predicate doclen list. Rationale: a single doclen list is a write- - conflict hotspot (every doc mutation writes the same key) and forces a query-time random - read per candidate. Packing makes scoring read `(uid, tf, docLen)` in one shot, - contention-free. Cost: docLen duplicated across a doc's unique terms (acceptable; a doc's - postings are all rewritten together on update anyway). -- **Corpus stats** (`N` docs, `totalTerms` → `avgDL`): conflict-free **bucketed** stats. - `BM25StatsKey(attr, bucket)`, `bucket = docUID % numBuckets` (B=32). Each bucket holds - `(docCount_b, totalTerms_b)`. Mutations touch only their bucket → ~B-fold less contention - than a single hot key. Read path sums across buckets. BM25 tolerates the slight staleness. - -## Value codec `encodeBM25(tf, docLen)` -Two unsigned varints: `tf` then `docLen`. Decoded during scoring. Small file -`posting/bm25.go` (no new package) holds encode/decode + index-mutation logic. - -## Query path (no pseudo-predicate) -- `bm25(attr, "query", [k], [b])` parses to `bm25SearchFn` (unchanged keyword). -- `worker/task.go handleBM25Search`: tokenize query, read bucketed stats → `N`, `avgDL`, - load each term's standard posting `List` via the cache, run WAND, emit `UidMatrix` - (uids asc) + `ValueMatrix` (float64 scores aligned to uids). -- **Surfacing/ordering the score**: via Dgraph's existing **value-variable** (`val()`) - mechanism — the function's `ValueMatrix` populates a value var the user binds and orders - by. No `bm25_score` pseudo-predicate, no new `ParentVars` channel. - -## WAND on the standard iterator (no parallel block format) -Dgraph loads a whole posting list (or split-part) into memory on `Get`. So: -- For each query term, one `List.Iterate` pass materializes a sorted cursor of - `(uid, tf, docLen)`, plus `df`, term `maxTF`, and per-chunk (128) `maxTF`/`minDocLen` - for Block-Max upper bounds — all computed from the in-memory list, **no storage-format - change**. -- WAND / Block-Max WAND DAAT with a top-k min-heap (reuse scoring + heap from the existing - `worker/bm25wand.go`, swapping the block-reading cursor for the standard-list cursor). -- (Future optimization, out of scope now: persist per-block maxTF at rollup to avoid - recomputing for hot terms.) - -## Scoring -`idf = log1p((N - df + 0.5)/(df + 0.5))`; `score = Σ idf·(k+1)·tf / (k·(1-b+b·dl/avgDL) + tf)`. -Defaults `k=1.2`, `b=0.75`. - -## Implementation phases -1. Storage+index: `encode()` retention fix; `posting/bm25.go` (value codec + mutations); - bucketed stats; delete bm25block/bm25enc, bm25Writes, BitBM25Data, mvcc branch. -2. Keys: trim `x/keys.go` to `BM25IndexKey` + bucketed `BM25StatsKey`. -3. Tokenizer: keep `BM25Tokenizer` + query tokens (minor cleanup). -4. Query+WAND: rewrite `worker/bm25wand.go` over standard lists; rewrite `handleBM25Search`; - remove pseudo-predicate/ParentVars from `query/query.go`; wire value-var scoring. -5. Tests: forced-rollup TF-survival test; bucketed-stats test; WAND unit tests over standard - lists; adapt `query/query_bm25_test.go`; build + run. From ae42ebaaeccbf5d5dc564bb72f8e1f0fd1677186 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 3 Jun 2026 23:05:40 -0400 Subject: [PATCH 17/19] feat(search): native hybrid search with score fusion (fuse + hybrid) Add native multi-signal search fusion so a single DQL query can combine BM25 text relevance with vector similarity (and any other scored value variable), instead of issuing separate queries and fusing in application code. New surface: - fuse(v1, v2, ..., method:"rrf"|"linear", k:60, weights:"0.3,0.7", normalize:"max"|"none", topk:N): an N-way combinator over already-scored DQL value variables. RRF = sum 1/(k+rank); linear = weighted sum of (optionally max-normalized) scores. Outer-join/union semantics: a uid missing from a channel contributes nothing (RRF) or 0 (linear), never dropped. Computed coordinator-side over resolved variables; the existing dependency scheduler orders it after its channel blocks. - hybrid(textPred, "query", vecPred, $vec, topk:N, method:..., k:...): convenience sugar rewritten at parse time into bm25 + similar_to channel blocks plus a fuse() block (no distinct execution path). Vector scores surfaced: - similar_to now binds a higher-is-better similarity score (cosine/dot as-is; euclidean as 1/(1+d^2)) to a value variable, so vector results can be a fusion channel alongside BM25. New SearchScored / SearchWithUidScored mirror Search / SearchWithUid exactly, so scoring a plain vector query does not change which neighbors it returns; the *Options* scored variants apply only when ef/distance-threshold is supplied. Robustness (post adversarial review by GPT-5 + Gemini): - non-finite scores are dropped before fusion so they cannot break the sort comparator or poison linear sums; - fused value variable follows the bm25 value-variable contract (ascending uid set + uid->score map; ranked order via orderdesc: val(var); topk selected before the ascending sort); - undefined fuse channel vars are rejected at parse time; - the __hybrid prefix is reserved to avoid synthetic/user var collisions. Tests: fusion-core unit tests (RRF/linear/union/ties/NaN/topk/determinism), parser tests (fuse + hybrid + error paths), score-orientation tests, and end-to-end integration tests combining BM25 + vector (fuse and hybrid). Design and review notes in docs/superpowers/specs. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../2026-06-03-hybrid-search-fusion-design.md | 265 +++++++++++++ dql/fuse_parser_test.go | 216 +++++++++++ dql/hybrid.go | 146 +++++++ dql/parser.go | 77 +++- query/common_test.go | 19 + query/fuse.go | 363 ++++++++++++++++++ query/fuse_test.go | 208 ++++++++++ query/query.go | 37 +- query/query_hybrid_test.go | 276 +++++++++++++ tok/hnsw/helper.go | 18 + tok/hnsw/persistent_hnsw.go | 99 ++++- tok/hnsw/search_layer.go | 5 +- tok/hnsw/similarity_score_test.go | 36 ++ tok/index/index.go | 23 ++ tok/index/search_path.go | 5 + worker/task.go | 65 +++- 16 files changed, 1819 insertions(+), 39 deletions(-) create mode 100644 docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md create mode 100644 dql/fuse_parser_test.go create mode 100644 dql/hybrid.go create mode 100644 query/fuse.go create mode 100644 query/fuse_test.go create mode 100644 query/query_hybrid_test.go create mode 100644 tok/hnsw/similarity_score_test.go diff --git a/docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md b/docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md new file mode 100644 index 00000000000..3be9c207bf0 --- /dev/null +++ b/docs/superpowers/specs/2026-06-03-hybrid-search-fusion-design.md @@ -0,0 +1,265 @@ +# Native Hybrid Search with Score Fusion — Design Spec + +**Branch:** `sp/hybrid-search` (off `sp/bm25`) +**Date:** 2026-06-03 +**Status:** Approved (Option C) + +## 1. Problem + +Dgraph can rank text (BM25, on `sp/bm25`) and find nearest vectors (HNSW `similar_to`), +but cannot **combine** them. A consumer wanting hybrid retrieval (the standard RAG +pattern: dense vector + sparse/keyword, fused into one ranked list) must issue +separate DQL queries and fuse the results in application code. + +Concretely, the reference consumer (modelhub, a GraphRAG product running on Dgraph) +issues **three** separate DQL queries per search — (1) `similar_to` vector, (2) +BM25-ish term, (3) entity-anchored term — and fuses them in Python with Reciprocal +Rank Fusion (`score = Σ 1/(k+rank)`, k=60), sometimes with a linear combination +(`α·vec + (1-α)·text`, scores max-normalized first). + +This spec brings that fusion **natively into a single DQL query**. + +## 2. Approach (Option C) + +Two pieces, agreed after consulting GPT-5 and Gemini (both recommended the N-way +primitive as the core; Option C adds a thin convenience wrapper): + +1. **`fuse()`** — an N-way fusion combinator over already-scored DQL **value + variables**. The general primitive. Handles any number of channels and any + scored signal, not just bm25/vector. +2. **`hybrid()`** — convenience sugar for the common 2-channel bm25+vector case, + expanded at query-rewrite time into two channel blocks + a `fuse()`. No new + executor path. + +A prerequisite makes `fuse()` useful with vectors: + +3. **Surface `similar_to` similarity scores as a value variable** (today it returns + only uids). Required so vector results can be a fusion channel. + +### Why this fits Dgraph's architecture + +- DQL already has **value variables** carrying uid→score maps (`var(func: bm25(...))` + binds per-doc scores into a `varValue{Uids, Vals}` — see `query/query.go` + `populateUidValVar`, the `bm25` case). +- `ProcessQuery` already has a **variable-dependency scheduler** (`canExecute` over + `QueryVars.Needs`/`.Defines`): a block runs only once the variables it needs are + populated. A `fuse()` block that *needs* its channel vars and *defines* the fused + var is scheduled automatically after its inputs — no new sequencing logic. +- Fusion is a **coordinator-side** operation over resolved variables (like `math()` + and `uid(a,b)`), which is correct under predicate sharding: each channel may be + computed on a different Raft group; their value variables are already merged to + the coordinator before fusion runs. + +## 3. DQL Surface + +### 3.1 `fuse()` — N-way primitive + +``` +v as var(func: bm25(text, "quick brown fox")) +e as var(func: similar_to(emb, 100, $queryVec)) + +f as var(func: fuse(v, e, method: "rrf", k: 60)) + +{ + result(func: uid(f), orderdesc: val(f), first: 10) { + uid + val(f) + } +} +``` + +- **Positional args**: two or more value-variable references (the channels). These + become the block's `NeedsVar`. +- **Named options** (parsed like `similar_to`'s `ef:`/`distance_threshold:`): + - `method`: `"rrf"` (default) or `"linear"`. + - `k`: RRF rank constant (default `60`). Ignored for linear. + - `weights`: comma-separated per-channel weights for linear, aligned positionally + with the channel args (e.g. `weights: "0.3,0.7"`). Default: `1.0` each. Ignored + for RRF. + - `normalize`: linear-only score normalization, `"max"` (default) or `"none"`. + - `topk`: optional cap on the number of fused results emitted (default: unbounded; + downstream `first/offset` still applies). Bounds coordinator work. +- **Output**: binds `f` as a value variable — both the **union** uid set + (`uid(f)`) and the uid→fusedScore map (`val(f)`, `orderdesc: val(f)`), exactly + like the bm25 ranker var. + +### 3.2 `hybrid()` — 2-channel sugar + +``` +f as var(func: hybrid(text, "quick brown fox", emb, $queryVec, 100, method: "rrf", k: 60)) + +{ result(func: uid(f), orderdesc: val(f), first: 10) { uid val(f) } } +``` + +Positional: `textPredicate, "queryText", vectorPredicate, $queryVec, topk`. Named +options as for `fuse()`. Rewritten at query-build time to: + +``` +__hybrid_0_ as var(func: bm25(text, "quick brown fox")) +__hybrid_1_ as var(func: similar_to(emb, 100, $queryVec)) +f as var(func: fuse(__hybrid_0_, __hybrid_1_, method: "rrf", k: 60)) +``` + +The synthetic var names are unique per hybrid block. After rewrite there is no +distinct `hybrid` execution path — it is purely a parser/builder transformation. + +## 4. Fusion Semantics (the core, `query/fuse.go`) + +Pure function over channels, fully unit-testable, no I/O: + +```go +type fuseChannel struct { + scores map[uint64]float64 // uid -> raw channel score (higher = better) + weight float64 // linear weight (default 1.0) +} + +func fuseRRF(channels []fuseChannel, k float64, topk int) []scoredUid +func fuseLinear(channels []fuseChannel, normalize string, topk int) []scoredUid +``` + +### Outer-join / union semantics (the key correctness point) + +Both models independently flagged this. Fusion is a **set union** of candidate uids +across channels, NOT an intersection. For a uid present in some channels but not +others: + +- **RRF**: only ranked channels contribute. `fused[u] = Σ_{c: u∈c} 1/(k + rank_c(u))`. + A channel that doesn't contain `u` adds nothing (equivalent to rank = ∞). +- **Linear**: missing channel contributes `0`. `fused[u] = Σ_c weight_c · norm_c(score_c(u))`, + where `norm_c(missing) = 0`. + +Standard DQL `math()` across var blocks aligns/intersects on uid and would drop or +NaN-poison such uids — `fuse()` must not. + +### RRF ranks + +Each channel is independently sorted by raw score **descending**, tie-broken by uid +**ascending**, to assign 1-based ranks. (Deterministic; matches the bm25/HNSW +`sorted()` tie-break already in the codebase.) + +### Linear normalization + +`"max"` (default): divide each channel's scores by that channel's max +(`|max|`, guard against 0 → channel contributes 0). Brings heterogeneous score +scales (BM25 ∈ [0,∞), cosine ∈ [-1,1]) onto a comparable range before weighting. +`"none"`: use raw scores (caller asserts they're comparable). + +### Output ordering + +Fused results sorted by fused score descending, tie-broken by uid ascending. If +`topk > 0`, truncate to `topk`. The output is emitted to the value variable as the +union uid set (ascending, per the query pipeline contract) + positionally-aligned +fused scores — identical shape to the bm25 var binding. + +## 5. Vector score surfacing (prerequisite) + +`similar_to` currently emits only `UidMatrix` (worker/task.go ~L417). Change: + +- Extend `index.SearchPathResult` with `Distances []float64` parallel to `Neighbors`, + populated from the search-layer heap (which already holds metric-domain distances, + `n.value`) in `addFinalNeighbors`. +- In the `similarToFn` worker path, when the function is bound to a value variable, + use the path-returning search to obtain distances and emit a `ValueMatrix` of + **similarity** scores (higher = better): + - **cosine**: cosine similarity in [-1,1], as-is. + - **dotproduct**: dot product, as-is. + - **euclidean**: `1/(1 + d)` where `d` is the (non-squared) Euclidean distance, + mapping to (0,1] so higher = better and linear normalization is well-behaved. +- Bind in `populateUidValVar` exactly like the `bm25` case (reuse/generalize that + branch). When `similar_to` is **not** bound to a var, behavior is unchanged + (uids only) — zero overhead for existing queries. + +Orientation contract: **all rankers surface higher-is-better scores**, so RRF and +linear fusion compose without per-channel sign handling. + +## 6. Execution flow + +1. Parse: `fuse`/`hybrid` recognized as functions in `dql/parser.go`; channel args → + `NeedsVar`, named options stored on the function. `hybrid` expands to channel + blocks + `fuse`. The block `Defines` its output var, `Needs` its channel vars. +2. Schedule: `ProcessQuery`'s `canExecute` runs channel blocks first; the `fuse` + block becomes runnable once all channel vars are in `req.Vars`. +3. Compute: the `fuse` block is **coordinator-only** — like the existing + `similar_to`-empty/`IsEmpty` cases, it is **not** dispatched to a worker. Fusion + is computed in `populateVarMap` (new `fuse` case) reading channel `varValue`s + from `doneVars`, producing the fused `varValue`. +4. Consume: downstream `uid(f)` / `orderdesc: val(f)` / `first`/`offset` work via the + existing value-variable machinery. + +## 7. Validation & errors + +- ≥1 channel var required (2+ to be meaningful; 1 is allowed and passes through). +- Unknown `method` → error. `k <= 0` → error. `weights` count must match channel + count when provided → error. Malformed `weights` floats → error. +- Empty channels (no matches) are valid and contribute nothing. +- A channel var that is a **uid variable without scores** (no `Vals`): for RRF, rank + by the var's intrinsic order if any, else treat as unscored → error with a clear + message ("fuse channel %q has no scores; use a ranker like bm25/similar_to"). MVP: + require scored channels. + +## 8. Testing + +**Unit (`query/fuse_test.go`)** — pure fusion core: +- RRF: known ranks → known `Σ 1/(k+rank)`; default k=60; custom k. +- Linear: max-normalize, weights, `normalize:none`. +- Union semantics: uid in 1 of N channels; disjoint channels; full overlap. +- Ties (equal scores → uid-ascending rank); empty channels; single channel. +- `topk` truncation; determinism. + +**Worker (`worker/`)** — vector score surfacing: +- `similar_to` bound to a var emits similarity scores with correct orientation per + metric; unbound `similar_to` unchanged. + +**Integration (systest/DQL)** — end-to-end: +- `fuse()` RRF over bm25 + vector; ordering matches hand-computed RRF. +- `fuse()` linear with weights. +- 3-channel fusion (modelhub's shape). +- Pagination (`first`/`offset`) on the consuming block. +- Missing-uid union correctness. +- `hybrid()` produces results identical to the equivalent explicit `fuse()`. +- Error paths (unknown method, bad weights, unscored channel). + +## 9. Out of scope (future) + +- Filter pushdown into HNSW (pre-filtered ANN) — separate gap. +- Worker-side fusion pushdown / `topk` propagation into channel funcs. +- Additional methods (weighted-RRF, ISR, distribution-based fusion). +- Reranking primitives. + +## 9a. Adversarial review outcomes (GPT-5 + Gemini) + +Both models deep-reviewed the diff. Findings triaged and resolved: + +- **Behavior preservation for `similar_to` (High, both):** the worker no longer + routes plain (no-option) vector queries through the options path. New + `SearchScored` / `SearchWithUidScored` mirror `Search` / `SearchWithUid` exactly + and just also return scores, so existing queries' neighbor selection is unchanged; + the `*Options*` scored variants are used only when ef/distance-threshold is given. +- **NaN/Inf safety (both):** `scoresFromVar` drops non-finite scores and + `channelMaxAbs` ignores them, so a pathological score can never break the sort + comparator's strict-weak-ordering or poison a linear sum. +- **"Ascending Uids destroys ranking" (both, flagged Critical):** not a bug — this + follows the same value-variable contract as bm25 (Uids is the unordered set for + `uid(var)`; ranked order is recovered via `orderdesc: val(var)`; `topk` selection + happens before the ascending sort). Documented in `computeFuse`. +- **Undefined channel var (GPT-5 Critical "stall"):** not a stall — `checkDependency` + rejects it at parse time ("variables used but not defined"). Regression test added. +- **Synthetic var collision (both, Low):** `__hybrid` is now a reserved prefix; + user vars using it are rejected with a clear message. +- **`@filter(similar_to(...))` + ValueMatrix (both, Med):** follows the proven bm25 + precedent (bm25 emits ValueMatrix and is used in filters); to be confirmed by the + vector integration suite in CI. +- **Coordinator GC churn in `channelRanks` (Gemini, perf):** acknowledged; acceptable + at expected channel sizes, noted as future optimization (buffer pooling) if needed. + +## 10. Files touched + +| Area | File(s) | +|---|---| +| Fusion core (new) | `query/fuse.go`, `query/fuse_test.go` | +| Parser | `dql/parser.go` (recognize `fuse`/`hybrid`, args+opts), `dql/state.go` if needed | +| hybrid rewrite | `query/query.go` (ToSubGraph/build) or `dql` transform | +| Var binding | `query/query.go` `populateUidValVar` (fuse case; generalize bm25 case) | +| Scheduler skip-worker | `query/query.go` `ProcessQuery` (fuse → no worker dispatch) | +| Vector scores | `tok/index/search_path.go` (`Distances`), `tok/hnsw/persistent_hnsw.go` (populate), `worker/task.go` (`similar_to` ValueMatrix) | +| Docs | DQL docs for `fuse`/`hybrid` | diff --git a/dql/fuse_parser_test.go b/dql/fuse_parser_test.go new file mode 100644 index 00000000000..a445d1cc1d5 --- /dev/null +++ b/dql/fuse_parser_test.go @@ -0,0 +1,216 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package dql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// findBlock returns the query block whose result variable is varName. +func findVarBlock(t *testing.T, res Result, varName string) *GraphQuery { + for _, q := range res.Query { + if q.Var == varName { + return q + } + } + t.Fatalf("no block defines var %q", varName) + return nil +} + +func TestParseFuse_ChannelsAndOptions(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "quick brown fox")) + e as var(func: similar_to(emb, 100, "[0.1, 0.2]")) + f as var(func: fuse(v, e, method: "rrf", k: 60)) + result(func: uid(f), orderdesc: val(f)) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + + fb := findVarBlock(t, res, "f") + require.NotNil(t, fb.Func) + require.Equal(t, "fuse", fb.Func.Name) + + // Channels are captured as value-variable NeedsVar in order. + require.Len(t, fb.Func.NeedsVar, 2) + require.Equal(t, "v", fb.Func.NeedsVar[0].Name) + require.Equal(t, ValueVar, fb.Func.NeedsVar[0].Typ) + require.Equal(t, "e", fb.Func.NeedsVar[1].Name) + + // Named options are captured as [key, value] arg pairs. + args := map[string]string{} + for i := 0; i+1 < len(fb.Func.Args); i += 2 { + args[fb.Func.Args[i].Value] = fb.Func.Args[i+1].Value + } + require.Equal(t, "rrf", args["method"]) + require.Equal(t, "60", args["k"]) +} + +func TestParseFuse_LinearWeights(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "fox")) + e as var(func: bm25(title, "fox")) + f as var(func: fuse(v, e, method: "linear", weights: "0.3,0.7", normalize: "max")) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + fb := findVarBlock(t, res, "f") + require.Len(t, fb.Func.NeedsVar, 2) + args := map[string]string{} + for i := 0; i+1 < len(fb.Func.Args); i += 2 { + args[fb.Func.Args[i].Value] = fb.Func.Args[i+1].Value + } + require.Equal(t, "linear", args["method"]) + require.Equal(t, "0.3,0.7", args["weights"]) + require.Equal(t, "max", args["normalize"]) +} + +func TestParseFuse_ThreeChannels(t *testing.T) { + query := ` + { + a as var(func: bm25(text, "fox")) + b as var(func: bm25(title, "fox")) + c as var(func: bm25(body, "fox")) + f as var(func: fuse(a, b, c)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + fb := findVarBlock(t, res, "f") + require.Len(t, fb.Func.NeedsVar, 3) +} + +func TestParseFuse_UnknownOption(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "fox")) + f as var(func: fuse(v, bogus: "x")) + result(func: uid(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unknown option") +} + +func TestParseFuse_DuplicateOption(t *testing.T) { + query := ` + { + v as var(func: bm25(text, "fox")) + f as var(func: fuse(v, k: 10, k: 20)) + result(func: uid(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Duplicate key") +} + +func TestParseFuse_NoChannels(t *testing.T) { + query := ` + { + f as var(func: fuse(method: "rrf")) + result(func: uid(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "at least one value variable") +} + +func TestParseHybrid_ExpandsToThreeBlocks(t *testing.T) { + query := ` + { + f as var(func: hybrid(description, "quick brown fox", emb, "[0.1, 0.2]", topk: 50, method: "rrf", k: 60)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + + // The hybrid block is replaced by bm25 + similar_to + fuse (plus the result block). + var bm25Block, simBlock, fuseBlock *GraphQuery + for _, q := range res.Query { + if q.Func == nil { + continue + } + switch q.Func.Name { + case "bm25": + bm25Block = q + case "similar_to": + simBlock = q + case "fuse": + fuseBlock = q + case "hybrid": + t.Fatal("hybrid block should have been rewritten away") + } + } + require.NotNil(t, bm25Block, "bm25 channel block must exist") + require.NotNil(t, simBlock, "similar_to channel block must exist") + require.NotNil(t, fuseBlock, "fuse block must exist") + + // bm25 channel: predicate + query text. + require.Equal(t, "description", bm25Block.Func.Attr) + require.Equal(t, "quick brown fox", bm25Block.Func.Args[0].Value) + + // similar_to channel: predicate + topk + vector. + require.Equal(t, "emb", simBlock.Func.Attr) + require.Equal(t, "50", simBlock.Func.Args[0].Value) + + // fuse block keeps the original variable name and the fuse options. + require.Equal(t, "f", fuseBlock.Var) + require.Len(t, fuseBlock.Func.NeedsVar, 2) + args := map[string]string{} + for i := 0; i+1 < len(fuseBlock.Func.Args); i += 2 { + args[fuseBlock.Func.Args[i].Value] = fuseBlock.Func.Args[i+1].Value + } + require.Equal(t, "rrf", args["method"]) + require.Equal(t, "60", args["k"]) + // topk is consumed by similar_to, not forwarded to fuse. + require.NotContains(t, args, "topk") +} + +func TestParseHybrid_DefaultTopK(t *testing.T) { + query := ` + { + f as var(func: hybrid(description, "fox", emb, "[0.1]")) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + for _, q := range res.Query { + if q.Func != nil && q.Func.Name == "similar_to" { + require.Equal(t, "100", q.Func.Args[0].Value, "default topk should be 100") + } + } +} + +func TestParseFuse_UndefinedChannelVarErrors(t *testing.T) { + // A fuse channel referencing a variable that no block defines must be rejected + // at parse time (not silently stall the scheduler). + query := ` + { + v as var(func: bm25(text, "fox")) + f as var(func: fuse(v, ghost, method: "rrf")) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "not defined") +} + +func TestParseHybrid_RequiresVar(t *testing.T) { + query := ` + { + result(func: hybrid(description, "fox", emb, "[0.1]")) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "must be assigned to a variable") +} diff --git a/dql/hybrid.go b/dql/hybrid.go new file mode 100644 index 00000000000..8506847a46b --- /dev/null +++ b/dql/hybrid.go @@ -0,0 +1,146 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package dql + +import ( + "fmt" + "strings" +) + +// hybrid() is convenience sugar for the common two-channel hybrid-search case: +// BM25 text relevance fused with vector similarity. It is rewritten, before +// variable collection and dependency checking, into the equivalent explicit form: +// +// x as var(func: hybrid(textPred, "query", vecPred, $vec, topk: 100, method: "rrf", k: 60)) +// +// becomes +// +// __hybrid0_bm25 as var(func: bm25(textPred, "query")) +// __hybrid0_vec as var(func: similar_to(vecPred, 100, $vec)) +// x as var(func: fuse(__hybrid0_bm25, __hybrid0_vec, method: "rrf", k: 60)) +// +// There is therefore no distinct hybrid execution path: it desugars entirely to +// the fuse() primitive. Positional args are textPred, "query", vecPred and the +// query vector ($var or literal); named options are topk (vector neighbor count, +// default 100) plus the fuse options method/k/weights/normalize. +const ( + hybridTopKOption = "topk" + hybridDefaultTopK = "100" + // hybridVarPrefix namespaces the synthetic channel variables a hybrid() block + // expands into. It is reserved: user variables may not start with it. + hybridVarPrefix = "__hybrid" +) + +// rewriteHybridBlocks expands every hybrid() query block in res into its three +// constituent blocks. It runs before fragment expansion, variable substitution, +// and dependency checking so the generated blocks participate normally. +func rewriteHybridBlocks(res *Result) error { + hasHybrid := false + for _, qu := range res.Query { + if qu != nil && qu.Func != nil && qu.Func.Name == hybridFunc { + hasHybrid = true + break + } + } + if !hasHybrid { + return nil + } + + // Guard against the (extremely unlikely) case of a user variable colliding with + // the synthetic channel names we generate, which would otherwise produce a + // confusing "defined multiple times" error the user can't act on. + for _, qu := range res.Query { + if qu != nil && strings.HasPrefix(qu.Var, hybridVarPrefix) { + return fmt.Errorf("variable %q uses the reserved prefix %q (used internally by hybrid)", + qu.Var, hybridVarPrefix) + } + } + + expanded := make([]*GraphQuery, 0, len(res.Query)+2) + hybridIdx := 0 + for _, qu := range res.Query { + if qu == nil || qu.Func == nil || qu.Func.Name != hybridFunc { + expanded = append(expanded, qu) + continue + } + blocks, err := expandHybridBlock(qu, hybridIdx) + if err != nil { + return err + } + hybridIdx++ + expanded = append(expanded, blocks...) + } + res.Query = expanded + return nil +} + +// expandHybridBlock turns a single hybrid() block into [bm25, similar_to, fuse]. +func expandHybridBlock(qu *GraphQuery, idx int) ([]*GraphQuery, error) { + if qu.Var == "" { + return nil, fmt.Errorf("hybrid must be assigned to a variable, e.g. " + + "`x as var(func: hybrid(textPred, \"query\", vecPred, $vec))`") + } + fn := qu.Func + textPred := fn.Attr + if textPred == "" { + return nil, fmt.Errorf("hybrid: missing text predicate (first argument)") + } + + // hybrid has exactly three positional args in Args (the text predicate is the + // function Attr): queryText, vecPred and the query vector. Any further args are + // key/value option pairs appended by the parser. + const numPositional = 3 + if len(fn.Args) < numPositional { + return nil, fmt.Errorf("hybrid requires textPred, \"query text\", vecPred and a "+ + "query vector; got %d positional arguments", len(fn.Args)+1) + } + queryText := fn.Args[0].Value + vecPred := fn.Args[1].Value + vecArg := fn.Args[2] + + // Parse options: topk feeds similar_to's neighbor count; the rest feed fuse. + topk := hybridDefaultTopK + var fuseArgs []Arg + for i := numPositional; i+1 < len(fn.Args); i += 2 { + key := strings.ToLower(fn.Args[i].Value) + val := fn.Args[i+1] + if key == hybridTopKOption { + topk = val.Value + continue + } + fuseArgs = append(fuseArgs, Arg{Value: key}, val) + } + + chanBM25 := fmt.Sprintf("%s%d_bm25", hybridVarPrefix, idx) + chanVec := fmt.Sprintf("%s%d_vec", hybridVarPrefix, idx) + + bm25Block := &GraphQuery{ + Alias: "var", + Var: chanBM25, + Func: &Function{Name: "bm25", Attr: textPred, Args: []Arg{{Value: queryText}}}, + Args: map[string]string{}, + } + simBlock := &GraphQuery{ + Alias: "var", + Var: chanVec, + Func: &Function{Name: similarToFn, Attr: vecPred, Args: []Arg{{Value: topk}, vecArg}}, + Args: map[string]string{}, + } + + channels := []VarContext{ + {Name: chanBM25, Typ: ValueVar}, + {Name: chanVec, Typ: ValueVar}, + } + fuseBlock := &GraphQuery{ + Alias: "var", + Var: qu.Var, + Func: &Function{Name: fuseFunc, NeedsVar: channels, Args: fuseArgs}, + NeedsVar: channels, + Args: map[string]string{}, + } + + return []*GraphQuery{bm25Block, simBlock, fuseBlock}, nil +} diff --git a/dql/parser.go b/dql/parser.go index 666c3eacaab..930ba2d58d5 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -29,8 +29,15 @@ const ( countFunc = "count" uidInFunc = "uid_in" similarToFn = "similar_to" + fuseFunc = "fuse" + hybridFunc = "hybrid" ) +// fuseOptionKeys is the set of named options accepted by the fuse() function. +var fuseOptionKeys = map[string]struct{}{ + "method": {}, "k": {}, "weights": {}, "normalize": {}, "topk": {}, +} + var ( errExpandType = "expand is only compatible with type filters" ) @@ -699,6 +706,12 @@ func ParseWithNeedVars(r Request, needVars []string) (res Result, rerr error) { } } + // Expand any hybrid() sugar blocks into explicit bm25 + similar_to channel + // blocks plus a fuse() block before variable collection and dependency checks. + if err := rewriteHybridBlocks(&res); err != nil { + return res, err + } + if len(res.Query) != 0 { res.QueryVars = make([]*Vars, 0, len(res.Query)) for i := range res.Query { @@ -1701,7 +1714,8 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25", + fuseFunc, hybridFunc: return true } return false @@ -1749,6 +1763,10 @@ L: if function.Name == similarToFn { similarToOptSeen = make(map[string]struct{}) } + var fuseOptSeen map[string]struct{} + if function.Name == fuseFunc || function.Name == hybridFunc { + fuseOptSeen = make(map[string]struct{}) + } if _, ok := tryParseItemType(it, itemLeftRound); !ok { return nil, it.Errorf("Expected ( after func name [%s]", function.Name) } @@ -1980,6 +1998,46 @@ L: // Disallow extra positional args after (k, vec). Options must be named. return nil, itemInFunc.Errorf("Expected named parameter in similar_to options (e.g. ef: 64)") } + + // fuse(v1, v2, ..., method: "rrf", k: 60) collects bare value-variable + // names as channels (handled by the fuseFunc case in the NeedsVar switch + // below) and key:value pairs as named options. An itemName followed by a + // colon is an option; otherwise it falls through to bare-name handling. + // hybrid() shares the same named options (its positional args are handled + // by the generic bare-name path). + if itemInFunc.Typ == itemName && + (function.Name == fuseFunc || function.Name == hybridFunc) { + next, ok := it.PeekOne() + if ok && next.Typ == itemColon { + key := strings.ToLower(collectName(it, itemInFunc.Val)) + if _, valid := fuseOptionKeys[key]; !valid { + return nil, itemInFunc.Errorf("Unknown option %q in fuse", key) + } + if _, exists := fuseOptSeen[key]; exists { + return nil, itemInFunc.Errorf("Duplicate key %q in fuse options", key) + } + fuseOptSeen[key] = struct{}{} + if ok := trySkipItemTyp(it, itemColon); !ok { + return nil, it.Errorf("Expected colon(:) after %s", key) + } + if !it.Next() { + return nil, it.Errorf("Expected value for %s", key) + } + valItem := it.Item() + if valItem.Typ != itemName { + return nil, valItem.Errorf("Expected value for %s", key) + } + v := strings.Trim(collectName(it, valItem.Val), " \t") + uq, err := unquoteIfQuoted(v) + if err != nil { + return nil, err + } + function.Args = append(function.Args, Arg{Value: key}, Arg{Value: uq}) + expectArg = false + continue + } + } + if itemInFunc.Typ != itemName { return nil, itemInFunc.Errorf("Expected arg after func [%s], but got item %v", function.Name, itemInFunc) @@ -2029,7 +2087,7 @@ L: // Unlike other functions, uid function has no attribute, everything is args. switch { case len(function.Attr) == 0 && function.Name != uidFunc && - function.Name != typFunc: + function.Name != typFunc && function.Name != fuseFunc: if strings.ContainsRune(itemInFunc.Val, '"') { return nil, itemInFunc.Errorf("Attribute in function"+ @@ -2047,7 +2105,7 @@ L: } function.Lang = val expectLang = false - case function.Name != uidFunc: + case function.Name != uidFunc && function.Name != fuseFunc: // For UID function. we set g.UID function.Args = append(function.Args, Arg{Value: val}) } @@ -2058,6 +2116,12 @@ L: expectArg = false switch function.Name { + case fuseFunc: + // fuse(v1, v2, ...) takes scored value variables as channels. + function.NeedsVar = append(function.NeedsVar, VarContext{ + Name: val, + Typ: ValueVar, + }) case valueFunc: // E.g. @filter(gt(val(a), 10)) function.NeedsVar = append(function.NeedsVar, VarContext{ @@ -2099,10 +2163,15 @@ L: } } - if function.Name != uidFunc && function.Name != typFunc && len(function.Attr) == 0 { + if function.Name != uidFunc && function.Name != typFunc && function.Name != fuseFunc && + len(function.Attr) == 0 { return nil, it.Errorf("Got empty attr for function: [%s]", function.Name) } + if function.Name == fuseFunc && len(function.NeedsVar) == 0 { + return nil, it.Errorf("fuse function requires at least one value variable channel") + } + if function.Name == typFunc && len(function.Args) != 1 { return nil, it.Errorf("type function only supports one argument. Got: %v", function.Args) } diff --git a/query/common_test.go b/query/common_test.go index 32a3e65a81b..64a88bd865e 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -393,6 +393,10 @@ func populateCluster(dc dgraphapi.Cluster) { // BM25 indexing - uses same version gate as ngram for now if ngramSupport { testSchema += "\ndescription_bm25: string @index(bm25) ." + // A parallel vector predicate on the same documents enables hybrid-search + // (fuse) tests that combine BM25 with vector similarity. Gated together with + // BM25 so both channels are always present for those tests. + testSchema += "\ndescription_vec: float32vector @index(hnsw(metric:\"euclidean\")) ." } setSchema(testSchema) @@ -1024,4 +1028,19 @@ func populateCluster(dc dgraphapi.Cluster) { <507> "Brown foxes are quick and agile animals in the forest" . `) x.Panic(err) + + // Vector embeddings on the same documents (dims = [fox, dog, quick, brown]). + // Chosen so a "pure fox" query vector [3,0,0,0] ranks 503 first, then the + // fox/quick docs (506,507,501,502), and the dog-only docs (504,505) last — + // letting hybrid (fuse) tests observe both channels and union semantics. + err = addTriplesToCluster(` + <501> "[1.0, 1.0, 1.0, 1.0]" . + <502> "[1.0, 1.0, 1.0, 1.0]" . + <503> "[3.0, 0.0, 0.0, 0.0]" . + <504> "[0.0, 2.0, 0.0, 0.0]" . + <505> "[0.0, 2.0, 0.0, 0.0]" . + <506> "[1.0, 0.0, 1.0, 0.0]" . + <507> "[1.0, 0.0, 1.0, 1.0]" . + `) + x.Panic(err) } diff --git a/query/fuse.go b/query/fuse.go new file mode 100644 index 00000000000..23fa62ea80c --- /dev/null +++ b/query/fuse.go @@ -0,0 +1,363 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package query + +import ( + "math" + "sort" + "strconv" + "strings" + + "github.com/dgraph-io/dgraph/v25/dql" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/types" + "github.com/pkg/errors" +) + +// This file implements the pure, I/O-free core of native hybrid search: combining +// several already-scored result sets ("channels") into a single ranked list. +// +// A channel is a uid->score map produced by an upstream ranker bound to a DQL value +// variable (e.g. bm25(...) or similar_to(...)). Fusion is a coordinator-side +// operation over resolved value variables; everything here is deterministic and +// independent of storage, sharding, and the query pipeline so it can be tested in +// isolation. The query-layer adapter (see populateUidValVar) converts value +// variables into fuseChannels and the result back into a value variable. + +// fusionMethod selects how channel scores are combined. +type fusionMethod int + +const ( + // fusionRRF is Reciprocal Rank Fusion: each channel contributes 1/(k+rank), + // where rank is the 1-based position of the uid within that channel. Robust to + // heterogeneous score scales because it uses ranks, not raw scores. + fusionRRF fusionMethod = iota + // fusionLinear is a weighted sum of (optionally normalized) raw scores. + fusionLinear +) + +// linearNormalize selects score normalization for fusionLinear. +type linearNormalize int + +const ( + // normalizeMax divides each channel's scores by that channel's maximum absolute + // score, bringing heterogeneous scales onto a comparable [-1,1]-ish range. + normalizeMax linearNormalize = iota + // normalizeNone uses raw scores as-is (the caller asserts comparability). + normalizeNone +) + +// defaultRRFK is the conventional RRF rank constant. Larger k flattens the +// contribution of top ranks; 60 is the widely used default. +const defaultRRFK = 60.0 + +// fuseChannel is one scored input to fusion. scores maps uid -> raw score with the +// convention that higher is always better (all Dgraph rankers surface +// higher-is-better scores). weight applies only to fusionLinear. +type fuseChannel struct { + scores map[uint64]float64 + weight float64 +} + +// fuseOpts configures a fusion run. +type fuseOpts struct { + method fusionMethod + k float64 // RRF rank constant; <=0 falls back to defaultRRFK. + normalize linearNormalize // linear only. + topk int // if >0, truncate output to the top topk results. +} + +// scoredUid is a uid paired with its fused score. +type scoredUid struct { + uid uint64 + score float64 +} + +// fuseChannels combines channels into a single ranked list, sorted by fused score +// descending and tie-broken by uid ascending. The candidate set is the UNION of all +// channels' uids (outer join): a uid missing from a channel simply receives no +// contribution from it (RRF: as if rank = infinity; linear: as if score = 0). It is +// never dropped and never produces NaN. +func fuseChannels(channels []fuseChannel, opts fuseOpts) []scoredUid { + var fused map[uint64]float64 + switch opts.method { + case fusionLinear: + fused = fuseLinear(channels, opts.normalize) + default: + fused = fuseRRF(channels, opts.k) + } + + out := make([]scoredUid, 0, len(fused)) + for uid, s := range fused { + out = append(out, scoredUid{uid: uid, score: s}) + } + sort.Slice(out, func(i, j int) bool { + if out[i].score != out[j].score { + return out[i].score > out[j].score + } + return out[i].uid < out[j].uid + }) + if opts.topk > 0 && len(out) > opts.topk { + out = out[:opts.topk] + } + return out +} + +// channelRanks returns the 1-based rank of every uid in a channel, computed by +// sorting on score descending and tie-breaking by uid ascending. The tie-break +// matches the deterministic ordering used elsewhere in the codebase (bm25/HNSW +// sorted()), so equal-scored uids rank stably by uid. +func channelRanks(c fuseChannel) map[uint64]int { + uids := make([]uint64, 0, len(c.scores)) + for uid := range c.scores { + uids = append(uids, uid) + } + sort.Slice(uids, func(i, j int) bool { + si, sj := c.scores[uids[i]], c.scores[uids[j]] + if si != sj { + return si > sj + } + return uids[i] < uids[j] + }) + ranks := make(map[uint64]int, len(uids)) + for i, uid := range uids { + ranks[uid] = i + 1 + } + return ranks +} + +// fuseRRF computes Reciprocal Rank Fusion over the channels. +func fuseRRF(channels []fuseChannel, k float64) map[uint64]float64 { + if k <= 0 || math.IsNaN(k) || math.IsInf(k, 0) { + k = defaultRRFK + } + fused := make(map[uint64]float64) + for _, c := range channels { + for uid, rank := range channelRanks(c) { + fused[uid] += 1.0 / (k + float64(rank)) + } + } + return fused +} + +// channelMaxAbs returns the maximum finite absolute score in a channel, used to +// max-normalize heterogeneous score scales. Non-finite scores are ignored; +// returns 0 for an empty channel (callers treat a 0 denominator as "contribute 0"). +func channelMaxAbs(c fuseChannel) float64 { + var maxAbs float64 + for _, s := range c.scores { + if math.IsNaN(s) || math.IsInf(s, 0) { + continue + } + if a := math.Abs(s); a > maxAbs { + maxAbs = a + } + } + return maxAbs +} + +// fuseLinear computes a weighted sum of (optionally max-normalized) raw scores. A +// uid missing from a channel contributes 0 from that channel. A channel whose +// maximum absolute score is 0 (all zeros / empty) contributes 0 rather than +// dividing by zero. +func fuseLinear(channels []fuseChannel, normalize linearNormalize) map[uint64]float64 { + denoms := make([]float64, len(channels)) + for i, c := range channels { + if normalize == normalizeMax { + denoms[i] = channelMaxAbs(c) + } else { + denoms[i] = 1.0 + } + } + + fused := make(map[uint64]float64) + for i, c := range channels { + denom := denoms[i] + for uid, s := range c.scores { + norm := s + if normalize == normalizeMax { + if denom == 0 { + norm = 0 + } else { + norm = s / denom + } + } + fused[uid] += c.weight * norm + } + } + // Ensure uids that appear only in zero-contribution channels are still present + // in the union (e.g. an all-zero max-normalized channel). The loop above already + // inserts them with a running sum (possibly 0), so the union is complete. + return fused +} + +// --- Query-layer adapter ----------------------------------------------------- +// +// The functions below bridge the pure fusion core to DQL value variables. They +// parse the fuse() options, read each channel's scores from the already-resolved +// variable map, run fusion, and return a varValue carrying both the union uid set +// and the fused uid->score map (the same shape the bm25 ranker binds). + +// parseFuseOpts extracts fuse() options from the function's key/value arg pairs and +// returns the resolved fuseOpts plus the optional per-channel linear weights +// (nil when unspecified). numChannels is used to validate the weights count. +func parseFuseOpts(args []dql.Arg, numChannels int) (fuseOpts, []float64, error) { + opts := fuseOpts{method: fusionRRF, k: defaultRRFK, normalize: normalizeMax} + var weights []float64 + + for i := 0; i+1 < len(args); i += 2 { + key := strings.ToLower(args[i].Value) + val := args[i+1].Value + switch key { + case "method": + switch strings.ToLower(val) { + case "rrf": + opts.method = fusionRRF + case "linear": + opts.method = fusionLinear + default: + return opts, nil, errors.Errorf("fuse: unknown method %q (want rrf or linear)", val) + } + case "k": + k, err := strconv.ParseFloat(val, 64) + if err != nil || k <= 0 || math.IsNaN(k) || math.IsInf(k, 0) { + return opts, nil, errors.Errorf("fuse: k must be a positive finite number, got %q", val) + } + opts.k = k + case "normalize": + switch strings.ToLower(val) { + case "max": + opts.normalize = normalizeMax + case "none": + opts.normalize = normalizeNone + default: + return opts, nil, errors.Errorf("fuse: unknown normalize %q (want max or none)", val) + } + case "topk": + tk, err := strconv.Atoi(val) + if err != nil || tk < 0 { + return opts, nil, errors.Errorf("fuse: topk must be a non-negative integer, got %q", val) + } + opts.topk = tk + case "weights": + parts := strings.Split(val, ",") + weights = make([]float64, 0, len(parts)) + for _, p := range parts { + w, err := strconv.ParseFloat(strings.TrimSpace(p), 64) + if err != nil || math.IsNaN(w) || math.IsInf(w, 0) { + return opts, nil, errors.Errorf("fuse: invalid weight %q", p) + } + weights = append(weights, w) + } + default: + return opts, nil, errors.Errorf("fuse: unknown option %q", key) + } + } + + if weights != nil && len(weights) != numChannels { + return opts, nil, errors.Errorf("fuse: weights count (%d) must match channel count (%d)", + len(weights), numChannels) + } + return opts, weights, nil +} + +// computeFuse reads the channel value variables named in the fuse function's +// NeedsVar from doneVars, runs fusion, and returns a varValue with the union uid +// set and the fused uid->score map. +func computeFuse(args []dql.Arg, needsVar []dql.VarContext, + doneVars map[string]varValue, sgPath []*SubGraph) (varValue, error) { + + if len(needsVar) == 0 { + return varValue{}, errors.Errorf("fuse: requires at least one value variable channel") + } + + opts, weights, err := parseFuseOpts(args, len(needsVar)) + if err != nil { + return varValue{}, err + } + + channels := make([]fuseChannel, len(needsVar)) + for i, nv := range needsVar { + v, ok := doneVars[nv.Name] + if !ok || v.Vals == nil || v.Vals.IsEmpty() { + // A channel that produced no scored results contributes nothing but is + // still a valid (empty) channel. + channels[i] = fuseChannel{scores: map[uint64]float64{}, weight: 1.0} + } else { + scores, err := scoresFromVar(v, nv.Name) + if err != nil { + return varValue{}, err + } + channels[i] = fuseChannel{scores: scores, weight: 1.0} + } + if weights != nil { + channels[i].weight = weights[i] + } + } + + fused := fuseChannels(channels, opts) + + out := varValue{Vals: types.NewShardedMap(), path: sgPath} + uids := make([]uint64, len(fused)) + for i, r := range fused { + uids[i] = r.uid + out.Vals.Set(r.uid, types.Val{Tid: types.FloatID, Value: r.score}) + } + // Emit the uid set in ascending order — the Dgraph value-variable contract (the + // same one bm25 follows): Uids is the unordered candidate set used by uid(var), + // and the fused score lives in Vals. Callers recover ranked order with + // `orderdesc: val(var)`. When `topk` is set, fuseChannels has already selected + // the top-k by fused score before this ascending sort, so top-k + orderdesc is + // correct. (Sorting here by score would break uid(var) set semantics.) + sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) + out.Uids = &pb.List{Uids: uids} + return out, nil +} + +// scoresFromVar extracts a uid->float64 score map from a value variable. The +// variable must carry numeric scores (as bm25/similar_to bind); a uid variable +// without scores cannot be a fusion channel. +func scoresFromVar(v varValue, name string) (map[uint64]float64, error) { + scores := make(map[uint64]float64, v.Vals.Len()) + var convErr error + err := v.Vals.Iterate(func(uid uint64, val types.Val) error { + var s float64 + // bm25 and similar_to bind FloatID scores directly; take that fast path so a + // non-finite value (which types.Convert rejects) is dropped rather than + // mis-reported as "non-numeric". Other numeric types go through Convert. + if val.Tid == types.FloatID { + f, ok := val.Value.(float64) + if !ok { + convErr = errors.Errorf("fuse: channel %q has a malformed score value", name) + return convErr + } + s = f + } else { + f, err := types.Convert(val, types.FloatID) + if err != nil { + convErr = errors.Errorf("fuse: channel %q has non-numeric scores; use a "+ + "ranker such as bm25 or similar_to", name) + return convErr + } + s = f.Value.(float64) + } + // Drop non-finite scores so they can never poison fusion: a NaN/Inf would + // break the sort comparator's strict-weak-ordering and propagate through + // linear sums. A uid dropped here simply doesn't participate via this channel. + if math.IsNaN(s) || math.IsInf(s, 0) { + return nil + } + scores[uid] = s + return nil + }) + if convErr != nil { + return nil, convErr + } + if err != nil { + return nil, err + } + return scores, nil +} diff --git a/query/fuse_test.go b/query/fuse_test.go new file mode 100644 index 00000000000..69529c8beae --- /dev/null +++ b/query/fuse_test.go @@ -0,0 +1,208 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package query + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/types" +) + +// ch is a small helper to build a fusion channel from a uid->score map with a +// default weight of 1.0. +func ch(scores map[uint64]float64) fuseChannel { + return fuseChannel{scores: scores, weight: 1.0} +} + +// asMap collapses a fused result slice into a uid->score map for assertions that +// don't care about ordering. +func asMap(res []scoredUid) map[uint64]float64 { + m := make(map[uint64]float64, len(res)) + for _, r := range res { + m[r.uid] = r.score + } + return m +} + +func TestFuseRRF_BasicRanks(t *testing.T) { + // Channel A order: 10, 20, 30 (ranks 1,2,3) + // Channel B order: 30, 10, 40 (ranks 1,2,3) + a := ch(map[uint64]float64{10: 9.0, 20: 5.0, 30: 1.0}) + b := ch(map[uint64]float64{30: 0.9, 10: 0.5, 40: 0.1}) + + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + + const k = 60.0 + // uid 10: rank1 in A, rank2 in B + require.InDelta(t, 1/(k+1)+1/(k+2), got[10], 1e-9) + // uid 20: rank2 in A only + require.InDelta(t, 1/(k+2), got[20], 1e-9) + // uid 30: rank3 in A, rank1 in B + require.InDelta(t, 1/(k+3)+1/(k+1), got[30], 1e-9) + // uid 40: rank3 in B only + require.InDelta(t, 1/(k+3), got[40], 1e-9) +} + +func TestFuseRRF_OrderingAndUnion(t *testing.T) { + a := ch(map[uint64]float64{10: 9.0, 20: 5.0, 30: 1.0}) + b := ch(map[uint64]float64{30: 0.9, 10: 0.5, 40: 0.1}) + + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + + // Union of all uids is present (outer join, not intersection). + require.Len(t, res, 4) + // Sorted by fused score descending. uid 10 and 30 both appear in both channels + // near the top; 10 is rank1+rank2, 30 is rank3+rank1 -> 10 slightly higher. + require.Equal(t, uint64(10), res[0].uid) + require.Equal(t, uint64(30), res[1].uid) + // Scores must be monotonically non-increasing. + for i := 1; i < len(res); i++ { + require.LessOrEqual(t, res[i].score, res[i-1].score) + } +} + +func TestFuseRRF_DefaultK(t *testing.T) { + a := ch(map[uint64]float64{1: 1.0}) + // k<=0 should fall back to the default of 60. + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 0}) + require.InDelta(t, 1/(60.0+1), res[0].score, 1e-9) +} + +func TestFuseRRF_TieBreakByUidAscending(t *testing.T) { + // Equal scores within a channel -> lower uid gets the better (smaller) rank. + a := ch(map[uint64]float64{2: 5.0, 1: 5.0, 3: 5.0}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + // uid 1 rank1, uid 2 rank2, uid 3 rank3. + require.InDelta(t, 1/(60.0+1), got[1], 1e-9) + require.InDelta(t, 1/(60.0+2), got[2], 1e-9) + require.InDelta(t, 1/(60.0+3), got[3], 1e-9) + // Final output tie-broken by uid ascending when fused scores are equal. + require.Equal(t, uint64(1), res[0].uid) +} + +func TestFuseRRF_DisjointChannels(t *testing.T) { + a := ch(map[uint64]float64{1: 9.0, 2: 8.0}) + b := ch(map[uint64]float64{3: 9.0, 4: 8.0}) + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + require.Len(t, res, 4) + got := asMap(res) + // Each uid scored only by its single channel rank. + require.InDelta(t, 1/(60.0+1), got[1], 1e-9) + require.InDelta(t, 1/(60.0+1), got[3], 1e-9) + require.InDelta(t, 1/(60.0+2), got[2], 1e-9) + require.InDelta(t, 1/(60.0+2), got[4], 1e-9) +} + +func TestFuseLinear_MaxNormalizeAndWeights(t *testing.T) { + // BM25-ish scale vs cosine-ish scale. + text := fuseChannel{scores: map[uint64]float64{1: 10.0, 2: 5.0}, weight: 0.3} + vec := fuseChannel{scores: map[uint64]float64{1: 0.8, 2: 0.4, 3: 0.2}, weight: 0.7} + + res := fuseChannels([]fuseChannel{text, vec}, + fuseOpts{method: fusionLinear, normalize: normalizeMax}) + got := asMap(res) + + // max-normalize: text/10, vec/0.8. + // uid1: 0.3*(10/10) + 0.7*(0.8/0.8) = 0.3 + 0.7 = 1.0 + require.InDelta(t, 1.0, got[1], 1e-9) + // uid2: 0.3*(5/10) + 0.7*(0.4/0.8) = 0.15 + 0.35 = 0.5 + require.InDelta(t, 0.5, got[2], 1e-9) + // uid3: only vec: 0.7*(0.2/0.8) = 0.175 (text contributes 0, not NaN) + require.InDelta(t, 0.175, got[3], 1e-9) + + require.Equal(t, uint64(1), res[0].uid) +} + +func TestFuseLinear_NoNormalize(t *testing.T) { + a := fuseChannel{scores: map[uint64]float64{1: 2.0, 2: 1.0}, weight: 1.0} + b := fuseChannel{scores: map[uint64]float64{1: 3.0}, weight: 2.0} + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionLinear, normalize: normalizeNone}) + got := asMap(res) + // uid1: 1*2 + 2*3 = 8 ; uid2: 1*1 = 1 + require.InDelta(t, 8.0, got[1], 1e-9) + require.InDelta(t, 1.0, got[2], 1e-9) +} + +func TestFuseLinear_ZeroMaxChannelContributesZero(t *testing.T) { + // A channel whose scores are all zero must not divide-by-zero / NaN. + a := fuseChannel{scores: map[uint64]float64{1: 0.0, 2: 0.0}, weight: 1.0} + b := fuseChannel{scores: map[uint64]float64{1: 4.0}, weight: 1.0} + res := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionLinear, normalize: normalizeMax}) + got := asMap(res) + require.False(t, math.IsNaN(got[1])) + require.False(t, math.IsNaN(got[2])) + // uid1: a contributes 0, b contributes 4/4=1 -> 1.0 + require.InDelta(t, 1.0, got[1], 1e-9) + // uid2: only in a (all-zero) -> 0.0 + require.InDelta(t, 0.0, got[2], 1e-9) +} + +func TestFuse_TopKTruncation(t *testing.T) { + a := ch(map[uint64]float64{1: 9, 2: 8, 3: 7, 4: 6, 5: 5}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60, topk: 3}) + require.Len(t, res, 3) + require.Equal(t, uint64(1), res[0].uid) + require.Equal(t, uint64(3), res[2].uid) +} + +func TestFuse_SingleChannelPassthroughOrder(t *testing.T) { + a := ch(map[uint64]float64{1: 1, 2: 9, 3: 5}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60}) + // Order should reflect channel ranking: 2 (rank1), 3 (rank2), 1 (rank3). + require.Equal(t, []uint64{2, 3, 1}, []uint64{res[0].uid, res[1].uid, res[2].uid}) +} + +func TestFuse_EmptyChannels(t *testing.T) { + res := fuseChannels([]fuseChannel{ch(nil), ch(map[uint64]float64{})}, + fuseOpts{method: fusionRRF, k: 60}) + require.Empty(t, res) +} + +func TestScoresFromVar_DropsNonFinite(t *testing.T) { + // Non-finite scores from a channel must be dropped so they can't break the sort + // comparator or poison linear sums. + m := types.NewShardedMap() + m.Set(1, types.Val{Tid: types.FloatID, Value: 0.5}) + m.Set(2, types.Val{Tid: types.FloatID, Value: math.NaN()}) + m.Set(3, types.Val{Tid: types.FloatID, Value: math.Inf(1)}) + m.Set(4, types.Val{Tid: types.FloatID, Value: math.Inf(-1)}) + m.Set(5, types.Val{Tid: types.FloatID, Value: 2.0}) + + scores, err := scoresFromVar(varValue{Vals: m}, "ch") + require.NoError(t, err) + require.Len(t, scores, 2, "only finite scores should survive") + require.Contains(t, scores, uint64(1)) + require.Contains(t, scores, uint64(5)) + require.NotContains(t, scores, uint64(2)) + require.NotContains(t, scores, uint64(3)) + require.NotContains(t, scores, uint64(4)) +} + +func TestFuseLinear_NonFiniteChannelDoesNotPoison(t *testing.T) { + // Even if a NaN slips into a channel passed directly to the core, max-normalize + // must not produce a NaN denominator that propagates. + bad := fuseChannel{scores: map[uint64]float64{1: math.NaN(), 2: math.NaN()}, weight: 1.0} + good := fuseChannel{scores: map[uint64]float64{1: 4.0}, weight: 1.0} + res := fuseChannels([]fuseChannel{bad, good}, fuseOpts{method: fusionLinear, normalize: normalizeMax}) + for _, r := range res { + require.False(t, math.IsNaN(r.score), "uid %d score must not be NaN", r.uid) + } +} + +func TestFuse_Determinism(t *testing.T) { + a := ch(map[uint64]float64{1: 5, 2: 5, 3: 5}) + b := ch(map[uint64]float64{3: 1, 2: 1, 1: 1}) + first := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + for i := 0; i < 20; i++ { + again := fuseChannels([]fuseChannel{a, b}, fuseOpts{method: fusionRRF, k: 60}) + require.Equal(t, first, again) + } +} diff --git a/query/query.go b/query/query.go index 80ca66b185d..35f005ed451 100644 --- a/query/query.go +++ b/query/query.go @@ -1570,6 +1570,17 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su var ok bool switch { + case sg.SrcFunc != nil && sg.SrcFunc.Name == "fuse": + // Native hybrid search: fuse() combines several already-scored value + // variables (its NeedsVar channels) into one ranked value variable. Fusion is + // a coordinator-side operation over resolved variables — the channel blocks + // have already populated doneVars by the time this block is scheduled. We bind + // both the union uid set (uid(var)) and the uid->fused-score map (val(var)). + fv, err := computeFuse(sg.SrcFunc.Args, sg.Params.NeedsVar, doneVars, sgPath) + if err != nil { + return err + } + doneVars[sg.Params.Var] = fv case len(sg.counts) > 0: // 1. When count of a predicate is assigned a variable, we store the mapping of uid => // count(predicate). @@ -1605,14 +1616,15 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su Value: int64(len(sg.SrcUIDs.Uids)), } doneVars[sg.Params.Var].Vals.Set(math.MaxUint64, val) - case sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && len(sg.uidMatrix) > 0 && - len(sg.valueMatrix) > 0: - // A query-side ranker (BM25) binds its per-document relevance score as a - // value variable. We populate BOTH the matched uid set and the uid->score - // map so the variable works with uid(var), val(var) and orderdesc: val(var) - // — surfacing and ordering by score without a pseudo-predicate or a - // ParentVars channel. The valueMatrix is positionally aligned with the - // function's returned uidMatrix[0]. + case sg.SrcFunc != nil && (sg.SrcFunc.Name == "bm25" || sg.SrcFunc.Name == "similar_to") && + len(sg.uidMatrix) > 0 && len(sg.valueMatrix) > 0: + // A query-side ranker (BM25 relevance or vector similarity) binds its + // per-document score as a value variable. We populate BOTH the matched uid set + // and the uid->score map so the variable works with uid(var), val(var) and + // orderdesc: val(var) — surfacing and ordering by score without a + // pseudo-predicate or a ParentVars channel. The valueMatrix is positionally + // aligned with the function's returned uidMatrix[0]. For similar_to the score + // is a higher-is-better similarity; this also lets vector results feed fuse(). if v, ok = doneVars[sg.Params.Var]; !ok { v = varValue{Vals: types.NewShardedMap(), path: sgPath, strList: sg.valueMatrix} } @@ -3003,6 +3015,15 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { continue } + // fuse() is a coordinator-side fusion over already-resolved value + // variables; it is never dispatched to a worker. The fused variable is + // produced from doneVars in populateVarMap (the fuse case in + // populateUidValVar) once its channel variables are populated. + if sg.SrcFunc != nil && sg.SrcFunc.Name == "fuse" { + errChan <- nil + continue + } + switch { case sg.Params.Alias == "shortest": // We allow only one shortest path block per query. diff --git a/query/query_hybrid_test.go b/query/query_hybrid_test.go new file mode 100644 index 00000000000..7443fa3791b --- /dev/null +++ b/query/query_hybrid_test.go @@ -0,0 +1,276 @@ +//go:build integration || cloud + +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +//nolint:lll +package query + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +// hybridResult unmarshals a {uid, val(f)} result list ordered by fused score. +type hybridRow struct { + UID string `json:"uid"` + Score float64 `json:"val(f)"` +} + +func fuseRows(t *testing.T, js string) []hybridRow { + t.Helper() + var resp struct { + Data struct { + Me []hybridRow `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + return resp.Data.Me +} + +func fuseUIDSet(rows []hybridRow) map[string]bool { + set := make(map[string]bool, len(rows)) + for _, r := range rows { + set[r.UID] = true + } + return set +} + +// --- similar_to score surfacing (prerequisite for vector fusion) ------------- + +func TestSimilarToScoreVariable(t *testing.T) { + // similar_to bound to a value variable surfaces a higher-is-better similarity + // score. The query vector equals doc 503's embedding, so 503 is the closest + // (euclidean distance 0 -> score 1.0) and must rank first. + query := ` + { + s as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) + me(func: uid(s), orderdesc: val(s)) { + uid + val(s) + } + }` + js := processQueryNoErr(t, query) + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(s)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me) + + uid503 := uidHex(t, 503) + require.Equal(t, uid503, resp.Data.Me[0].UID, "closest vector (503) should rank first") + require.InDelta(t, 1.0, resp.Data.Me[0].Score, 1e-6, "exact match should score 1/(1+0)=1.0") + + // Scores must be in descending order. + for i := 1; i < len(resp.Data.Me); i++ { + require.GreaterOrEqual(t, resp.Data.Me[i-1].Score, resp.Data.Me[i].Score) + } +} + +// --- fuse() over BM25 channels ---------------------------------------------- + +func TestFuseRRFTwoBM25Channels(t *testing.T) { + // Fuse two BM25 channels: "fox" (matches 501,502,503,506,507) and "dog" + // (matches 501,502,504,505). The union must include dog-only docs (504,505) + // that the fox channel never returns — proving outer-join (not intersection). + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.NotEmpty(t, rows) + set := fuseUIDSet(rows) + + // Union contains both fox-only (503) and dog-only (504, 505) documents. + require.True(t, set[uidHex(t, 503)], "fox-only doc 503 must be present") + require.True(t, set[uidHex(t, 504)], "dog-only doc 504 must be present (union, not intersection)") + require.True(t, set[uidHex(t, 505)], "dog-only doc 505 must be present") + + // Doc 501/502 appear in BOTH channels, so their fused RRF score should exceed a + // document that appears in only one channel at the same rank. + scores := make(map[string]float64) + for _, r := range rows { + scores[r.UID] = r.Score + } + require.Greater(t, scores[uidHex(t, 501)], 0.0) + + // Scores descending. + for i := 1; i < len(rows); i++ { + require.GreaterOrEqual(t, rows[i-1].Score, rows[i].Score) + } +} + +func TestFuseLinearWeights(t *testing.T) { + // Linear fusion with weights. Heavily weight the "dog" channel; a dog-only doc + // should then be present and outscore where appropriate. + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "linear", weights: "0.1,0.9", normalize: "max")) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.NotEmpty(t, rows) + set := fuseUIDSet(rows) + require.True(t, set[uidHex(t, 504)], "dog-only doc must appear with linear fusion") + for i := 1; i < len(rows); i++ { + require.GreaterOrEqual(t, rows[i-1].Score, rows[i].Score) + } +} + +func TestFuseSingleChannelPassthrough(t *testing.T) { + // A single-channel fuse should preserve that channel's ranking. "fox fox fox" + // (503) is the top BM25 result for "fox". + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + f as var(func: fuse(fox, method: "rrf")) + me(func: uid(f), orderdesc: val(f), first: 1) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.Len(t, rows, 1) + require.Equal(t, uidHex(t, 503), rows[0].UID) +} + +func TestFusePagination(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "rrf")) + me(func: uid(f), orderdesc: val(f), first: 2, offset: 1) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.Len(t, rows, 2, "first:2 offset:1 should return exactly 2 rows") +} + +// --- fuse() hybrid: BM25 + vector ------------------------------------------- + +func TestFuseHybridBM25AndVector(t *testing.T) { + // The headline use case: fuse BM25 text relevance with vector similarity. + // Doc 503 ("fox fox fox", embedding [3,0,0,0]) is rank-1 in BOTH the "fox" BM25 + // channel and the [3,0,0,0] vector channel, so it must be the top fused result. + // The vector channel (k=7) returns all docs, so dog-only docs (504,505) enter + // the union even though the BM25 "fox" channel never returns them. + query := ` + { + txt as var(func: bm25(description_bm25, "fox")) + vec as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) + f as var(func: fuse(txt, vec, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + rows := fuseRows(t, processQueryNoErr(t, query)) + require.NotEmpty(t, rows) + require.Equal(t, uidHex(t, 503), rows[0].UID, "503 is rank-1 in both channels -> top fused") + + set := fuseUIDSet(rows) + require.True(t, set[uidHex(t, 504)], "vector-only doc 504 must enter the union") + require.True(t, set[uidHex(t, 505)], "vector-only doc 505 must enter the union") + + for i := 1; i < len(rows); i++ { + require.GreaterOrEqual(t, rows[i-1].Score, rows[i].Score) + } +} + +// --- hybrid() sugar ---------------------------------------------------------- + +func TestHybridSugarEquivalentToFuse(t *testing.T) { + // hybrid() must produce the same top result as the explicit fuse() form. + hybridQ := ` + { + f as var(func: hybrid(description_bm25, "fox", description_vec, "[3.0, 0.0, 0.0, 0.0]", topk: 7, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + explicitQ := ` + { + txt as var(func: bm25(description_bm25, "fox")) + vec as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) + f as var(func: fuse(txt, vec, method: "rrf", k: 60)) + me(func: uid(f), orderdesc: val(f)) { + uid + val(f) + } + }` + hybridRows := fuseRows(t, processQueryNoErr(t, hybridQ)) + explicitRows := fuseRows(t, processQueryNoErr(t, explicitQ)) + + require.NotEmpty(t, hybridRows) + require.Equal(t, len(explicitRows), len(hybridRows), "hybrid and fuse must return the same set") + require.Equal(t, explicitRows[0].UID, hybridRows[0].UID, "same top result") + // Fused scores should match position-for-position. + for i := range explicitRows { + require.Equal(t, explicitRows[i].UID, hybridRows[i].UID, "row %d uid mismatch", i) + require.InDelta(t, explicitRows[i].Score, hybridRows[i].Score, 1e-9, "row %d score mismatch", i) + } +} + +// --- error handling ---------------------------------------------------------- + +func TestFuseUnknownMethod(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + f as var(func: fuse(fox, method: "bogus")) + me(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "method") +} + +func TestFuseWeightsCountMismatch(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + dog as var(func: bm25(description_bm25, "dog")) + f as var(func: fuse(fox, dog, method: "linear", weights: "0.5")) + me(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "weights") +} + +func TestFuseBadK(t *testing.T) { + query := ` + { + fox as var(func: bm25(description_bm25, "fox")) + f as var(func: fuse(fox, k: "-5")) + me(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "k must be") +} diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 39d72d8f5e7..890ca170480 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -214,6 +214,24 @@ type SimilarityType[T c.Float] struct { isSimilarityMetric bool } +// similarityScore converts a heap element's metric-domain value into a +// higher-is-better similarity score suitable for ranking and score fusion. +// +// - Cosine / dot product (isSimilarityMetric): the value is already a similarity +// where higher is better, so it is returned as-is. +// - Euclidean: the value is a squared L2 distance where lower is better. It is +// mapped to 1/(1+d) in (0,1], which is monotonically decreasing in distance, so +// higher is better and the result is well-behaved under linear normalization. +// +// Keeping this orientation in one place lets every caller (and hybrid-search +// fusion) treat vector scores with the same higher-is-better convention as BM25. +func (s SimilarityType[T]) similarityScore(value T) float64 { + if s.isSimilarityMetric { + return float64(value) + } + return 1.0 / (1.0 + float64(value)) +} + func GetSimType[T c.Float](indexType string, floatBits int) SimilarityType[T] { switch { case indexType == Euclidean: diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 5658800e579..e7b8561e24e 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -266,6 +266,20 @@ func (ph *persistentHNSW[T]) SearchWithOptions( maxResults int, opts index.VectorIndexOptions[T], ) ([]uint64, error) { + uids, _, err := ph.SearchWithOptionsScored(ctx, c, query, maxResults, opts) + return uids, err +} + +// SearchWithOptionsScored is SearchWithOptions that also returns a higher-is-better +// similarity score for each returned uid (positionally aligned). See +// index.ScoredSearchOptions. +func (ph *persistentHNSW[T]) SearchWithOptionsScored( + ctx context.Context, + c index.CacheType, + query []T, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, []float64, error) { if opts.Filter == nil { opts.Filter = index.AcceptAll[T] } @@ -279,7 +293,7 @@ func (ph *persistentHNSW[T]) SearchWithOptions( var startVec []T entry, err := ph.PickStartNode(ctx, c, &startVec) if err != nil { - return nil, err + return nil, nil, err } // Upper layers use efUpper (override if provided) @@ -296,13 +310,13 @@ func (ph *persistentHNSW[T]) SearchWithOptions( layerResult, err := ph.searchPersistentLayer( c, level, entry, startVec, query, filterOut, efUpper, opts.Filter) if err != nil { - return nil, err + return nil, nil, err } layerResult.updateFinalMetrics(r) entry = layerResult.bestNeighbor().index layerResult.updateFinalPath(r) if err = ph.getVecFromUid(entry, c, &startVec); err != nil { - return nil, err + return nil, nil, err } } @@ -315,13 +329,14 @@ func (ph *persistentHNSW[T]) SearchWithOptions( layerResult, err := ph.searchPersistentLayer( c, ph.maxLevels-1, entry, startVec, query, filterOut, candidateK, opts.Filter) if err != nil { - return nil, err + return nil, nil, err } layerResult.updateFinalMetrics(r) layerResult.updateFinalPath(r) // Build final neighbor list with optional threshold, limited to maxResults. res := make([]uint64, 0, maxResults) + scores := make([]float64, 0, maxResults) for _, n := range layerResult.neighbors { if maxResults == 0 { break @@ -347,23 +362,38 @@ func (ph *persistentHNSW[T]) SearchWithOptions( } } res = append(res, n.index) + scores = append(scores, ph.simType.similarityScore(n.value)) if len(res) >= maxResults { break } } r.Metrics[searchTime] = uint64(time.Now().UnixMilli() - start) - return res, nil + return res, scores, nil } // SearchWithUidAndOptions is analogous to SearchWithUid but applies per‑call options. func (ph *persistentHNSW[T]) SearchWithUidAndOptions( - _ context.Context, + ctx context.Context, c index.CacheType, queryUid uint64, maxResults int, opts index.VectorIndexOptions[T], ) ([]uint64, error) { + uids, _, err := ph.SearchWithUidAndOptionsScored(ctx, c, queryUid, maxResults, opts) + return uids, err +} + +// SearchWithUidAndOptionsScored is SearchWithUidAndOptions that also returns a +// higher-is-better similarity score for each returned uid (positionally aligned). +// See index.ScoredSearchOptions. +func (ph *persistentHNSW[T]) SearchWithUidAndOptionsScored( + _ context.Context, + c index.CacheType, + queryUid uint64, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, []float64, error) { if opts.Filter == nil { opts.Filter = index.AcceptAll[T] } @@ -373,12 +403,12 @@ func (ph *persistentHNSW[T]) SearchWithUidAndOptions( var queryVec []T if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { if errors.Is(err, errFetchingPostingList) { - return []uint64{}, nil + return []uint64{}, []float64{}, nil } - return []uint64{}, err + return []uint64{}, []float64{}, err } if len(queryVec) == 0 { - return []uint64{}, nil + return []uint64{}, []float64{}, nil } filterOut := !opts.Filter(queryVec, queryVec, queryUid) candidateK := maxResults @@ -388,9 +418,10 @@ func (ph *persistentHNSW[T]) SearchWithUidAndOptions( lr, err := ph.searchPersistentLayer( c, ph.maxLevels-1, queryUid, queryVec, queryVec, filterOut, candidateK, opts.Filter) if err != nil { - return []uint64{}, err + return []uint64{}, []float64{}, err } res := make([]uint64, 0, maxResults) + scores := make([]float64, 0, maxResults) for _, n := range lr.neighbors { if maxResults == 0 { break @@ -413,11 +444,12 @@ func (ph *persistentHNSW[T]) SearchWithUidAndOptions( } } res = append(res, n.index) + scores = append(scores, ph.simType.similarityScore(n.value)) if len(res) >= maxResults { break } } - return res, nil + return res, scores, nil } // SearchWithUid searches the HNSW graph for the nearest neighbors of the query UID @@ -548,13 +580,56 @@ func (ph *persistentHNSW[T]) SearchWithPath( } layerResult.updateFinalMetrics(r) layerResult.updateFinalPath(r) - layerResult.addFinalNeighbors(r) + layerResult.addFinalNeighbors(r, ph.simType) t := time.Now().UnixMilli() elapsed := t - start r.Metrics[searchTime] = uint64(elapsed) return r, nil } +// SearchScored is Search that also returns a higher-is-better similarity score for +// each returned uid (positionally aligned with the neighbor uids). It preserves the +// exact candidate-exploration behavior of Search (unlike the options-based path), +// so scoring an otherwise plain query does not change which neighbors are returned. +func (ph *persistentHNSW[T]) SearchScored(ctx context.Context, c index.CacheType, query []T, + maxResults int, filter index.SearchFilter[T]) ([]uint64, []float64, error) { + r, err := ph.SearchWithPath(ctx, c, query, maxResults, filter) + if err != nil { + return nil, nil, err + } + return r.Neighbors, r.Distances, nil +} + +// SearchWithUidScored is SearchWithUid that also returns a higher-is-better +// similarity score for each returned uid (positionally aligned), preserving +// SearchWithUid's exact neighbor selection. +func (ph *persistentHNSW[T]) SearchWithUidScored(_ context.Context, c index.CacheType, + queryUid uint64, maxResults int, filter index.SearchFilter[T]) ([]uint64, []float64, error) { + var queryVec []T + if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { + if errors.Is(err, errFetchingPostingList) { + return []uint64{}, []float64{}, nil + } + return []uint64{}, []float64{}, err + } + if len(queryVec) == 0 { + return []uint64{}, []float64{}, nil + } + shouldFilterOutQueryVec := !filter(queryVec, queryVec, queryUid) + r, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, queryUid, queryVec, queryVec, shouldFilterOutQueryVec, maxResults, filter) + if err != nil { + return []uint64{}, []float64{}, err + } + uids := make([]uint64, 0, len(r.neighbors)) + scores := make([]float64, 0, len(r.neighbors)) + for _, n := range r.neighbors { + uids = append(uids, n.index) + scores = append(scores, ph.simType.similarityScore(n.value)) + } + return uids, scores, nil +} + // InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, diff --git a/tok/hnsw/search_layer.go b/tok/hnsw/search_layer.go index 55140e7319a..81c529bd29b 100644 --- a/tok/hnsw/search_layer.go +++ b/tok/hnsw/search_layer.go @@ -113,10 +113,13 @@ func (slr *searchLayerResult[T]) updateFinalPath(r *index.SearchPathResult) { r.Path = append(r.Path, slr.path...) } -func (slr *searchLayerResult[T]) addFinalNeighbors(r *index.SearchPathResult) { +func (slr *searchLayerResult[T]) addFinalNeighbors(r *index.SearchPathResult, simType SimilarityType[T]) { for _, n := range slr.neighbors { if !n.filteredOut { r.Neighbors = append(r.Neighbors, n.index) + // Distances carries the higher-is-better similarity for each neighbor, + // positionally aligned with Neighbors, so scored searches can surface it. + r.Distances = append(r.Distances, simType.similarityScore(n.value)) } } } diff --git a/tok/hnsw/similarity_score_test.go b/tok/hnsw/similarity_score_test.go new file mode 100644 index 00000000000..481c97371b7 --- /dev/null +++ b/tok/hnsw/similarity_score_test.go @@ -0,0 +1,36 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hnsw + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestSimilarityScoreOrientation verifies that every metric surfaces a +// higher-is-better similarity score, which native hybrid-search fusion relies on. +func TestSimilarityScoreOrientation(t *testing.T) { + cosine := GetSimType[float32](Cosine, 32) + dot := GetSimType[float32](DotProd, 32) + euclid := GetSimType[float32](Euclidean, 32) + + // Cosine / dot product: returned as-is (already higher-is-better). + require.InDelta(t, 0.9, cosine.similarityScore(0.9), 1e-6) + require.InDelta(t, -0.2, cosine.similarityScore(-0.2), 1e-6) + require.InDelta(t, 12.5, dot.similarityScore(12.5), 1e-6) + + // Euclidean: squared distance mapped to 1/(1+d), so a smaller distance yields a + // larger score (closer = better). + near := euclid.similarityScore(0.0) // distance 0 -> perfect match + mid := euclid.similarityScore(1.0) + far := euclid.similarityScore(9.0) + require.InDelta(t, 1.0, near, 1e-6) + require.InDelta(t, 0.5, mid, 1e-6) + require.InDelta(t, 0.1, far, 1e-6) + require.Greater(t, near, mid) + require.Greater(t, mid, far) +} diff --git a/tok/index/index.go b/tok/index/index.go index 1e981ef189e..050ec85e97e 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -147,6 +147,29 @@ type OptionalSearchOptions[T c.Float] interface { maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) } +// ScoredSearchOptions extends search to also return a higher-is-better similarity +// score for each returned uid (positionally aligned). These power native hybrid +// search: the score is bound to a DQL value variable so vector results can be a +// fusion channel alongside BM25. Scores carry the same higher-is-better convention +// as BM25 (cosine/dot as-is; euclidean as 1/(1+dist)). +// +// SearchScored / SearchWithUidScored preserve the exact neighbor selection of +// Search / SearchWithUid (so scoring a plain query does not change its results), +// while the *Options* variants apply per-call ef/distance-threshold controls. +type ScoredSearchOptions[T c.Float] interface { + SearchScored(ctx context.Context, c CacheType, query []T, + maxResults int, filter SearchFilter[T]) ([]uint64, []float64, error) + + SearchWithUidScored(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, filter SearchFilter[T]) ([]uint64, []float64, error) + + SearchWithOptionsScored(ctx context.Context, c CacheType, query []T, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, []float64, error) + + SearchWithUidAndOptionsScored(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, []float64, error) +} + // A Txn is an interface representation of a persistent storage transaction, // where multiple operations are performed on a database type Txn interface { diff --git a/tok/index/search_path.go b/tok/index/search_path.go index 7e24b7d068c..5387b71ee7f 100644 --- a/tok/index/search_path.go +++ b/tok/index/search_path.go @@ -12,6 +12,10 @@ type SearchPathResult struct { // The collection of nearest neighbors in sorted order after filtering // out neighbors that fail any Filter criteria. Neighbors []uint64 + // Distances holds the higher-is-better similarity score for each entry in + // Neighbors (positionally aligned). It is populated by scored searches and may + // be empty for callers that only need the neighbor uids. + Distances []float64 // The path from the start of search to the closest neighbor vector. Path []uint64 // A collection of captured named counters that occurred for the @@ -24,6 +28,7 @@ type SearchPathResult struct { func NewSearchPathResult() *SearchPathResult { return &SearchPathResult{ Neighbors: []uint64{}, + Distances: []float64{}, Path: []uint64{}, Metrics: make(map[string]uint64), } diff --git a/worker/task.go b/worker/task.go index bcb41b903d2..24b5f4dbe17 100644 --- a/worker/task.go +++ b/worker/task.go @@ -385,6 +385,7 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er return err } var nnUids []uint64 + var nnScores []float64 // Build optional search options if provided filter := index.AcceptAll[float32] opts := index.VectorIndexOptions[float32]{Filter: filter} @@ -395,27 +396,63 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er opts.DistanceThreshold = srcFn.vsDistanceThreshold } hasOptions := opts.EfOverride > 0 || opts.DistanceThreshold != nil - if o, ok := indexer.(index.OptionalSearchOptions[float32]); ok && hasOptions { - if srcFn.vectorInfo != nil { - nnUids, err = o.SearchWithOptions(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) - } else { - nnUids, err = o.SearchWithUidAndOptions(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + // Use the scored search path so the per-uid similarity score can be bound to a + // value variable (powering native hybrid search / fuse()). The scored variants + // mirror their unscored counterparts exactly — SearchScored/SearchWithUidScored + // preserve the plain-query neighbor selection, and the *Options* variants are + // used only when the query supplies ef/distance-threshold — so adding scoring + // does not change which neighbors existing queries return. Indexes that don't + // implement scoring fall back to the unscored path (no scores surfaced). + if so, ok := indexer.(index.ScoredSearchOptions[float32]); ok { + switch { + case hasOptions && srcFn.vectorInfo != nil: + nnUids, nnScores, err = so.SearchWithOptionsScored(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) + case hasOptions: + nnUids, nnScores, err = so.SearchWithUidAndOptionsScored(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + case srcFn.vectorInfo != nil: + nnUids, nnScores, err = so.SearchScored(ctx, qc, srcFn.vectorInfo, int(numNeighbors), index.AcceptAll[float32]) + default: + nnUids, nnScores, err = so.SearchWithUidScored(ctx, qc, srcFn.vectorUid, int(numNeighbors), index.AcceptAll[float32]) } + } else if srcFn.vectorInfo != nil { + nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) } else { - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) - } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) - } + nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) } if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { return err } - sort.Slice(nnUids, func(i, j int) bool { return nnUids[i] < nnUids[j] }) - args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: nnUids}) + + // Emit uids ascending (required by the query pipeline) with positionally + // aligned similarity scores in ValueMatrix. The query layer binds these to a + // value variable so callers can order by and project the score via val(), and + // so vector results can serve as a fuse() channel. Scores are higher-is-better. + order := make([]int, len(nnUids)) + for i := range order { + order[i] = i + } + sort.Slice(order, func(i, j int) bool { return nnUids[order[i]] < nnUids[order[j]] }) + sortedUids := make([]uint64, len(nnUids)) + for i, idx := range order { + sortedUids[i] = nnUids[idx] + } + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: sortedUids}) + + if len(nnScores) == len(nnUids) && len(nnScores) > 0 { + scoreBuf := make([]byte, len(nnUids)*8) + scoreValues := make([]*pb.ValueList, len(nnUids)) + for i, idx := range order { + off := i * 8 + binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(nnScores[idx])) + scoreValues[i] = &pb.ValueList{ + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_FLOAT}}, + } + } + args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) + } return nil } From caef2171d65df7149a2fccfd145ea6e679318c7f Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 3 Jun 2026 23:23:03 -0400 Subject: [PATCH 18/19] fix(review): resolve consensus findings from deep review Auto-fixed issues flagged by a majority (>=3/5) of /judge runs across Claude, GPT5, and Gemini. Minority findings documented in the deep-review report. - hybrid(): bound the generated bm25 channel to topk (first:) so a broad text query no longer scores the entire corpus before fusion (5/5). - fuse(): apply per-channel weights under RRF too, not only linear, so a user passing weights with the default method no longer has them silently ignored; default weight 1.0 keeps standard RRF (4/5). - hybrid(): check the reserved __hybrid var prefix across nested blocks, not just top-level (4/5); reject malformed (odd) option lists instead of dropping them (3/5). - fuse(): distinguish a genuinely missing channel variable (internal invariant error) from a channel that ran but matched nothing (valid empty channel) (3/5). Added unit tests for weighted RRF and the hybrid bm25 bound; updated the hybrid/fuse equivalence test to mirror the bounded bm25 channel. Co-Authored-By: Claude Opus 4.8 (1M context) --- dql/fuse_parser_test.go | 28 ++++++++++++++++++++++++++++ dql/hybrid.go | 37 ++++++++++++++++++++++++++++++++----- query/fuse.go | 22 ++++++++++++++++------ query/fuse_test.go | 21 +++++++++++++++++++++ query/query_hybrid_test.go | 2 +- 5 files changed, 98 insertions(+), 12 deletions(-) diff --git a/dql/fuse_parser_test.go b/dql/fuse_parser_test.go index a445d1cc1d5..a1e223144ca 100644 --- a/dql/fuse_parser_test.go +++ b/dql/fuse_parser_test.go @@ -176,6 +176,34 @@ func TestParseHybrid_ExpandsToThreeBlocks(t *testing.T) { require.NotContains(t, args, "topk") } +func TestParseHybrid_BoundsBM25Channel(t *testing.T) { + // The generated bm25 channel must be bounded to topk so a broad text query does + // not score the whole corpus before fusion. + query := ` + { + f as var(func: hybrid(description, "fox", emb, "[0.1]", topk: 25)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + for _, q := range res.Query { + if q.Func != nil && q.Func.Name == "bm25" { + require.Equal(t, "25", q.Args["first"], "bm25 channel should be capped at topk") + } + } +} + +func TestParseHybrid_MalformedOptions(t *testing.T) { + // A trailing option key without a value must be rejected, not silently dropped. + query := ` + { + f as var(func: hybrid(description, "fox", emb, "[0.1]", method)) + result(func: uid(f), orderdesc: val(f)) { uid } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) +} + func TestParseHybrid_DefaultTopK(t *testing.T) { query := ` { diff --git a/dql/hybrid.go b/dql/hybrid.go index 8506847a46b..15db5b742c7 100644 --- a/dql/hybrid.go +++ b/dql/hybrid.go @@ -18,7 +18,7 @@ import ( // // becomes // -// __hybrid0_bm25 as var(func: bm25(textPred, "query")) +// __hybrid0_bm25 as var(func: bm25(textPred, "query"), first: 100) // __hybrid0_vec as var(func: similar_to(vecPred, 100, $vec)) // x as var(func: fuse(__hybrid0_bm25, __hybrid0_vec, method: "rrf", k: 60)) // @@ -51,11 +51,12 @@ func rewriteHybridBlocks(res *Result) error { // Guard against the (extremely unlikely) case of a user variable colliding with // the synthetic channel names we generate, which would otherwise produce a - // confusing "defined multiple times" error the user can't act on. + // confusing "defined multiple times" error the user can't act on. Variables can + // be defined in nested blocks too, so check the whole query tree, not just roots. for _, qu := range res.Query { - if qu != nil && strings.HasPrefix(qu.Var, hybridVarPrefix) { + if v, ok := findReservedHybridVar(qu); ok { return fmt.Errorf("variable %q uses the reserved prefix %q (used internally by hybrid)", - qu.Var, hybridVarPrefix) + v, hybridVarPrefix) } } @@ -77,6 +78,23 @@ func rewriteHybridBlocks(res *Result) error { return nil } +// findReservedHybridVar walks a query block and its children for any variable using +// the reserved hybrid prefix, returning the first one found. +func findReservedHybridVar(qu *GraphQuery) (string, bool) { + if qu == nil { + return "", false + } + if strings.HasPrefix(qu.Var, hybridVarPrefix) { + return qu.Var, true + } + for _, ch := range qu.Children { + if v, ok := findReservedHybridVar(ch); ok { + return v, true + } + } + return "", false +} + // expandHybridBlock turns a single hybrid() block into [bm25, similar_to, fuse]. func expandHybridBlock(qu *GraphQuery, idx int) ([]*GraphQuery, error) { if qu.Var == "" { @@ -101,6 +119,12 @@ func expandHybridBlock(qu *GraphQuery, idx int) ([]*GraphQuery, error) { vecPred := fn.Args[1].Value vecArg := fn.Args[2] + // Options follow the positionals as key/value pairs; an odd remainder means a + // malformed option list rather than something to silently drop. + if (len(fn.Args)-numPositional)%2 != 0 { + return nil, fmt.Errorf("hybrid: malformed options (expected key:value pairs)") + } + // Parse options: topk feeds similar_to's neighbor count; the rest feed fuse. topk := hybridDefaultTopK var fuseArgs []Arg @@ -117,11 +141,14 @@ func expandHybridBlock(qu *GraphQuery, idx int) ([]*GraphQuery, error) { chanBM25 := fmt.Sprintf("%s%d_bm25", hybridVarPrefix, idx) chanVec := fmt.Sprintf("%s%d_vec", hybridVarPrefix, idx) + // Bound the bm25 channel to the same topk candidate budget as the vector channel + // so a broad text query does not score and materialize the entire corpus before + // fusion. bm25 honors `first` with WAND top-k early termination. bm25Block := &GraphQuery{ Alias: "var", Var: chanBM25, Func: &Function{Name: "bm25", Attr: textPred, Args: []Arg{{Value: queryText}}}, - Args: map[string]string{}, + Args: map[string]string{"first": topk}, } simBlock := &GraphQuery{ Alias: "var", diff --git a/query/fuse.go b/query/fuse.go index 23fa62ea80c..3d36c7e4796 100644 --- a/query/fuse.go +++ b/query/fuse.go @@ -129,7 +129,10 @@ func channelRanks(c fuseChannel) map[uint64]int { return ranks } -// fuseRRF computes Reciprocal Rank Fusion over the channels. +// fuseRRF computes (weighted) Reciprocal Rank Fusion over the channels. Each +// channel contributes weight * 1/(k+rank); with the default weight of 1.0 this is +// standard RRF, and per-channel weights let callers bias channels under either +// fusion method rather than silently ignoring weights for rrf. func fuseRRF(channels []fuseChannel, k float64) map[uint64]float64 { if k <= 0 || math.IsNaN(k) || math.IsInf(k, 0) { k = defaultRRFK @@ -137,7 +140,7 @@ func fuseRRF(channels []fuseChannel, k float64) map[uint64]float64 { fused := make(map[uint64]float64) for _, c := range channels { for uid, rank := range channelRanks(c) { - fused[uid] += 1.0 / (k + float64(rank)) + fused[uid] += c.weight * (1.0 / (k + float64(rank))) } } return fused @@ -282,11 +285,18 @@ func computeFuse(args []dql.Arg, needsVar []dql.VarContext, channels := make([]fuseChannel, len(needsVar)) for i, nv := range needsVar { v, ok := doneVars[nv.Name] - if !ok || v.Vals == nil || v.Vals.IsEmpty() { - // A channel that produced no scored results contributes nothing but is - // still a valid (empty) channel. + switch { + case !ok: + // The dependency scheduler guarantees every channel block has run and + // populated doneVars before this fuse block. A genuinely absent channel + // therefore signals an internal invariant violation rather than an empty + // result — surface it instead of silently degrading the fusion. + return varValue{}, errors.Errorf("fuse: channel %q was not produced", nv.Name) + case v.Vals == nil || v.Vals.IsEmpty(): + // A channel that ran but matched nothing is a valid empty channel: it + // contributes nothing but must not drop the other channels' results. channels[i] = fuseChannel{scores: map[uint64]float64{}, weight: 1.0} - } else { + default: scores, err := scoresFromVar(v, nv.Name) if err != nil { return varValue{}, err diff --git a/query/fuse_test.go b/query/fuse_test.go index 69529c8beae..af1450f66a9 100644 --- a/query/fuse_test.go +++ b/query/fuse_test.go @@ -101,6 +101,27 @@ func TestFuseRRF_DisjointChannels(t *testing.T) { require.InDelta(t, 1/(60.0+2), got[4], 1e-9) } +func TestFuseRRF_AppliesWeights(t *testing.T) { + // Weights must affect RRF (not only linear): a uid ranked #1 in a 2x-weighted + // channel should beat a uid ranked #1 in a unit-weighted channel. + heavy := fuseChannel{scores: map[uint64]float64{1: 9.0}, weight: 2.0} + light := fuseChannel{scores: map[uint64]float64{2: 9.0}, weight: 1.0} + res := fuseChannels([]fuseChannel{heavy, light}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + require.InDelta(t, 2.0*(1/(60.0+1)), got[1], 1e-9) + require.InDelta(t, 1.0*(1/(60.0+1)), got[2], 1e-9) + require.Equal(t, uint64(1), res[0].uid, "heavier-weighted channel's top doc wins") +} + +func TestFuseRRF_DefaultWeightIsStandardRRF(t *testing.T) { + // With the default weight of 1.0, weighted RRF reduces to standard RRF. + a := ch(map[uint64]float64{10: 9.0, 20: 5.0}) + res := fuseChannels([]fuseChannel{a}, fuseOpts{method: fusionRRF, k: 60}) + got := asMap(res) + require.InDelta(t, 1/(60.0+1), got[10], 1e-9) + require.InDelta(t, 1/(60.0+2), got[20], 1e-9) +} + func TestFuseLinear_MaxNormalizeAndWeights(t *testing.T) { // BM25-ish scale vs cosine-ish scale. text := fuseChannel{scores: map[uint64]float64{1: 10.0, 2: 5.0}, weight: 0.3} diff --git a/query/query_hybrid_test.go b/query/query_hybrid_test.go index 7443fa3791b..c75786677be 100644 --- a/query/query_hybrid_test.go +++ b/query/query_hybrid_test.go @@ -215,7 +215,7 @@ func TestHybridSugarEquivalentToFuse(t *testing.T) { }` explicitQ := ` { - txt as var(func: bm25(description_bm25, "fox")) + txt as var(func: bm25(description_bm25, "fox"), first: 7) vec as var(func: similar_to(description_vec, 7, "[3.0, 0.0, 0.0, 0.0]")) f as var(func: fuse(txt, vec, method: "rrf", k: 60)) me(func: uid(f), orderdesc: val(f)) { From 109a2295e037eb0d83760f5408da629650957e0b Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 10 Jun 2026 14:17:18 +0000 Subject: [PATCH 19/19] fix(hybrid): register fuse in isValidFuncName; snapshot ranker scores MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two issues found in deep review of native hybrid search: 1. fuse() rejected at subgraph conversion. The DQL parser accepts `fuse` (validFuncName), but query.go's isValidFuncName did not list it, so every fuse()/hybrid() query failed with "Invalid function name: fuse". The integration tests that exercise this are CI-gated and had not been run. Added "fuse"; all 25 fuse/hybrid/similar_to-score integration tests now pass. (hybrid is rewritten to fuse before this check, so only fuse needs registering.) 2. Score/UID misalignment under @filter. The generalized bm25/similar_to ranker binding zipped uidMatrix[0] with valueMatrix positionally; a later @filter on the ranker block runs updateUidMatrix (and pagination), which shrinks/reorders uidMatrix[0] in place without touching valueMatrix — misbinding scores to UIDs and feeding wrong scores into fuse() channels. Snapshot the aligned worker result into a uid->score map (sg.rankerScores) at result time, before any mutation, and bind by UID. Identical behavior on the tested no-filter paths. --- query/query.go | 56 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/query/query.go b/query/query.go index 35f005ed451..c48cf2d11fe 100644 --- a/query/query.go +++ b/query/query.go @@ -269,6 +269,15 @@ type SubGraph struct { // In graph terms, a list is a slice of outgoing edges from a node. uidMatrix []*pb.List + // rankerScores maps a matched document UID to its ranker score (BM25 relevance + // or vector similarity). It is snapshotted from the (uid-aligned) worker result + // the moment it arrives, before filters or pagination can shrink/reorder + // uidMatrix out of step with valueMatrix. populateUidValVar binds the score + // variable from this map keyed by UID, so the score stays correct even when the + // ranker block carries an @filter. nil unless the source function is a ranker + // (bm25 / similar_to). + rankerScores map[uint64]float64 + // facetsMatrix contains the facet values. There would a list corresponding to each uid in // uidMatrix. facetsMatrix []*pb.FacetsList @@ -1617,29 +1626,24 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su } doneVars[sg.Params.Var].Vals.Set(math.MaxUint64, val) case sg.SrcFunc != nil && (sg.SrcFunc.Name == "bm25" || sg.SrcFunc.Name == "similar_to") && - len(sg.uidMatrix) > 0 && len(sg.valueMatrix) > 0: + sg.rankerScores != nil: // A query-side ranker (BM25 relevance or vector similarity) binds its // per-document score as a value variable. We populate BOTH the matched uid set // and the uid->score map so the variable works with uid(var), val(var) and // orderdesc: val(var) — surfacing and ordering by score without a - // pseudo-predicate or a ParentVars channel. The valueMatrix is positionally - // aligned with the function's returned uidMatrix[0]. For similar_to the score - // is a higher-is-better similarity; this also lets vector results feed fuse(). + // pseudo-predicate or a ParentVars channel. Scores are looked up from the + // uid-keyed snapshot taken at result time (sg.rankerScores), so they remain + // correct even after an @filter on the ranker block shrinks DestUIDs. For + // similar_to the score is a higher-is-better similarity; this also lets vector + // results feed fuse(). if v, ok = doneVars[sg.Params.Var]; !ok { v = varValue{Vals: types.NewShardedMap(), path: sgPath, strList: sg.valueMatrix} } v.Uids = sg.DestUIDs - uids := sg.uidMatrix[0].GetUids() - for idx, uid := range uids { - if idx >= len(sg.valueMatrix) || len(sg.valueMatrix[idx].Values) == 0 { - continue - } - tv := sg.valueMatrix[idx].Values[0] - if len(tv.Val) != 8 { - continue + for _, uid := range sg.DestUIDs.GetUids() { + if score, has := sg.rankerScores[uid]; has { + v.Vals.Set(uid, types.Val{Tid: types.FloatID, Value: score}) } - score := math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) - v.Vals.Set(uid, types.Val{Tid: types.FloatID, Value: score}) } doneVars[sg.Params.Var] = v case len(sg.DestUIDs.Uids) != 0 || (sg.Attr == "uid" && sg.SrcUIDs != nil): @@ -2327,6 +2331,27 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.List = result.List sg.vectorMetrics = result.VectorMetrics + // bm25 and similar_to return their per-document scores in valueMatrix + // positionally aligned with uidMatrix[0]. Snapshot them into a uid-keyed + // map now, while the two are still aligned — later filters/pagination + // shrink uidMatrix without touching valueMatrix, which would otherwise + // misbind scores to UIDs (and feed wrong scores into fuse() channels). + if sg.SrcFunc != nil && (sg.SrcFunc.Name == "bm25" || sg.SrcFunc.Name == "similar_to") && + len(result.UidMatrix) > 0 { + uids := result.UidMatrix[0].GetUids() + sg.rankerScores = make(map[uint64]float64, len(uids)) + for idx, uid := range uids { + if idx >= len(result.ValueMatrix) || len(result.ValueMatrix[idx].Values) == 0 { + continue + } + tv := result.ValueMatrix[idx].Values[0] + if len(tv.Val) != 8 { + continue + } + sg.rankerScores[uid] = math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) + } + } + if sg.Params.DoCount { if len(sg.Filters) == 0 { // If there is a filter, we need to do more work to get the actual count. @@ -2807,7 +2832,8 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25", + "fuse": return true } return isInequalityFn(f) || types.IsGeoFunc(f)