diff --git a/dgraph/cmd/bulk/loader.go b/dgraph/cmd/bulk/loader.go index 410b1a4f9f9..a939004513f 100644 --- a/dgraph/cmd/bulk/loader.go +++ b/dgraph/cmd/bulk/loader.go @@ -33,6 +33,7 @@ import ( "github.com/dgraph-io/dgraph/v25/enc" "github.com/dgraph-io/dgraph/v25/filestore" gqlSchema "github.com/dgraph-io/dgraph/v25/graphql/schema" + "github.com/dgraph-io/dgraph/v25/posting" "github.com/dgraph-io/dgraph/v25/protos/pb" "github.com/dgraph-io/dgraph/v25/schema" "github.com/dgraph-io/dgraph/v25/x" @@ -433,6 +434,12 @@ func (ld *loader) mapStage() { close(ld.readerChunkCh) mapperWg.Wait() + // Flush BM25 corpus statistics accumulated across all mappers as one posting per + // bucket, before the mappers are released. Stats must be summed (not unioned like + // postings), so this single merge-and-write avoids the per-mapper double counting a + // union would produce. + ld.flushBM25Stats() + // Allow memory to GC before the reduce phase. for i := range ld.mappers { ld.mappers[i] = nil @@ -444,6 +451,75 @@ func (ld *loader) mapStage() { ld.xids = nil } +// mergeBM25Stats combines every mapper's per-predicate corpus-statistics partials into +// per-predicate bucket totals. Summing across mappers is what makes the final per-bucket +// counts correct; emitting each mapper's partial as its own posting would be unioned (or +// collapsed last-write-wins) at reduce time and undercount. +func mergeBM25Stats(mappers []*mapper) map[string]*bm25StatEntry { + merged := make(map[string]*bm25StatEntry) + for _, m := range mappers { + if m == nil { + continue + } + for attr, e := range m.bm25Stats { + me := merged[attr] + if me == nil { + me = &bm25StatEntry{} + merged[attr] = me + } + for i := 0; i < posting.NumBM25StatsBuckets; i++ { + me.count[i] += e.count[i] + me.terms[i] += e.terms[i] + } + } + } + return merged +} + +// flushBM25Stats writes the merged BM25 corpus statistics as one value posting per +// non-empty bucket into the same map shard as the predicate's term postings (keyed by +// shardFor(attr)), so they co-locate through shard merge and land in the predicate's +// output DB. Exactly one posting is written per bucket, so the reduce produces a single +// value posting per bucket — the same format the live and rebuild paths store. +func (ld *loader) flushBM25Stats() { + merged := mergeBM25Stats(ld.mappers) + if len(merged) == 0 { + return + } + + // A fresh mapper supplies clean shard buffers; the running mappers have already + // flushed and released theirs. + writer := newMapper(ld.state) + for attr, e := range merged { + shard := ld.shards.shardFor(attr) + for b := 0; b < posting.NumBM25StatsBuckets; b++ { + if e.count[b] == 0 && e.terms[b] == 0 { + continue + } + writer.addMapEntry( + x.BM25StatsKey(attr, b), + &pb.Posting{ + Uid: math.MaxUint64, + PostingType: pb.Posting_VALUE, + ValType: pb.Posting_BINARY, + Value: posting.EncodeBM25Stats(uint64(e.count[b]), uint64(e.terms[b])), + }, + shard, + ) + } + } + + for i := range writer.shards { + sh := &writer.shards[i] + if sh.cbuf.LenNoPadding() > 0 { + sh.mu.Lock() // writeMapEntriesToFile unlocks and releases the buffer. + writer.writeMapEntriesToFile(sh.cbuf, i) + } else if err := sh.cbuf.Release(); err != nil { + glog.Warningf("error releasing bm25 stats buffer: %v", err) + } + } +} + func parseGqlSchema(s string) map[uint64]string { var schemas []x.ExportedGQLSchema if err := json.Unmarshal([]byte(s), &schemas); err != nil { diff --git a/dgraph/cmd/bulk/mapper.go b/dgraph/cmd/bulk/mapper.go index 1a7299ab2cf..ebf4a70a632 100644 --- a/dgraph/cmd/bulk/mapper.go +++ b/dgraph/cmd/bulk/mapper.go @@ -43,6 +43,19 @@ var ( type mapper struct { *state shards []shardState // shard is based on predicate + + // bm25Stats accumulates per-predicate corpus statistics (document count and total + // term count, bucketed by uid) as documents are mapped. BM25 stats must be summed, + // not unioned like postings, so each mapper accumulates locally and the loader + // merges all mappers and flushes one stats posting per bucket after the map phase + // (see loader.flushBM25Stats). Keyed by namespaced predicate. + bm25Stats map[string]*bm25StatEntry +} + +// bm25StatEntry holds one predicate's bucketed corpus-statistics partials for a mapper. +type bm25StatEntry struct { + count [posting.NumBM25StatsBuckets]int64 + terms [posting.NumBM25StatsBuckets]int64 } type shardState struct { @@ -66,8 +79,9 @@ func newMapper(st *state) *mapper { shards[i].cbuf = newMapperBuffer(st.opt) } return &mapper{ - state: st, - shards: shards, + state: st, + shards: shards, + bm25Stats: make(map[string]*bm25StatEntry), } } @@ -295,8 +309,11 @@ func (m *mapper) addMapEntry(key []byte, p *pb.Posting, shard int) { atomic.AddInt64(&m.prog.mapEdgeCount, 1) uid := p.Uid - if p.PostingType != pb.Posting_REF || len(p.Facets) > 0 { - // Keep p + if p.PostingType != pb.Posting_REF || len(p.Facets) > 0 || len(p.Value) > 0 { + // Keep p. A REF posting that carries a Value (e.g. a BM25 term posting packing + // term frequency and document length) must retain that payload — mirroring the + // len(p.Value) > 0 retention clause in List.encode — or it would be reduced to a + // bare UID and the value silently lost. } else { // We only needed the UID. p = nil @@ -456,11 +473,21 @@ func (m *mapper) addIndexMapEntries(nq dql.NQuad, de *pb.DirectedEdge) { // doing edge postings. So okay to be fatal. x.Check(err) + attr := x.NamespaceAttr(nq.Namespace, nq.Predicate) + + // BM25 postings pack (term frequency, document length) into each posting's value + // and require corpus statistics; the generic token path would write bare, + // valueless postings and no stats, leaving bulk-loaded data unsearchable. Handle + // it separately. + if _, isBM25 := toker.(tok.BM25Tokenizer); isBM25 { + m.addBM25IndexMapEntries(attr, nq.Lang, de, schemaVal) + continue + } + // Extract tokens. toks, err := tok.BuildTokens(schemaVal.Value, tok.GetTokenizerForLang(toker, nq.Lang)) x.Check(err) - attr := x.NamespaceAttr(nq.Namespace, nq.Predicate) // Store index posting. for _, t := range toks { m.addMapEntry( @@ -474,3 +501,45 @@ func (m *mapper) addIndexMapEntries(nq dql.NQuad, de *pb.DirectedEdge) { } } } + +// addBM25IndexMapEntries writes the BM25 term postings for one document — one posting +// per distinct term, packing (term frequency, document length) into the value exactly +// as the live index path does — and accumulates the document's contribution to this +// mapper's corpus statistics. The accumulated stats are flushed once per bucket after +// the map phase (loader.flushBM25Stats), because corpus statistics must be summed +// across documents rather than unioned like postings. +func (m *mapper) addBM25IndexMapEntries(attr string, lang string, de *pb.DirectedEdge, + schemaVal types.Val) { + termFreqs, docLen, err := tok.BM25Tokenizer{}.TokensWithFrequency(schemaVal.Value, lang) + x.Check(err) + if docLen == 0 { + // Document tokenizes to zero terms (e.g. all stopwords); it contributes no + // postings and no corpus statistics. + return + } + + shard := m.state.shards.shardFor(attr) + uid := de.GetEntity() + for term, tf := range termFreqs { + encodedTerm := string([]byte{tok.IdentBM25}) + term + m.addMapEntry( + x.IndexKey(attr, encodedTerm), + &pb.Posting{ + Uid: uid, + PostingType: pb.Posting_REF, + ValType: pb.Posting_BINARY, + Value: posting.EncodeBM25Value(tf, docLen), + }, + shard, + ) + } + + entry := m.bm25Stats[attr] + if entry == nil { + entry = &bm25StatEntry{} + m.bm25Stats[attr] = entry + } + bucket := uid % posting.NumBM25StatsBuckets + entry.count[bucket]++ + entry.terms[bucket] += int64(docLen) +} diff --git a/dgraph/cmd/bulk/mapper_test.go b/dgraph/cmd/bulk/mapper_test.go new file mode 100644 index 00000000000..8b449ecf0e5 --- /dev/null +++ b/dgraph/cmd/bulk/mapper_test.go @@ -0,0 +1,59 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package bulk + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/posting" +) + +// TestMergeBM25Stats verifies that per-mapper corpus-statistics partials are summed +// across mappers per (predicate, bucket). This is the linchpin of correct bulk BM25 +// stats: each mapper sees a disjoint subset of documents, so the final doc count and +// total term count for a bucket must be the sum of every mapper's partial — never just +// one mapper's (which a unioned/last-write-wins posting would produce). +func TestMergeBM25Stats(t *testing.T) { + mk := func(attr string, bucket int, count, terms int64) *mapper { + m := &mapper{bm25Stats: map[string]*bm25StatEntry{}} + e := &bm25StatEntry{} + e.count[bucket] = count + e.terms[bucket] = terms + m.bm25Stats[attr] = e + return m + } + + mappers := []*mapper{ + mk("name", 1, 3, 30), + mk("name", 1, 2, 25), // same predicate+bucket as above -> must sum + mk("name", 5, 4, 40), // same predicate, different bucket + mk("bio", 1, 7, 70), // different predicate + nil, // released mappers are skipped + } + + merged := mergeBM25Stats(mappers) + require.Len(t, merged, 2) + + require.Equal(t, int64(5), merged["name"].count[1], "bucket 1 doc count must sum across mappers") + require.Equal(t, int64(55), merged["name"].terms[1], "bucket 1 term count must sum across mappers") + require.Equal(t, int64(4), merged["name"].count[5]) + require.Equal(t, int64(40), merged["name"].terms[5]) + require.Equal(t, int64(7), merged["bio"].count[1]) + require.Equal(t, int64(70), merged["bio"].terms[1]) + + // Untouched buckets stay zero. + require.Equal(t, int64(0), merged["name"].count[0]) + require.Equal(t, int64(0), merged["bio"].count[5]) + + // Total doc count across the predicate equals the sum of all contributing mappers. + var nameDocs int64 + for b := 0; b < posting.NumBM25StatsBuckets; b++ { + nameDocs += merged["name"].count[b] + } + require.Equal(t, int64(9), nameDocs) +} diff --git a/dql/parser.go b/dql/parser.go index 0dd6e1db7ac..666c3eacaab 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -1701,7 +1701,7 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": return true } return false diff --git a/posting/bm25.go b/posting/bm25.go new file mode 100644 index 00000000000..8bd373f32e9 --- /dev/null +++ b/posting/bm25.go @@ -0,0 +1,306 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "context" + "encoding/binary" + "sync/atomic" + + ostats "go.opencensus.io/stats" + + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/tok" + "github.com/dgraph-io/dgraph/v25/x" +) + +// BM25Posting is a single materialized entry of a BM25 term posting list: the +// document UID together with the term frequency and document length decoded from +// the posting value. +type BM25Posting struct { + Uid uint64 + TF uint32 + DocLen uint32 +} + +// ReadBM25TermPostings materializes the postings of a BM25 term's standard index +// list at readTs into a UID-ascending slice, decoding (tf, docLen) from each +// posting value. getList reads a posting list for a key (e.g. LocalCache.Get), +// keeping the value encoding encapsulated in this package. +func ReadBM25TermPostings(getList func(key []byte) (*List, error), attr, encodedTerm string, + readTs uint64) ([]BM25Posting, error) { + key := x.BM25IndexKey(attr, encodedTerm) + pl, err := getList(key) + if err != nil { + return nil, err + } + var out []BM25Posting + err = pl.Iterate(readTs, 0, func(p *pb.Posting) error { + tf, docLen, ok := decodeBM25Value(p.Value) + if !ok { + // Corrupt/truncated posting value: skip rather than inject a bogus + // zero-frequency match that would silently distort scoring. + return nil + } + out = append(out, BM25Posting{Uid: p.Uid, TF: tf, DocLen: docLen}) + return nil + }) + if err != nil { + return nil, err + } + return out, nil +} + +// NumBM25StatsBuckets is the number of buckets the BM25 corpus statistics (document +// count and total term count) are sharded across, keyed by uid%NumBM25StatsBuckets. +// Sharding spreads the read-modify-write contention of stats maintenance across +// independent posting lists so that concurrent mutations on different documents +// rarely conflict, while same-bucket updates still conflict (and retry) — avoiding +// lost updates. A single hot stats key would serialize all writes to the predicate. +// Exported so the bulk loader buckets corpus statistics identically to the live and +// rebuild paths. +const NumBM25StatsBuckets = 32 + +// numBM25StatsBuckets is the unexported alias retained for readability within this +// package. +const numBM25StatsBuckets = NumBM25StatsBuckets + +// EncodeBM25Value packs a posting's term frequency and document length the same way the +// live index path does, for the bulk loader to write BM25 term postings in the standard +// format. See encodeBM25Value. +func EncodeBM25Value(tf, docLen uint32) []byte { return encodeBM25Value(tf, docLen) } + +// EncodeBM25Stats encodes corpus statistics (document count, total term count) for the +// bulk loader to write the per-bucket stats postings in the standard format. See +// encodeBM25Stats. +func EncodeBM25Stats(docCount, totalTerms uint64) []byte { + return encodeBM25Stats(docCount, totalTerms) +} + +// encodeBM25Value packs a posting's term frequency and document length into the +// posting Value as two unsigned varints. Storing the document length alongside the +// term frequency makes scoring read (tf, docLen) in a single posting access — no +// separate document-length list (which would be a write-hot key) and no random +// per-candidate lookup at query time. The document length is duplicated across a +// document's unique terms, but a document's postings are always rewritten together +// on update, so they stay consistent. +func encodeBM25Value(tf, docLen uint32) []byte { + buf := make([]byte, binary.MaxVarintLen32*2) + n := binary.PutUvarint(buf, uint64(tf)) + n += binary.PutUvarint(buf[n:], uint64(docLen)) + return buf[:n] +} + +// decodeBM25Value reverses encodeBM25Value. ok is false when the input does not +// hold two complete varints (e.g. a truncated or corrupt posting value), so the +// caller can skip it rather than silently scoring it as a zero-frequency match. +func decodeBM25Value(b []byte) (tf, docLen uint32, ok bool) { + tf64, n := binary.Uvarint(b) + if n <= 0 { + return 0, 0, false + } + docLen64, m := binary.Uvarint(b[n:]) + if m <= 0 { + return 0, 0, false + } + return uint32(tf64), uint32(docLen64), true +} + +// encodeBM25Stats encodes corpus statistics (document count, total term count) as +// two unsigned varints. +func encodeBM25Stats(docCount, totalTerms uint64) []byte { + buf := make([]byte, binary.MaxVarintLen64*2) + n := binary.PutUvarint(buf, docCount) + n += binary.PutUvarint(buf[n:], totalTerms) + return buf[:n] +} + +// decodeBM25Stats reverses encodeBM25Stats. It returns (0, 0) on malformed input. +func decodeBM25Stats(b []byte) (docCount, totalTerms uint64) { + docCount, n := binary.Uvarint(b) + if n <= 0 { + return 0, 0 + } + totalTerms, m := binary.Uvarint(b[n:]) + if m <= 0 { + return docCount, 0 + } + return docCount, totalTerms +} + +// addBM25TermPosting writes (op=SET) or removes (op=DEL) the posting for the given +// (term, uid) pair in the term's standard index posting list. On SET the posting's +// Value packs (tf, docLen); on DEL only the UID matters. The posting is a REF +// posting (ValueId set) that carries a Value — List.encode retains such postings +// through rollup (see the len(p.Value) > 0 clause there). +func (txn *Txn) addBM25TermPosting(ctx context.Context, attr, term string, uid uint64, + tf, docLen uint32, op pb.DirectedEdge_Op) error { + encodedTerm := string([]byte{tok.IdentBM25}) + term + key := x.BM25IndexKey(attr, encodedTerm) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + edge := &pb.DirectedEdge{ + ValueId: uid, + Attr: attr, + Op: op, + } + if op != pb.DirectedEdge_DEL { + edge.Value = encodeBM25Value(tf, docLen) + edge.ValueType = pb.Posting_BINARY + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + ostats.Record(ctx, x.NumEdges.M(1)) + return nil +} + +// bm25StatsAccum is a concurrency-safe per-bucket accumulator of corpus-statistics +// deltas. Index rebuild routes its per-document stats updates here (across many +// goroutines) instead of the read-modify-write counter, then flushes the buckets once +// as a single writer (see flush). This avoids the undercount that the streaming +// rebuild's independent per-thread caches and periodic resets would otherwise cause, +// where last-write-wins on the value posting drops every thread's partial total but +// one. +type bm25StatsAccum struct { + count [numBM25StatsBuckets]atomic.Int64 + terms [numBM25StatsBuckets]atomic.Int64 +} + +func newBM25StatsAccum() *bm25StatsAccum { return &bm25StatsAccum{} } + +// add records a document's contribution in its bucket (uid%numBM25StatsBuckets). +func (a *bm25StatsAccum) add(uid uint64, docCountDelta, totalTermsDelta int64) { + bucket := uid % numBM25StatsBuckets + a.count[bucket].Add(docCountDelta) + a.terms[bucket].Add(totalTermsDelta) +} + +// flush writes the accumulated absolute totals into the stats posting lists for attr +// through txn, one SET value posting per non-empty bucket. It is a single-writer +// operation (one txn writing all buckets), so there is no lost-update window; the +// caller commits txn. Buckets are written as absolute SETs because a rebuild deletes +// the prior stats first, so the buckets start empty. +func (a *bm25StatsAccum) flush(ctx context.Context, txn *Txn, attr string) error { + for bucket := 0; bucket < numBM25StatsBuckets; bucket++ { + docCount := a.count[bucket].Load() + totalTerms := a.terms[bucket].Load() + if docCount <= 0 && totalTerms <= 0 { + continue + } + key := x.BM25StatsKey(attr, bucket) + plist, err := txn.cache.GetFromDelta(key) + if err != nil { + return err + } + edge := &pb.DirectedEdge{ + Attr: attr, + Value: encodeBM25Stats(uint64(docCount), uint64(totalTerms)), + ValueType: pb.Posting_BINARY, + Op: pb.DirectedEdge_SET, + } + if err := plist.addMutation(ctx, txn, edge); err != nil { + return err + } + } + return nil +} + +// updateBM25Stats applies (docCountDelta, totalTermsDelta) to the bucketed corpus +// statistics for attr. The bucket is selected by uid%numBM25StatsBuckets. The +// running totals are stored as a single value posting per bucket; the read at +// txn.StartTs sees this transaction's own earlier writes (read-your-own-writes), +// so multiple documents in the same transaction that land in the same bucket +// accumulate correctly. +func (txn *Txn) updateBM25Stats(ctx context.Context, attr string, uid uint64, + docCountDelta, totalTermsDelta int64) error { + // During index rebuild, accumulate into the shared accumulator rather than the + // read-modify-write counter (see Txn.bm25Acc). The rebuild flushes the buckets + // once at the end as a single writer. + if txn.bm25Acc != nil { + txn.bm25Acc.add(uid, docCountDelta, totalTermsDelta) + return nil + } + bucket := int(uid % numBM25StatsBuckets) + key := x.BM25StatsKey(attr, bucket) + // Stats are maintained by read-modify-write: we must read the committed total + // from disk (and merge this transaction's own writes), not just the in-memory + // delta. GetFromDelta skips disk and is only safe for write-only index mutations, + // so each transaction would otherwise overwrite the bucket instead of + // accumulating across transactions. Get reads committed state. + plist, err := txn.cache.Get(key) + if err != nil { + return err + } + + var docCount, totalTerms uint64 + val, err := plist.Value(txn.StartTs) + switch err { + case nil: + if data, ok := val.Value.([]byte); ok { + docCount, totalTerms = decodeBM25Stats(data) + } + case ErrNoValue: + // No stats yet for this bucket; start from zero. + default: + return err + } + + docCount = applyBM25Delta(docCount, docCountDelta) + totalTerms = applyBM25Delta(totalTerms, totalTermsDelta) + + edge := &pb.DirectedEdge{ + Attr: attr, + Value: encodeBM25Stats(docCount, totalTerms), + ValueType: pb.Posting_BINARY, + Op: pb.DirectedEdge_SET, + } + return plist.addMutation(ctx, txn, edge) +} + +// applyBM25Delta adds a signed delta to an unsigned counter, clamping at zero. +func applyBM25Delta(v uint64, delta int64) uint64 { + if delta >= 0 { + return v + uint64(delta) + } + dec := uint64(-delta) + if dec > v { + return 0 + } + return v - dec +} + +// ReadBM25Stats sums the bucketed corpus statistics for attr at readTs, returning +// the document count and total term count. avgDL = totalTerms / docCount. The +// getList closure reads a posting list for a key (e.g. LocalCache.Get) so the +// caller controls caching and the read timestamp. +func ReadBM25Stats(getList func(key []byte) (*List, error), attr string, + readTs uint64) (docCount, totalTerms uint64, err error) { + for b := 0; b < numBM25StatsBuckets; b++ { + key := x.BM25StatsKey(attr, b) + pl, perr := getList(key) + if perr != nil { + return 0, 0, perr + } + val, verr := pl.Value(readTs) + if verr == ErrNoValue { + continue + } + if verr != nil { + return 0, 0, verr + } + data, ok := val.Value.([]byte) + if !ok || len(data) == 0 { + continue + } + dc, tt := decodeBM25Stats(data) + docCount += dc + totalTerms += tt + } + return docCount, totalTerms, nil +} diff --git a/posting/bm25_test.go b/posting/bm25_test.go new file mode 100644 index 00000000000..9e4913f555c --- /dev/null +++ b/posting/bm25_test.go @@ -0,0 +1,283 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "bytes" + "context" + "math" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/x" +) + +// TestBM25DropClearsStats guards against orphaned corpus statistics when the BM25 +// index is dropped or rebuilt. BM25 term postings live under the IdentBM25 token +// prefix, but the stats buckets live under a separate reserved token prefix. The +// drop/rebuild machinery deletes by token prefix, so unless the stats prefix is +// also returned by prefixesToDeleteTokensFor, the stats survive a drop and then +// double-count when the index is rebuilt on top of them. +func TestBM25DropClearsStats(t *testing.T) { + attr := x.AttrInRootNamespace("bm25dropstats") + prefixes, err := prefixesToDeleteTokensFor(attr, "bm25", false) + require.NoError(t, err) + + // Every stats bucket key must be covered by one of the deletion prefixes. + for bucket := 0; bucket < 3; bucket++ { + statsKey := x.BM25StatsKey(attr, bucket) + covered := false + for _, p := range prefixes { + if bytes.HasPrefix(statsKey, p) { + covered = true + break + } + } + require.Truef(t, covered, + "stats bucket %d key not covered by any bm25 deletion prefix (orphaned on drop)", bucket) + } +} + +func TestBM25ValueCodecRoundTrip(t *testing.T) { + cases := [][2]uint32{{0, 0}, {1, 1}, {3, 12}, {7, 200}, {65535, 1 << 20}, {1 << 24, 1 << 24}} + for _, c := range cases { + tf, dl, ok := decodeBM25Value(encodeBM25Value(c[0], c[1])) + require.True(t, ok) + require.Equal(t, c[0], tf) + require.Equal(t, c[1], dl) + } + // Malformed/truncated input is reported as invalid so callers can skip it. + _, _, ok := decodeBM25Value(nil) + require.False(t, ok) + _, _, ok = decodeBM25Value([]byte{0x80}) // varint continuation byte with no terminator + require.False(t, ok) +} + +// TestBM25ValueSurvivesRollup verifies the linchpin of the BM25 redesign: a REF +// index posting that carries a packed (tf, docLen) value is retained — value and +// all — through rollup, instead of being collapsed to a UID-only Pack entry (the +// default behavior for REF postings before the len(p.Value) > 0 retention clause +// in List.encode). +func TestBM25ValueSurvivesRollup(t *testing.T) { + attr := x.AttrInRootNamespace("bm25rollup") + encodedTerm := string([]byte{0x10}) + "fox" // IdentBM25 || term + key := x.BM25IndexKey(attr, encodedTerm) + + docs := []struct { + uid uint64 + tf uint32 + docLen uint32 + }{ + {uid: 5, tf: 3, docLen: 12}, + {uid: 9, tf: 1, docLen: 40}, + {uid: 100, tf: 7, docLen: 200}, + } + + ts := uint64(1) + for _, d := range docs { + l, err := GetNoStore(key, ts) + require.NoError(t, err) + edge := &pb.DirectedEdge{ + ValueId: d.uid, + Attr: attr, + Value: encodeBM25Value(d.tf, d.docLen), + ValueType: pb.Posting_BINARY, + } + addMutation(t, l, edge, Set, ts, ts+1, false) + ts += 2 + } + + // Force a rollup and decode the resulting posting list directly. + l, err := getNew(key, pstore, math.MaxUint64, false) + require.NoError(t, err) + kvs, err := l.Rollup(nil, math.MaxUint64) + require.NoError(t, err) + require.NotEmpty(t, kvs) + + var plist pb.PostingList + require.NoError(t, proto.Unmarshal(kvs[0].Value, &plist)) + + got := make(map[uint64][2]uint32) + for _, p := range plist.Postings { + tf, docLen, ok := decodeBM25Value(p.Value) + require.True(t, ok) + got[p.Uid] = [2]uint32{tf, docLen} + } + for _, d := range docs { + v, ok := got[d.uid] + require.Truef(t, ok, "uid %d posting missing after rollup (value stripped?)", d.uid) + require.Equal(t, d.tf, v[0], "tf for uid %d", d.uid) + require.Equal(t, d.docLen, v[1], "docLen for uid %d", d.uid) + } + + // Reading the list back materializes the same (uid, tf, docLen) triples. + posts, err := ReadBM25TermPostings(func(k []byte) (*List, error) { + return getNew(k, pstore, math.MaxUint64, false) + }, attr, encodedTerm, math.MaxUint64) + require.NoError(t, err) + require.Len(t, posts, len(docs)) + for _, p := range posts { + require.Equal(t, got[p.Uid][0], p.TF) + require.Equal(t, got[p.Uid][1], p.DocLen) + } +} + +// TestBM25StatsConflictKeyValueIndependent guards the invariant that makes the +// corpus-stats read-modify-write counter safe under concurrent live transactions: +// two overlapping transactions that read the same bucket base and write DIFFERENT +// resulting totals must still be detected as conflicting (so one retries instead of +// silently losing an update). This holds because on a scalar (non-list) predicate +// fingerprintEdge returns MaxUint64 for every untagged value, so all stats writes to +// a bucket share the conflict key getKey(statsKey, MaxUint64) — independent of the +// value bytes. It is precisely why @index(bm25) is rejected on list predicates, where +// fingerprintEdge would become value-dependent and let differing totals slip past +// conflict detection. +func TestBM25StatsConflictKeyValueIndependent(t *testing.T) { + attr := x.AttrInRootNamespace("bm25statsconflict") + key := x.BM25StatsKey(attr, 3) + pk, err := x.Parse(key) + require.NoError(t, err) + + // Mirror addMutationInternal: a value posting (no ValueId) gets its ValueId set + // from fingerprintEdge before the conflict key is computed. + mkEdge := func(val []byte) *pb.DirectedEdge { + e := &pb.DirectedEdge{Attr: attr, Value: val, ValueType: pb.Posting_BINARY, Op: pb.DirectedEdge_SET} + if NewPosting(e).PostingType != pb.Posting_REF { + e.ValueId = fingerprintEdge(e) + } + return e + } + e1 := mkEdge(encodeBM25Stats(10, 100)) + e2 := mkEdge(encodeBM25Stats(11, 137)) // different totals + require.Equal(t, uint64(math.MaxUint64), e1.ValueId, + "scalar untagged stats values must carry the MaxUint64 sentinel ValueId") + require.Equal(t, e1.ValueId, e2.ValueId, "stats ValueId must not depend on the value bytes") + + ck1 := GetConflictKey(pk, key, e1) + ck2 := GetConflictKey(pk, key, e2) + require.NotZero(t, ck1, "stats writes must register a conflict key") + require.Equal(t, ck1, ck2, + "two writers to the same stats bucket must share a conflict key regardless of the value") +} + +// TestBM25StatsRebuildAccumulator covers the fix for stats undercounting during index +// rebuild. The streaming rebuild processes documents across many goroutines, each with +// its own transaction/cache that is periodically reset — so a per-transaction +// read-modify-write counter loses updates (the last writer's partial total wins on +// merge). Routing rebuild stats through a shared accumulator and flushing the buckets +// once (single writer) must reproduce the exact corpus totals. This test models the +// rebuild by feeding documents (several sharing a bucket) through SEPARATE +// transactions that all share one accumulator, then flushing and reading back. +func TestBM25StatsRebuildAccumulator(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25rebuildacc") + acc := newBM25StatsAccum() + + docs := []struct { + uid uint64 + dl int64 + }{{1, 10}, {2, 20}, {33, 5}, {65, 7}, {3, 8}, {35, 4}, {97, 9}, {4, 6}} + var wantCount, wantTerms int64 + for _, d := range docs { + txn := NewTxn(900) + txn.cache = NewLocalCache(900) + txn.bm25Acc = acc + require.NoError(t, txn.updateBM25Stats(ctx, attr, d.uid, 1, d.dl)) + wantCount++ + wantTerms += d.dl + } + + // Single-writer finalize: flush the accumulator through one transaction and commit. + txn := Oracle().RegisterStartTs(901) + txn.cache = NewLocalCache(901) + require.NoError(t, acc.flush(ctx, txn, attr)) + txn.Update() + txn.UpdateCachedKeys(902) + writer := NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, 902)) + require.NoError(t, writer.Flush()) + + get := func(k []byte) (*List, error) { return GetNoStore(k, 903) } + dc, tt, err := ReadBM25Stats(get, attr, 903) + require.NoError(t, err) + require.Equal(t, uint64(wantCount), dc, "rebuilt doc count must equal the total across all transactions") + require.Equal(t, uint64(wantTerms), tt, "rebuilt total terms must equal the total across all transactions") +} + +// TestBM25StatsBucketed verifies that bucketed corpus statistics accumulate +// correctly across documents (including two documents that hash to the same +// bucket, exercising in-transaction read-your-own-writes) and that deletes +// subtract correctly. +func TestBM25StatsBucketed(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25stats") + ts := uint64(101) + txn := Oracle().RegisterStartTs(ts) + + // uid 1 and uid 33 both fall in bucket 1 (mod 32), exercising same-bucket + // accumulation within a single transaction. + docs := []struct { + uid uint64 + dl int64 + }{{1, 10}, {2, 20}, {33, 5}, {64, 7}, {100, 8}} + + var wantCount, wantTerms int64 + for _, d := range docs { + require.NoError(t, txn.updateBM25Stats(ctx, attr, d.uid, 1, d.dl)) + wantCount++ + wantTerms += d.dl + } + + get := func(k []byte) (*List, error) { return txn.cache.GetFromDelta(k) } + dc, tt, err := ReadBM25Stats(get, attr, ts) + require.NoError(t, err) + require.Equal(t, uint64(wantCount), dc) + require.Equal(t, uint64(wantTerms), tt) + + // Delete uid 2: docCount and totalTerms drop accordingly. + require.NoError(t, txn.updateBM25Stats(ctx, attr, 2, -1, -20)) + dc, tt, err = ReadBM25Stats(get, attr, ts) + require.NoError(t, err) + require.Equal(t, uint64(wantCount-1), dc) + require.Equal(t, uint64(wantTerms-20), tt) +} + +// TestBM25StatsAccumulateAcrossTxns verifies that stats accumulate across +// separately-committed transactions (not just within one). This guards against +// the read-modify-write reading only the in-memory delta instead of committed +// disk state, which would make each transaction overwrite its bucket and collapse +// the corpus document count. +func TestBM25StatsAccumulateAcrossTxns(t *testing.T) { + ctx := context.Background() + attr := x.AttrInRootNamespace("bm25statsxtxn") + + // Two documents in the SAME bucket (uid 5 and uid 37 → bucket 5), committed in + // two separate transactions. + commitDoc := func(startTs, commitTs, uid uint64, docLen int64) { + txn := Oracle().RegisterStartTs(startTs) + txn.cache = NewLocalCache(startTs) + require.NoError(t, txn.updateBM25Stats(ctx, attr, uid, 1, docLen)) + txn.Update() + txn.UpdateCachedKeys(commitTs) + writer := NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + } + + commitDoc(201, 202, 5, 10) + commitDoc(203, 204, 37, 6) + + // A fresh reader at a later ts must see BOTH documents (count 2, terms 16), + // not just the most recently committed one. + get := func(k []byte) (*List, error) { return GetNoStore(k, 205) } + dc, tt, err := ReadBM25Stats(get, attr, 205) + require.NoError(t, err) + require.Equal(t, uint64(2), dc, "doc count must accumulate across transactions") + require.Equal(t, uint64(16), tt, "total terms must accumulate across transactions") +} diff --git a/posting/index.go b/posting/index.go index ae6c3352a44..6412a59960c 100644 --- a/posting/index.go +++ b/posting/index.go @@ -68,6 +68,10 @@ func indexTokens(ctx context.Context, info *indexMutationInfo) ([]string, error) var tokens []string for _, it := range info.tokenizers { + // BM25 tokenizer is handled separately in addBM25IndexMutations. + if it.Identifier() == tok.IdentBM25 { + continue + } toks, err := tok.BuildTokens(sv.Value, tok.GetTokenizerForLang(it, lang)) if err != nil { return tokens, err @@ -179,6 +183,17 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) } } + // Check if any tokenizer is BM25 and handle separately. + for _, it := range info.tokenizers { + if _, ok := tok.GetTokenizerForLang(it, info.edge.GetLang()).(tok.BM25Tokenizer); ok { + if err := txn.addBM25IndexMutations(ctx, info); err != nil { + return []*pb.DirectedEdge{}, err + } + // Continue to process remaining non-BM25 tokenizers below. + continue + } + } + tokens, err := indexTokens(ctx, info) if err != nil { // This data is not indexable @@ -215,6 +230,58 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok return nil } +// addBM25IndexMutations handles index mutations for the BM25 tokenizer. Unlike +// other tokenizers, each BM25 index posting carries a value that packs the term +// frequency together with the document length (see encodeBM25Value). The postings +// are written through the standard delta path (plist.addMutation), so BM25 rides +// Dgraph's normal posting-list machinery — MVCC, deltas, rollup, splits, backup — +// with no separate storage path. Corpus statistics (document count and total term +// count, from which the average document length is derived) are kept in bucketed +// stats posting lists keyed by uid%numBM25StatsBuckets to avoid a single write-hot +// key while preserving conflict detection per bucket. +// +// Updates are driven entirely by the caller (AddMutationWithIndex), which issues a +// DEL for the previous value followed by a SET for the new one. The DEL re-tokenizes +// the old value and removes its postings and stats contribution; the SET adds the new +// ones. We therefore never need to detect updates here. +func (txn *Txn) addBM25IndexMutations(ctx context.Context, info *indexMutationInfo) error { + attr := info.edge.Attr + uid := info.edge.Entity + lang := info.edge.GetLang() + + schemaType, err := schema.State().TypeOf(attr) + if err != nil || !schemaType.IsScalar() { + return errors.Errorf("Cannot BM25 index attribute %s of type object.", attr) + } + + sv, err := types.Convert(info.val, schemaType) + if err != nil { + return err + } + + bm25Tok := tok.BM25Tokenizer{} + termFreqs, docLen, err := bm25Tok.TokensWithFrequency(sv.Value, lang) + if err != nil { + return err + } + + // Skip documents that tokenize to zero terms (e.g., all stopwords). + if docLen == 0 { + return nil + } + + for term, tf := range termFreqs { + if err := txn.addBM25TermPosting(ctx, attr, term, uid, tf, docLen, info.op); err != nil { + return err + } + } + + if info.op == pb.DirectedEdge_DEL { + return txn.updateBM25Stats(ctx, attr, uid, -1, -int64(docLen)) + } + return txn.updateBM25Stats(ctx, attr, uid, 1, int64(docLen)) +} + // countParams is sent to updateCount function. It is used to update the count index. // It deletes the uid from the key corresponding to and adds it // to . @@ -666,6 +733,14 @@ func prefixesToDeleteTokensFor(attr, tokenizerName string, hasLang bool) ([][]by prefix = append(prefix, tokenizer.Identifier()) prefixes = append(prefixes, prefix) + // BM25 stores corpus statistics under a reserved token prefix separate from its + // term-posting (IdentBM25) prefix, so deleting only the tokenizer-identifier + // prefix above would orphan the stats and double-count on rebuild. Delete the + // stats prefix (and its split variant) alongside the term postings. + if tokenizerName == "bm25" { + prefixes = append(prefixes, x.BM25StatsPrefix(attr, false), x.BM25StatsPrefix(attr, true)) + } + return prefixes, nil } @@ -678,6 +753,20 @@ type rebuilder struct { // The posting list passed here is the on disk version. It is not coming // from the LRU cache. fn func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) + + // bm25Acc, when non-nil, collects BM25 corpus statistics across the rebuild's + // per-thread transactions so they can be flushed once as a single writer. Set by + // rebuildTokIndex when a BM25 tokenizer is being rebuilt. See Txn.bm25Acc. + bm25Acc *bm25StatsAccum +} + +// newRebuildTxn creates a transaction for the streaming rebuild, propagating the +// shared BM25 stats accumulator (nil for non-BM25 rebuilds) so per-thread stats +// updates are collected rather than lost across cache resets. +func (r *rebuilder) newRebuildTxn() *Txn { + txn := NewTxn(r.startTs) + txn.bm25Acc = r.bm25Acc + return txn } func (r *rebuilder) RunWithoutTemp(ctx context.Context) error { @@ -941,7 +1030,7 @@ func (r *rebuilder) Run(ctx context.Context) error { txns := make([]*Txn, maxThreadIds) for i := range txns { - txns[i] = NewTxn(r.startTs) + txns[i] = r.newRebuildTxn() } stream.FinishThread = func(threadId int) (*bpb.KVList, error) { @@ -961,7 +1050,7 @@ func (r *rebuilder) Run(ctx context.Context) error { } kvs = append(kvs, &kv) } - txns[threadId] = NewTxn(r.startTs) + txns[threadId] = r.newRebuildTxn() return &bpb.KVList{Kv: kvs}, nil } @@ -1020,7 +1109,7 @@ func (r *rebuilder) Run(ctx context.Context) error { kvs = append(kvs, &kv) } - txns[threadId] = NewTxn(r.startTs) + txns[threadId] = r.newRebuildTxn() return &bpb.KVList{Kv: kvs}, nil } @@ -1041,6 +1130,25 @@ func (r *rebuilder) Run(ctx context.Context) error { if err := stream.Orchestrate(ctx); err != nil { return err } + // Flush the BM25 corpus statistics accumulated across all rebuild threads as a + // single writer. They are written into the temp store as ordinary delta postings + // so the second phase rolls them up into pstore at r.startTs alongside the term + // postings — no separate commit path, and no per-thread last-write-wins loss. + if r.bm25Acc != nil { + flushTxn := NewTxn(r.startTs) + flushTxn.cache = NewLocalCache(r.startTs) + if err := r.bm25Acc.flush(ctx, flushTxn, r.attr); err != nil { + return err + } + flushTxn.Update() + for key, data := range flushTxn.cache.deltas { + counter++ + e := &badger.Entry{Key: []byte(key), Value: data, UserMeta: BitDeltaPosting} + if err := tmpWriter.SetEntryAt(e, counter); err != nil { + return errors.Wrap(err, "error writing bm25 stats to temp index") + } + } + } if err := tmpWriter.Flush(); err != nil { return err } @@ -1446,6 +1554,14 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { pk := x.ParsedKey{Attr: rb.Attr} builder := rebuilder{attr: rb.Attr, prefix: pk.DataPrefix(), startTs: rb.StartTs} + // BM25 corpus statistics must be aggregated across the rebuild's per-thread + // transactions and flushed once; a per-thread read-modify-write counter would + // undercount. Route stats through a shared accumulator when rebuilding bm25. + for _, tokenizer := range tokenizers { + if tokenizer.Identifier() == tok.IdentBM25 { + builder.bm25Acc = newBM25StatsAccum() + } + } builder.fn = func(uid uint64, pl *List, txn *Txn) ([]*pb.DirectedEdge, error) { edge := pb.DirectedEdge{Attr: rb.Attr, Entity: uid} edges := []*pb.DirectedEdge{} diff --git a/posting/list.go b/posting/list.go index 1c0c7a0fc55..610eaf5b5c2 100644 --- a/posting/list.go +++ b/posting/list.go @@ -1627,7 +1627,14 @@ func (l *List) encode(out *rollupOutput, readTs uint64, split bool) error { } enc.Add(p.Uid) - if p.Facets != nil || p.PostingType != pb.Posting_REF { + // Retain the full posting (not just its UID in the Pack) whenever it + // carries facets, is not a plain UID reference, or carries a value. + // BM25 index postings are REF postings that pack (term-frequency, + // doc-length) into Value; without the len(p.Value) > 0 clause that + // value would be stripped at rollup, silently losing all term + // frequencies. This mirrors how faceted postings already coexist in + // both Pack (UID) and Postings (payload). + if p.Facets != nil || p.PostingType != pb.Posting_REF || len(p.Value) > 0 { plist.Postings = append(plist.Postings, p) } return nil diff --git a/posting/mvcc.go b/posting/mvcc.go index 81c5e375553..108cdfc3b3e 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -318,6 +318,7 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { return err } } + return nil } diff --git a/posting/oracle.go b/posting/oracle.go index d7c3837b4b2..6c68193a202 100644 --- a/posting/oracle.go +++ b/posting/oracle.go @@ -54,6 +54,13 @@ type Txn struct { lastUpdate time.Time cache *LocalCache // This pointer does not get modified. + + // bm25Acc, when non-nil, redirects BM25 corpus-statistics updates into a shared + // accumulator instead of the per-transaction read-modify-write counter. Index + // rebuild sets this on its per-thread transactions so stats survive the streaming + // rebuild's independent caches and periodic resets, which would otherwise drop + // updates and undercount the corpus. nil on normal live transactions. + bm25Acc *bm25StatsAccum } // struct to implement Txn interface from vector-indexer diff --git a/query/common_test.go b/query/common_test.go index e36211f7a18..32a3e65a81b 100644 --- a/query/common_test.go +++ b/query/common_test.go @@ -390,6 +390,11 @@ func populateCluster(dc dgraphapi.Cluster) { testSchema += "\ndescription: string @index(ngram) ." } + // BM25 indexing - uses same version gate as ngram for now + if ngramSupport { + testSchema += "\ndescription_bm25: string @index(bm25) ." + } + setSchema(testSchema) err = addTriplesToCluster(` @@ -1007,4 +1012,16 @@ func populateCluster(dc dgraphapi.Cluster) { <415> "Linguistic analysis helps understand text meaning" . `) x.Panic(err) + + // Add data for BM25 tests - uses separate predicate to avoid conflicts + err = addTriplesToCluster(` + <501> "The quick brown fox jumps over the lazy dog" . + <502> "A quick brown fox leaps over a sleeping dog" . + <503> "fox fox fox" . + <504> "The lazy dog sleeps under the warm sun all day long in the garden" . + <505> "Dogs are loyal companions to humans and families everywhere" . + <506> "Quick movements help foxes catch their prey in the wild" . + <507> "Brown foxes are quick and agile animals in the forest" . + `) + x.Panic(err) } diff --git a/query/query.go b/query/query.go index 6926e2ac6ed..ed7492f95f0 100644 --- a/query/query.go +++ b/query/query.go @@ -7,6 +7,7 @@ package query import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -268,6 +269,14 @@ type SubGraph struct { // In graph terms, a list is a slice of outgoing edges from a node. uidMatrix []*pb.List + // bm25Scores maps a matched document UID to its BM25 relevance score. It is + // snapshotted from the (uid-aligned) worker result the moment it arrives, before + // filters or pagination can shrink/reorder uidMatrix out of step with valueMatrix. + // populateUidValVar binds the score variable from this map keyed by UID, so the + // score stays correct even when the bm25 block carries an @filter. nil unless the + // source function is bm25. + bm25Scores map[uint64]float64 + // facetsMatrix contains the facet values. There would a list corresponding to each uid in // uidMatrix. facetsMatrix []*pb.FacetsList @@ -1591,6 +1600,24 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su Value: int64(len(sg.SrcUIDs.Uids)), } doneVars[sg.Params.Var].Vals.Set(math.MaxUint64, val) + case sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && sg.bm25Scores != nil: + // A query-side ranker (BM25) binds its per-document relevance score as a + // value variable. We populate BOTH the matched uid set and the uid->score + // map so the variable works with uid(var), val(var) and orderdesc: val(var) + // — surfacing and ordering by score without a pseudo-predicate or a + // ParentVars channel. Scores are looked up from the uid-keyed snapshot taken + // at result time (sg.bm25Scores), so they remain correct even after an + // @filter on the bm25 block shrinks DestUIDs. + if v, ok = doneVars[sg.Params.Var]; !ok { + v = varValue{Vals: types.NewShardedMap(), path: sgPath, strList: sg.valueMatrix} + } + v.Uids = sg.DestUIDs + for _, uid := range sg.DestUIDs.GetUids() { + if score, has := sg.bm25Scores[uid]; has { + v.Vals.Set(uid, types.Val{Tid: types.FloatID, Value: score}) + } + } + doneVars[sg.Params.Var] = v case len(sg.DestUIDs.Uids) != 0 || (sg.Attr == "uid" && sg.SrcUIDs != nil): // 3. A uid variable. The variable could be defined in one of two places. // a) Either on the actual predicate. @@ -2173,6 +2200,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { rch <- nil return } + var err error switch { case parent == nil && sg.SrcFunc != nil && sg.SrcFunc.Name == "uid": @@ -2275,6 +2303,25 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { sg.List = result.List sg.vectorMetrics = result.VectorMetrics + // bm25 returns its per-document scores in valueMatrix positionally aligned + // with uidMatrix[0]. Snapshot them into a uid-keyed map now, while the two + // are still aligned — later filters/pagination shrink uidMatrix without + // touching valueMatrix, which would otherwise misbind scores to UIDs. + if sg.SrcFunc != nil && sg.SrcFunc.Name == "bm25" && len(result.UidMatrix) > 0 { + uids := result.UidMatrix[0].GetUids() + sg.bm25Scores = make(map[uint64]float64, len(uids)) + for idx, uid := range uids { + if idx >= len(result.ValueMatrix) || len(result.ValueMatrix[idx].Values) == 0 { + continue + } + tv := result.ValueMatrix[idx].Values[0] + if len(tv.Val) != 8 { + continue + } + sg.bm25Scores[uid] = math.Float64frombits(binary.LittleEndian.Uint64(tv.Val)) + } + } + if sg.Params.DoCount { if len(sg.Filters) == 0 { // If there is a filter, we need to do more work to get the actual count. @@ -2373,9 +2420,12 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } if len(sg.Params.Order) == 0 && len(sg.Params.FacetsOrder) == 0 { - // for `has` function when there is no filtering and ordering, we fetch - // correct paginated results so no need to apply pagination here. - if !(len(sg.Filters) == 0 && sg.SrcFunc != nil && sg.SrcFunc.Name == "has") { + // For `has` and `bm25`, the worker already returns correctly paginated + // results (bm25 paginates over score order, which the uid-sorted query-layer + // pagination cannot reproduce), so applying pagination again here would + // double-apply first/offset. Skip it when there is no filtering/ordering. + if !(len(sg.Filters) == 0 && sg.SrcFunc != nil && + (sg.SrcFunc.Name == "has" || sg.SrcFunc.Name == "bm25")) { // There is no ordering. Just apply pagination and return. if err = sg.applyPagination(ctx); err != nil { rch <- err @@ -2452,6 +2502,7 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } child.SrcUIDs = sg.DestUIDs // Make the connection. + if child.IsInternal() { // We dont have to execute these nodes. continue @@ -2751,7 +2802,7 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", "ngram", - "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to": + "has", "uid", "uid_in", "anyof", "allof", "type", "match", "similar_to", "bm25": return true } return isInequalityFn(f) || types.IsGeoFunc(f) diff --git a/query/query_bm25_test.go b/query/query_bm25_test.go new file mode 100644 index 00000000000..a69fd0dca1e --- /dev/null +++ b/query/query_bm25_test.go @@ -0,0 +1,1146 @@ +//go:build integration || cloud + +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +//nolint:lll +package query + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strings" + "sync" + "testing" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/stretchr/testify/require" +) + +// uidHex queries Dgraph for the hex UID string of a given decimal UID. +// This avoids hardcoding hex values that depend on UID assignment order. +func uidHex(t *testing.T, decimalUID int) string { + t.Helper() + js := processQueryNoErr(t, fmt.Sprintf(`{ me(func: uid(%d)) { uid } }`, decimalUID)) + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me, "UID %d should exist", decimalUID) + return resp.Data.Me[0].UID +} + +func TestBM25Basic(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick brown fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return documents containing "quick", "brown", or "fox" + require.Contains(t, js, "quick brown fox jumps") + require.Contains(t, js, "quick brown fox leaps") +} + +func TestBM25Ordering(t *testing.T) { + // BM25 returns all matching documents. Use first:1 to verify the highest-scored + // document is "fox fox fox" (tf=3, short doc). + query := ` + { + me(func: bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should contain all fox-mentioning documents. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + + // first:1 should return the top-ranked document. + topQuery := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + topJs := processQueryNoErr(t, topQuery) + require.Contains(t, topJs, "fox fox fox", + "top-1 BM25 result for 'fox' should be 'fox fox fox' (highest tf, shortest doc)") +} + +func TestBM25WithParams(t *testing.T) { + // Custom k and b parameters + query := ` + { + me(func: bm25(description_bm25, "fox", "1.5", "0.5")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") +} + +func TestBM25InvalidParams(t *testing.T) { + // Negative k should be rejected. + query := ` + { + me(func: bm25(description_bm25, "fox", "-1.0", "0.75")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: k must be a positive finite number") + + // b > 1 should be rejected. + query2 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "1.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query2) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") + + // b < 0 should be rejected. + query3 := ` + { + me(func: bm25(description_bm25, "fox", "1.2", "-0.5")) { + uid + } + } + ` + _, err = processQuery(context.Background(), t, query3) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25: b must be between 0 and 1") +} + +func TestBM25AsFilter(t *testing.T) { + query := ` + { + me(func: has(description_bm25)) @filter(bm25(description_bm25, "fox")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "fox") + // Should not contain documents without "fox" + require.NotContains(t, js, "Dogs are loyal") +} + +func TestBM25NoResults(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "xyznonexistent")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25SingleTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "dog")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "dog") +} + +func TestBM25MultiTerm(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "quick lazy")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should find docs with "quick" or "lazy" (scores accumulate). + // Doc 501 has both "quick" and "lazy", so it should rank high. + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25AllStopwords(t *testing.T) { + // A query consisting entirely of stopwords should return no results. + query := ` + { + me(func: bm25(description_bm25, "the a an")) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25EmptyPredicate(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "")) { + uid + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[]}}`, js) +} + +func TestBM25WithCount(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox")) { + count(uid) + } + } + ` + js := processQueryNoErr(t, query) + // Should have at least 2 results (docs with "fox") + require.Contains(t, js, "count") +} + +func TestBM25Pagination(t *testing.T) { + query := ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // With first:1, should return exactly one result (the highest-scoring). + // Doc 503 "fox fox fox" should be the top result. + require.Contains(t, js, "fox fox fox") +} + +func TestBM25ScoreOrdering(t *testing.T) { + // Bind the bm25 score to a value variable and order results by it via val(). + query := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // "fox fox fox" (doc 503) has the highest BM25 score (tf=3, shortest doc). + require.Contains(t, js, "fox fox fox") +} + +func TestBM25ScoreOrderingMultiTerm(t *testing.T) { + // Multi-term query with score ordering: "quick lazy" should rank doc 501 highest + // since it contains both terms. + query := ` + { + score as var(func: bm25(description_bm25, "quick lazy")) + me(func: uid(score), orderdesc: val(score), first: 1) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + require.Contains(t, js, "quick brown fox jumps over the lazy dog") +} + +func TestBM25ScoreOrderingAllResults(t *testing.T) { + // Verify all results are returned in score-descending order via val(score). + query := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + description_bm25 + val(score) + } + } + ` + js := processQueryNoErr(t, query) + // All fox-containing docs should appear. + require.Contains(t, js, "fox fox fox") + require.Contains(t, js, "quick brown fox jumps") + // Score values should be present. + require.Contains(t, js, "val(score)") +} + +func TestBM25ScoreWithPagination(t *testing.T) { + // Use offset with score ordering. + query := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score), first: 1, offset: 1) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + // Should return the second-highest scored document (not "fox fox fox"). + require.NotContains(t, js, "fox fox fox") + require.Contains(t, js, "fox") +} + +// parseScoresFromJSON extracts uid → score from JSON responses containing val(score). +func parseScoresFromJSON(t *testing.T, js string) map[string]float64 { + t.Helper() + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + scores := make(map[string]float64) + for _, item := range resp.Data.Me { + scores[item.UID] = item.Score + } + return scores +} + +func TestBM25IncrementalAddBatch(t *testing.T) { + batch1 := ` + <600> "alpha bravo charlie" . + <601> "delta echo foxtrot" . + ` + batch2 := ` + <602> "golf hotel india" . + <603> "juliet kilo lima" . + <604> "mike november oscar" . + ` + batch3 := ` + <605> "papa quebec romeo" . + <606> "sierra tango uniform" . + <607> "victor whiskey xray" . + ` + cleanup := func() { + deleteTriplesInCluster(` + <600> * . + <601> * . + <602> * . + <603> * . + <604> * . + <605> * . + <606> * . + <607> * . + `) + } + t.Cleanup(cleanup) + + countQuery := ` + { + me(func: bm25(description_bm25, "alpha bravo delta echo golf juliet mike papa sierra victor")) { + count(uid) + } + } + ` + + // Batch 1: add 2 docs. + require.NoError(t, addTriplesToCluster(batch1)) + js := processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":2`) + + // Batch 2: add 3 more docs → total 5. + require.NoError(t, addTriplesToCluster(batch2)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":5`) + + // Batch 3: add 3 more docs → total 8. + require.NoError(t, addTriplesToCluster(batch3)) + js = processQueryNoErr(t, countQuery) + require.Contains(t, js, `"count":8`) + + // Verify specific new terms are searchable. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "whiskey")) { uid description_bm25 } }`) + require.Contains(t, js, "whiskey") +} + +func TestBM25CorpusStatsAffectIDF(t *testing.T) { + // Capture baseline score for "fox" query. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + jsBefore := processQueryNoErr(t, scoreQuery) + scoresBefore := parseScoresFromJSON(t, jsBefore) + require.NotEmpty(t, scoresBefore, "baseline should have fox results") + + // Add 10 non-fox docs → N grows, df("fox") stays same → IDF should increase. + var triples string + for i := 610; i < 620; i++ { + triples += fmt.Sprintf(`<%d> "completely unrelated document about cats and dogs number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 610; i < 620; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + jsAfter := processQueryNoErr(t, scoreQuery) + scoresAfter := parseScoresFromJSON(t, jsAfter) + + // Compare score for UID 503 ("fox fox fox") — should increase. + uid503 := uidHex(t, 503) + before, ok1 := scoresBefore[uid503] + after, ok2 := scoresAfter[uid503] + require.True(t, ok1 && ok2, "UID 503 should appear in both before and after results") + require.Greater(t, after, before, + "IDF should increase when corpus grows with non-matching docs (before=%f, after=%f)", before, after) +} + +func TestBM25DocumentUpdate(t *testing.T) { + // Add a doc with lots of "fox". + require.NoError(t, addTriplesToCluster(`<620> "fox fox fox fox" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<620> * .`) + }) + + uid620 := uidHex(t, 620) + + // Should rank top for "fox". + js := processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox"), first: 1) { + uid + } + }`) + require.Contains(t, js, `"`+uid620+`"`) + + // Update to remove "fox", add "cat". + deleteTriplesInCluster(`<620> "fox fox fox fox" .`) + require.NoError(t, addTriplesToCluster(`<620> "the cat sat on the mat" .`)) + + // Should no longer appear in "fox" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "fox")) { + uid + } + }`) + require.NotContains(t, js, `"`+uid620+`"`) + + // Should appear in "cat" results. + js = processQueryNoErr(t, ` + { + me(func: bm25(description_bm25, "cat")) { + uid + } + }`) + require.Contains(t, js, `"`+uid620+`"`) +} + +func TestBM25DocumentDeletion(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<625> "unique elephant term" .`)) + t.Cleanup(func() { + // Cleanup in case test fails before explicit delete. + deleteTriplesInCluster(`<625> * .`) + }) + + uid625 := uidHex(t, 625) + + // Should find the elephant doc. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"`+uid625+`"`) + + // Delete it. + deleteTriplesInCluster(`<625> "unique elephant term" .`) + + // Should return empty for "elephant". + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.JSONEq(t, `{"data": {"me":[]}}`, js) + + // Baseline "fox" results should be unaffected. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "fox")) { uid description_bm25 } }`) + require.Contains(t, js, "fox") +} + +func TestBM25ScoreStabilityAsCorpusGrows(t *testing.T) { + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + uid503 := uidHex(t, 503) + + // Phase 1: baseline score. + js1 := processQueryNoErr(t, scoreQuery) + scores1 := parseScoresFromJSON(t, js1) + score1, ok := scores1[uid503] + require.True(t, ok, "UID 503 must appear in baseline") + + // Phase 2: add 5 fox docs → IDF decreases. + var foxTriples string + for i := 630; i < 635; i++ { + foxTriples += fmt.Sprintf(`<%d> "the fox runs quickly across the field number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(foxTriples)) + t.Cleanup(func() { + var del string + for i := 630; i < 640; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + js2 := processQueryNoErr(t, scoreQuery) + scores2 := parseScoresFromJSON(t, js2) + score2, ok := scores2[uid503] + require.True(t, ok, "UID 503 must appear after adding fox docs") + require.Greater(t, score1, score2, + "Adding fox docs should decrease IDF and thus score (phase1=%f, phase2=%f)", score1, score2) + + // Phase 3: add 5 non-fox docs → IDF increases relative to phase 2. + var nonFoxTriples string + for i := 635; i < 640; i++ { + nonFoxTriples += fmt.Sprintf(`<%d> "unrelated content about birds and fish number %d" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(nonFoxTriples)) + + js3 := processQueryNoErr(t, scoreQuery) + scores3 := parseScoresFromJSON(t, js3) + score3, ok := scores3[uid503] + require.True(t, ok, "UID 503 must appear after adding non-fox docs") + require.Greater(t, score3, score2, + "Adding non-fox docs should increase IDF relative to phase2 (phase2=%f, phase3=%f)", score2, score3) +} + +func TestBM25LargeCorpus(t *testing.T) { + // Add 100 docs: 50 with "alpha", 50 with "beta". + var triples string + for i := 700; i < 750; i++ { + triples += fmt.Sprintf(`<%d> "alpha document content number %d with some padding words" . +`, i, i) + } + for i := 750; i < 800; i++ { + triples += fmt.Sprintf(`<%d> "beta document content number %d with some padding words" . +`, i, i) + } + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + var del string + for i := 700; i < 800; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + // Count alpha docs. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Count beta docs. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "beta")) { count(uid) } }`) + require.Contains(t, js, `"count":50`) + + // Union count: "alpha beta" should match all 100. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "alpha beta")) { count(uid) } }`) + require.Contains(t, js, `"count":100`) + + // Pagination: first:10, offset:40 for alpha should return 10 results. + js = processQueryNoErr(t, ` + { + score as var(func: bm25(description_bm25, "alpha")) + me(func: uid(score), orderdesc: val(score), first: 10, offset: 40) { + uid + } + }`) + var resp struct { + Data struct { + Me []struct{ UID string } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.Len(t, resp.Data.Me, 10, "pagination first:10 offset:40 should return exactly 10 results") +} + +func TestBM25EdgeCaseSingleCharTerm(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<640> "x y z" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<640> * .`) + }) + + // Single-char terms may or may not be indexed depending on tokenizer. + // Just verify no panic/error. + _, err := processQuery(context.Background(), t, ` + { + me(func: bm25(description_bm25, "x")) { + uid + } + }`) + require.NoError(t, err) +} + +func TestBM25EdgeCaseLongDocument(t *testing.T) { + // Build a ~500-word document with "fox" appearing once. + words := make([]string, 500) + for i := range words { + words[i] = "padding" + } + words[250] = "fox" + longDoc := strings.Join(words, " ") + + require.NoError(t, addTriplesToCluster(fmt.Sprintf(`<645> %q .`, longDoc))) + t.Cleanup(func() { + deleteTriplesInCluster(`<645> * .`) + }) + + // Get scores for "fox" query. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + uid503 := uidHex(t, 503) // "fox fox fox" (doclen=3) + uid645 := uidHex(t, 645) // long doc (doclen~500) + s503, ok1 := scores[uid503] + s645, ok2 := scores[uid645] + require.True(t, ok1, "UID 503 must appear in fox results") + require.True(t, ok2, "UID 645 must appear in fox results") + require.Greater(t, s503, s645, + "Short doc with high tf should score higher than long doc with low tf (503=%f, 645=%f)", s503, s645) +} + +func TestBM25EdgeCaseUnicode(t *testing.T) { + triples := ` + <650> "der schnelle braune Fuchs springt" . + <651> "le renard brun rapide saute" . + <652> "el zorro marrón rápido salta" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <650> * . + <651> * . + <652> * . + `) + }) + + uid650 := uidHex(t, 650) + uid651 := uidHex(t, 651) + uid652 := uidHex(t, 652) + + // Query German term. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "Fuchs")) { uid } }`) + require.Contains(t, js, `"`+uid650+`"`) + + // Query French term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "renard")) { uid } }`) + require.Contains(t, js, `"`+uid651+`"`) + + // Query Spanish term. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "zorro")) { uid } }`) + require.Contains(t, js, `"`+uid652+`"`) +} + +func TestBM25EdgeCaseAllStopwordsDoc(t *testing.T) { + require.NoError(t, addTriplesToCluster(`<655> "the a an is are was were" .`)) + t.Cleanup(func() { + deleteTriplesInCluster(`<655> * .`) + }) + + uid655 := uidHex(t, 655) + + // Query "the" — should return empty since "the" is a stopword. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "the")) { uid } }`) + require.NotContains(t, js, `"`+uid655+`"`) // 655 should not appear + + // But the doc should exist via has(). + js = processQueryNoErr(t, ` + { + me(func: has(description_bm25)) @filter(uid(655)) { + uid + } + }`) + require.Contains(t, js, `"`+uid655+`"`) +} + +func TestBM25WithUidFilter(t *testing.T) { + // BM25 root with uid filter to restrict results. + query := ` + { + me(func: bm25(description_bm25, "fox")) @filter(uid(501, 503)) { + uid + description_bm25 + } + } + ` + js := processQueryNoErr(t, query) + uid501 := uidHex(t, 501) + uid502 := uidHex(t, 502) + uid503 := uidHex(t, 503) + uid506 := uidHex(t, 506) + // Should contain only UIDs 501 and 503. + require.Contains(t, js, `"`+uid501+`"`) + require.Contains(t, js, `"`+uid503+`"`) + // Should NOT contain other fox docs like 502, 506. + require.NotContains(t, js, `"`+uid502+`"`) + require.NotContains(t, js, `"`+uid506+`"`) +} + +func TestBM25ScoreValuesAreValidFloats(t *testing.T) { + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + } + ` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + require.NotEmpty(t, scores, "should have at least one result") + + var prevScore float64 + first := true + // Iterate over results in order (they're orderdesc by score). + var resp struct { + Data struct { + Me []struct { + UID string `json:"uid"` + Score float64 `json:"val(score)"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + + for _, item := range resp.Data.Me { + score := item.Score + require.False(t, math.IsNaN(score), "score should not be NaN for uid %s", item.UID) + require.False(t, math.IsInf(score, 0), "score should not be Inf for uid %s", item.UID) + require.Greater(t, score, 0.0, "score should be positive for uid %s", item.UID) + + if !first { + require.GreaterOrEqual(t, prevScore, score, + "scores should be in descending order: %f >= %f", prevScore, score) + } + prevScore = score + first = false + } +} + +func TestBM25IncrementalAddThenDeleteThenReadd(t *testing.T) { + t.Cleanup(func() { + deleteTriplesInCluster(`<670> * .`) + }) + + // Phase 1: add with "elephant". + require.NoError(t, addTriplesToCluster(`<670> "elephant roams the savanna" .`)) + uid670 := uidHex(t, 670) + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.Contains(t, js, `"`+uid670+`"`) + + // Phase 2: delete. + deleteTriplesInCluster(`<670> "elephant roams the savanna" .`) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"`+uid670+`"`) + + // Phase 3: re-add with different content. + require.NoError(t, addTriplesToCluster(`<670> "penguin waddles on the ice" .`)) + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "penguin")) { uid } }`) + require.Contains(t, js, `"`+uid670+`"`) + + // "elephant" should still not match 670. + js = processQueryNoErr(t, `{ me(func: bm25(description_bm25, "elephant")) { uid } }`) + require.NotContains(t, js, `"`+uid670+`"`) +} + +func TestBM25NonIndexedPredicateError(t *testing.T) { + // "name" predicate does not have @index(bm25). + query := ` + { + me(func: bm25(name, "alice")) { + uid + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "bm25") +} + +func TestBM25ConcurrentBatchAdd(t *testing.T) { + // Add 5 batches of 4 docs each (UIDs 680-699) back-to-back. + t.Cleanup(func() { + var del string + for i := 680; i < 700; i++ { + del += fmt.Sprintf("<%d> * .\n", i) + } + deleteTriplesInCluster(del) + }) + + for batch := 0; batch < 5; batch++ { + var triples string + for j := 0; j < 4; j++ { + uid := 680 + batch*4 + j + triples += fmt.Sprintf(`<%d> "searchterm batch%d doc%d content here" . +`, uid, batch, j) + } + require.NoError(t, addTriplesToCluster(triples)) + } + + // All 20 docs should be findable. + js := processQueryNoErr(t, `{ me(func: bm25(description_bm25, "searchterm")) { count(uid) } }`) + require.Contains(t, js, `"count":20`) + + // Spot-check a doc from each batch. + for batch := 0; batch < 5; batch++ { + decUID := 680 + batch*4 + hexUID := uidHex(t, decUID) + term := fmt.Sprintf("batch%d", batch) + js = processQueryNoErr(t, fmt.Sprintf(`{ me(func: bm25(description_bm25, "%s")) { uid } }`, term)) + require.Contains(t, js, `"`+hexUID+`"`, "doc %d from batch %d should be searchable", decUID, batch) + } +} + +// parseCorpusCount returns the total number of documents with the description_bm25 predicate. +func parseCorpusCount(t *testing.T) float64 { + t.Helper() + js := processQueryNoErr(t, `{ me(func: has(description_bm25)) { count(uid) } }`) + var resp struct { + Data struct { + Me []struct { + Count int `json:"count"` + } `json:"me"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(js), &resp)) + require.NotEmpty(t, resp.Data.Me) + n := float64(resp.Data.Me[0].Count) + require.Greater(t, n, 0.0, "corpus must have documents") + return n +} + +func TestBM25ExactScoreValues(t *testing.T) { + // Exact score verification using b=0 (BM15 variant) to eliminate avgDL dependency. + // With b=0: score = idf * (k+1) * tf / (k + tf) + // This validates the core BM25 formula computes correct numerical values. + triples := ` + <850> "quasar quasar quasar" . + <851> "quasar nebula pulsar" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <850> * . + <851> * . + `) + }) + + N := parseCorpusCount(t) + + // Query "quasar" with b=0 so score depends only on tf, k, and IDF (not avgDL). + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "quasar", "1.2", "0")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + k := 1.2 + df := 2.0 // both 850 and 851 contain "quasar" + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + + // Doc 850 "quasar quasar quasar": tf=3, b=0 → score = idf * 2.2 * 3 / 4.2 + expected850 := idf * (k + 1) * 3.0 / (k + 3.0) + // Doc 851 "quasar nebula pulsar": tf=1, b=0 → score = idf * 2.2 * 1 / 2.2 = idf + expected851 := idf * (k + 1) * 1.0 / (k + 1.0) + + uid850 := uidHex(t, 850) + uid851 := uidHex(t, 851) + actual850, ok := scores[uid850] + require.True(t, ok, "UID 850 (%s) must be in results", uid850) + actual851, ok := scores[uid851] + require.True(t, ok, "UID 851 (%s) must be in results", uid851) + + require.InEpsilon(t, expected850, actual850, 1e-6, + "Doc 850 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected850, actual850, N, df, idf) + require.InEpsilon(t, expected851, actual851, 1e-6, + "Doc 851 score mismatch: expected %f, got %f (N=%f, df=%f, idf=%f)", + expected851, actual851, N, df, idf) + + // Verify ordering: higher tf should yield higher score. + require.Greater(t, actual850, actual851) +} + +func TestBM25BM15NoLengthNormalization(t *testing.T) { + // With b=0 (BM15 variant), document length should NOT affect the score. + // Two docs with the same term frequency but different lengths must score identically. + triples := ` + <860> "vortex" . + <861> "vortex alpha bravo charlie delta echo foxtrot golf hotel india" . + ` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(` + <860> * . + <861> * . + `) + }) + + // Query with b=0: length normalization disabled. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "vortex", "1.2", "0")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + uid860 := uidHex(t, 860) + uid861 := uidHex(t, 861) + score860, ok1 := scores[uid860] + score861, ok2 := scores[uid861] + require.True(t, ok1, "UID 860 must be in results") + require.True(t, ok2, "UID 861 must be in results") + + // With b=0 and same tf=1, scores must be equal regardless of document length. + require.InDelta(t, score860, score861, 1e-9, + "b=0 should disable length normalization: short doc score=%f, long doc score=%f", + score860, score861) + + // Now verify that with default b=0.75, the shorter doc scores higher. + scoreQueryDefault := ` + { + score as var(func: bm25(description_bm25, "vortex")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js = processQueryNoErr(t, scoreQueryDefault) + scoresDefault := parseScoresFromJSON(t, js) + + defScore860, ok1 := scoresDefault[uid860] + defScore861, ok2 := scoresDefault[uid861] + require.True(t, ok1, "UID 860 must be in default results") + require.True(t, ok2, "UID 861 must be in default results") + require.Greater(t, defScore860, defScore861, + "With b=0.75, shorter doc (doclen=1) should score higher than longer doc (doclen=10)") +} + +func TestBM25SingleMatchingDocument(t *testing.T) { + // Edge case: a single document matching the query term (df=1). + // IDF should be high since the term is very rare. + triples := `<865> "aardvark" .` + require.NoError(t, addTriplesToCluster(triples)) + t.Cleanup(func() { + deleteTriplesInCluster(`<865> * .`) + }) + + N := parseCorpusCount(t) + + // Query with b=0 for exact verification. + scoreQuery := ` + { + score as var(func: bm25(description_bm25, "aardvark", "1.2", "0")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }` + js := processQueryNoErr(t, scoreQuery) + scores := parseScoresFromJSON(t, js) + + require.Len(t, scores, 1, "exactly one document should match 'aardvark'") + + uid865 := uidHex(t, 865) + actual, ok := scores[uid865] + require.True(t, ok, "UID 865 (%s) must be in results", uid865) + + // With df=1, tf=1, b=0, k=1.2: + // idf = log1p((N - 1 + 0.5) / (1 + 0.5)) = log1p((N - 0.5) / 1.5) + // score = idf * 2.2 * 1 / (1.2 + 1) = idf * 2.2 / 2.2 = idf + k := 1.2 + df := 1.0 + idf := math.Log1p((N - df + 0.5) / (df + 0.5)) + expected := idf * (k + 1) * 1.0 / (k + 1.0) // simplifies to idf + + require.InEpsilon(t, expected, actual, 1e-6, + "Single-doc score mismatch: expected %f, got %f (N=%f, idf=%f)", + expected, actual, N, idf) + require.Greater(t, actual, 0.0, "score must be positive") + require.False(t, math.IsInf(actual, 0), "score must be finite") +} + +// TestBM25RebuildOnExistingData loads documents BEFORE a BM25 index exists, then adds +// @index(bm25) to trigger an index rebuild over the existing data. The streaming +// rebuild aggregates corpus statistics across many threads; a naive per-thread +// read-modify-write counter would undercount N/avgDL. This verifies bm25 ranks +// exactly as if the documents had been indexed live. +func TestBM25RebuildOnExistingData(t *testing.T) { + pred := "bm25_rebuild" + t.Cleanup(func() { dropPredicate(pred) }) + + require.NoError(t, addTriplesToCluster(fmt.Sprintf(` + <920> <%[1]s> "fox fox fox" . + <921> <%[1]s> "fox dog" . + <922> <%[1]s> "dog cat bird fish" . + `, pred))) + + // Add the index now: this rebuilds over the three existing documents. + setSchema(fmt.Sprintf("%s: string @index(bm25) .", pred)) + + query := fmt.Sprintf(` + { + score as var(func: bm25(%s, "fox")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }`, pred) + scores := parseScoresFromJSON(t, processQueryNoErr(t, query)) + + require.Len(t, scores, 2, "both documents containing 'fox' must match after rebuild") + require.Greater(t, scores[uidHex(t, 920)], scores[uidHex(t, 921)], + "the denser, shorter document must rank higher after rebuild") + require.Greater(t, scores[uidHex(t, 920)], 0.0, "rebuilt stats must yield a positive score") +} + +// TestBM25DropThenReaddNoDoubleCount verifies that dropping the BM25 index clears its +// corpus statistics so re-adding it rebuilds from scratch. If the stats survived the +// drop, the rebuild would accumulate on top of stale counters, inflating N/avgDL and +// shifting every score. +func TestBM25DropThenReaddNoDoubleCount(t *testing.T) { + pred := "bm25_dropreadd" + t.Cleanup(func() { dropPredicate(pred) }) + + require.NoError(t, addTriplesToCluster(fmt.Sprintf(` + <930> <%[1]s> "alpha alpha beta" . + <931> <%[1]s> "alpha gamma" . + <932> <%[1]s> "beta gamma delta" . + `, pred))) + setSchema(fmt.Sprintf("%s: string @index(bm25) .", pred)) + + query := fmt.Sprintf(` + { + score as var(func: bm25(%s, "alpha")) + me(func: uid(score), orderdesc: val(score)) { + uid + val(score) + } + }`, pred) + before := parseScoresFromJSON(t, processQueryNoErr(t, query)) + require.NotEmpty(t, before) + + // Drop the index (keeping the data), then re-add it to force a rebuild. + setSchema(fmt.Sprintf("%s: string .", pred)) + setSchema(fmt.Sprintf("%s: string @index(bm25) .", pred)) + + after := parseScoresFromJSON(t, processQueryNoErr(t, query)) + + require.Equal(t, len(before), len(after), "result set size must be stable across drop/re-add") + for uid, score := range before { + require.InEpsilon(t, score, after[uid], 1e-9, + "score for %s must be identical after drop/re-add (no stale-stats double counting)", uid) + } +} + +// TestBM25ConcurrentOverlappingTxns adds documents whose UIDs share BM25 stats buckets +// from many goroutines at once. Corpus stats are guarded by a value-independent +// conflict key, so overlapping transactions to the same bucket conflict and one +// retries rather than silently overwriting the other's contribution. Every document +// must end up indexed and searchable. +func TestBM25ConcurrentOverlappingTxns(t *testing.T) { + pred := "bm25_concurrent" + t.Cleanup(func() { dropPredicate(pred) }) + setSchema(fmt.Sprintf("%s: string @index(bm25) .", pred)) + + ctx := context.Background() + // UIDs chosen so several collide on uid%32 (the stats bucket count). + uids := []int{1000, 1032, 1064, 1001, 1033, 1065, 1002, 1034} + + var wg sync.WaitGroup + for _, uid := range uids { + wg.Add(1) + go func(uid int) { + defer wg.Done() + // Retry on conflict, exactly as a client would. + for { + txn := client.NewTxn() + _, err := txn.Mutate(ctx, &api.Mutation{ + SetNquads: []byte(fmt.Sprintf( + `<%d> <%s> "concurrent indexing term doc%d" .`, uid, pred, uid)), + CommitNow: true, + }) + _ = txn.Discard(ctx) + if err == nil { + return + } + } + }(uid) + } + wg.Wait() + + js := processQueryNoErr(t, fmt.Sprintf( + `{ me(func: bm25(%s, "concurrent")) { count(uid) } }`, pred)) + require.Contains(t, js, fmt.Sprintf(`"count":%d`, len(uids)), + "all concurrently-indexed documents must be searchable (no lost stats updates)") +} diff --git a/systest/integration2/bulk_loader_test.go b/systest/integration2/bulk_loader_test.go index 48b59b01d8a..bd005476d02 100644 --- a/systest/integration2/bulk_loader_test.go +++ b/systest/integration2/bulk_loader_test.go @@ -10,6 +10,7 @@ package main import ( "os" "path/filepath" + "strings" "testing" "time" @@ -122,6 +123,62 @@ func TestBulkLoaderSkipReducePhase(t *testing.T) { }`, string(data))) } +// TestBulkLoaderBM25 verifies the BM25 index is correctly built by the bulk loader: +// term postings must carry their packed (term frequency, document length) value, and +// the corpus statistics must be written, or bm25() queries return nothing. It loads a +// small corpus via bulk, then checks that all documents containing the term are found +// and that the densest/shortest document ranks first. +func TestBulkLoaderBM25(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1). + WithACL(time.Hour).WithReplicas(1).WithBulkLoadOutDir(t.TempDir()) + c, err := dgraphtest.NewLocalCluster(conf) + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + + require.NoError(t, c.StartZero(0)) + require.NoError(t, c.HealthCheck(true)) + + baseDir := t.TempDir() + schemaFile := filepath.Join(baseDir, "bm25.schema") + require.NoError(t, os.WriteFile(schemaFile, + []byte("description_bm25: string @index(bm25) .\n"), os.ModePerm)) + + dataFile := filepath.Join(baseDir, "bm25.rdf") + rdf := ` + <0x1> "the quick brown fox jumps over the lazy dog" . + <0x2> "fox fox fox" . + <0x3> "the lazy dog sleeps in the warm sun all day" . + <0x4> "quick brown foxes are agile animals" . + ` + require.NoError(t, os.WriteFile(dataFile, []byte(rdf), os.ModePerm)) + + require.NoError(t, c.BulkLoad(dgraphtest.BulkOpts{ + DataFiles: []string{dataFile}, + SchemaFiles: []string{schemaFile}, + })) + + require.NoError(t, c.Start()) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.RootNamespace)) + + // Both documents containing "fox" (0x1 and 0x2) must be found — proves the term + // postings and corpus statistics survived the bulk build. + data, err := hc.PostDqlQuery(`{ q(func: bm25(description_bm25, "fox")) { count(uid) } }`) + require.NoError(t, err) + require.Contains(t, string(data), `"count":3`, + "bulk-loaded bm25 index must find every document containing the term") + + // The all-"fox" document (0x2, tf=3, shortest) must rank first. + data, err = hc.PostDqlQuery( + `{ q(func: bm25(description_bm25, "fox"), first: 1) { description_bm25 } }`) + require.NoError(t, err) + require.True(t, strings.Contains(string(data), "fox fox fox"), + "densest, shortest document must rank first after bulk load; got: %s", string(data)) +} + func TestBulkLoaderNoDqlSchema(t *testing.T) { conf := dgraphtest.NewClusterConfig().WithNumAlphas(2).WithNumZeros(1). WithACL(time.Hour).WithReplicas(1).WithBulkLoadOutDir(t.TempDir()) diff --git a/tok/tok.go b/tok/tok.go index c1da3e991d7..cb50b0a369e 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -50,6 +50,7 @@ const ( IdentBigFloat = 0xD IdentVFloat = 0xE IdentNGram = 0xF + IdentBM25 = 0x10 IdentCustom = 0x80 IdentDelimiter = 0x1f // ASCII 31 - Unit separator ) @@ -101,6 +102,7 @@ func init() { registerTokenizer(TermTokenizer{}) registerTokenizer(FullTextTokenizer{}) registerTokenizer(NGramTokenizer{}) + registerTokenizer(BM25Tokenizer{}) registerTokenizer(Sha256Tokenizer{}) setupBleve() } @@ -576,6 +578,47 @@ func (t FullTextTokenizer) Identifier() byte { return IdentFullText } func (t FullTextTokenizer) IsSortable() bool { return false } func (t FullTextTokenizer) IsLossy() bool { return true } +// BM25Tokenizer generates tokens for BM25 ranked text search. +// It uses the same pipeline as FullTextTokenizer (normalize, stopwords, stem) +// but preserves duplicates for term frequency counting. +type BM25Tokenizer struct{ lang string } + +func (t BM25Tokenizer) Name() string { return "bm25" } +func (t BM25Tokenizer) Type() string { return "string" } +func (t BM25Tokenizer) Tokens(v interface{}) ([]string, error) { + str, ok := v.(string) + if !ok || str == "" { + return []string{}, nil + } + lang := LangBase(t.lang) + tokens := fulltextAnalyzer.Analyze([]byte(str)) + tokens = filterStopwords(lang, tokens) + tokens = filterStemmers(lang, tokens) + // Return all tokens with duplicates preserved (for TF counting). + result := make([]string, 0, len(tokens)) + for _, t := range tokens { + result = append(result, string(t.Term)) + } + return result, nil +} +func (t BM25Tokenizer) Identifier() byte { return IdentBM25 } +func (t BM25Tokenizer) IsSortable() bool { return false } +func (t BM25Tokenizer) IsLossy() bool { return true } + +// TokensWithFrequency tokenizes the input and returns term frequencies and doc length. +func (t BM25Tokenizer) TokensWithFrequency(v interface{}, lang string) (map[string]uint32, uint32, error) { + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(v) + if err != nil { + return nil, 0, err + } + termFreqs := make(map[string]uint32, len(allTokens)) + for _, t := range allTokens { + termFreqs[t]++ + } + return termFreqs, uint32(len(allTokens)), nil +} + // Sha256Tokenizer generates tokens for the sha256 hash part from string data. type Sha256Tokenizer struct{ _ string } diff --git a/tok/tok_test.go b/tok/tok_test.go index 4c95094e577..b9fbc4dd1a5 100644 --- a/tok/tok_test.go +++ b/tok/tok_test.go @@ -652,6 +652,146 @@ func TestNGramTokenizerNonStringInput(t *testing.T) { require.Equal(t, 0, len(tokens2), "Expected empty tokens for nil input") } +func TestBM25Tokenizer(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + require.NotNil(t, tokenizer) + require.Equal(t, "bm25", tokenizer.Name()) + require.Equal(t, "string", tokenizer.Type()) + require.Equal(t, byte(IdentBM25), tokenizer.Identifier()) + require.True(t, tokenizer.IsLossy()) + require.False(t, tokenizer.IsSortable()) +} + +func TestBM25TokensPreservesDuplicates(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("fox fox fox dog") + require.NoError(t, err) + // "fox" should appear 3 times (duplicates preserved), "dog" once + foxCount := 0 + dogCount := 0 + for _, token := range tokens { + if token == "fox" { + foxCount++ + } + if token == "dog" { + dogCount++ + } + } + require.Equal(t, 3, foxCount, "Expected 3 occurrences of 'fox'") + require.Equal(t, 1, dogCount, "Expected 1 occurrence of 'dog'") +} + +func TestBM25TokensWithFrequency(t *testing.T) { + tok := BM25Tokenizer{} + termFreqs, docLen, err := tok.TokensWithFrequency("the quick brown fox fox fox", "en") + require.NoError(t, err) + // "the" is a stopword and should be removed + _, hasThe := termFreqs["the"] + require.False(t, hasThe, "'the' should be removed as stopword") + // "fox" should have tf=3 + require.Equal(t, uint32(3), termFreqs["fox"]) + // "quick" -> "quick" (stemmed) + require.Contains(t, termFreqs, "quick") + require.Equal(t, uint32(1), termFreqs["quick"]) + // "brown" -> "brown" (stemmed) + require.Contains(t, termFreqs, "brown") + require.Equal(t, uint32(1), termFreqs["brown"]) + // docLen should be total tokens after stopword removal + require.Equal(t, uint32(5), docLen) +} + +func TestBM25TokensEmpty(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestBM25TokensSingleWord(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("hello") + require.NoError(t, err) + require.Equal(t, 1, len(tokens)) + require.Equal(t, "hello", tokens[0]) +} + +func TestBM25TokensStemming(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("running jumping swimming") + require.NoError(t, err) + require.Equal(t, 3, len(tokens)) + require.Contains(t, tokens, "run") + require.Contains(t, tokens, "jump") + require.Contains(t, tokens, "swim") +} + +func TestGetBM25QueryTokens(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"quick brown fox fox"}, "en") + require.NoError(t, err) + // Query tokens should be deduplicated + require.Equal(t, 3, len(tokens)) + // Each token should be encoded with the BM25 identifier prefix + for _, token := range tokens { + require.Equal(t, byte(IdentBM25), token[0], "Token should start with BM25 identifier") + } +} + +func TestGetBM25QueryTokensEmpty(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{""}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestBM25TokenizerForLang(t *testing.T) { + tokenizer, has := GetTokenizer("bm25") + require.True(t, has) + langTok := GetTokenizerForLang(tokenizer, "de") + bm25Tok, ok := langTok.(BM25Tokenizer) + require.True(t, ok) + // German: "Katzen" -> "katz" (stemmed) + tokens, err := bm25Tok.Tokens("Katzen und Katzen") + require.NoError(t, err) + // "und" is a German stopword + katzCount := 0 + for _, token := range tokens { + if token == "katz" { + katzCount++ + } + } + require.Equal(t, 2, katzCount, "Expected 2 occurrences of stemmed 'katz'") +} + +func TestBM25AllStopwords(t *testing.T) { + tok := BM25Tokenizer{lang: "en"} + tokens, err := tok.Tokens("the a an is") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) + + termFreqs, docLen, err := tok.TokensWithFrequency("the a an is", "en") + require.NoError(t, err) + require.Equal(t, 0, len(termFreqs)) + require.Equal(t, uint32(0), docLen) +} + +func TestGetBM25QueryTokensAllStopwords(t *testing.T) { + tokens, err := GetBM25QueryTokens([]string{"the a an"}, "en") + require.NoError(t, err) + require.Equal(t, 0, len(tokens)) +} + +func TestGetBM25QueryTokensWrongArgCount(t *testing.T) { + _, err := GetBM25QueryTokens([]string{}, "en") + require.Error(t, err) + _, err = GetBM25QueryTokens([]string{"a", "b"}, "en") + require.Error(t, err) +} + func BenchmarkTermTokenizer(b *testing.B) { b.Skip() // tmp } diff --git a/tok/tokens.go b/tok/tokens.go index bda9a04e743..f089a3f4344 100644 --- a/tok/tokens.go +++ b/tok/tokens.go @@ -25,6 +25,8 @@ func GetTokenizerForLang(t Tokenizer, lang string) Tokenizer { // We must return a new instance because another goroutine might be calling this // with a different lang. return FullTextTokenizer{lang: lang} + case BM25Tokenizer: + return BM25Tokenizer{lang: lang} case TermTokenizer: return TermTokenizer{lang: lang} case ExactTokenizer: @@ -67,6 +69,29 @@ func GetNGramQueryTokens(funcArgs []string, lang string) ([]string, error) { return BuildNGramQueryTokens(funcArgs[0], NGramTokenizer{lang: lang}) } +// GetBM25QueryTokens tokenizes the query text using the fulltext pipeline, +// deduplicates, and encodes with the BM25 identifier prefix. +func GetBM25QueryTokens(funcArgs []string, lang string) ([]string, error) { + if l := len(funcArgs); l != 1 { + return nil, errors.Errorf("Function requires 1 arguments, but got %d", l) + } + tok := BM25Tokenizer{lang: lang} + allTokens, err := tok.Tokens(funcArgs[0]) + if err != nil { + return nil, err + } + // Deduplicate for query + seen := make(map[string]struct{}, len(allTokens)) + var unique []string + for _, t := range allTokens { + if _, ok := seen[t]; !ok { + seen[t] = struct{}{} + unique = append(unique, encodeToken(t, tok.Identifier())) + } + } + return unique, nil +} + // GetFullTextTokens returns the full-text tokens for the given value. func GetFullTextTokens(funcArgs []string, lang string) ([]string, error) { if l := len(funcArgs); l != 1 { diff --git a/worker/bm25wand.go b/worker/bm25wand.go new file mode 100644 index 00000000000..02b5f9affd2 --- /dev/null +++ b/worker/bm25wand.go @@ -0,0 +1,433 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "sort" + + "github.com/dgraph-io/dgraph/v25/posting" +) + +// wandBlockSize is the number of postings grouped into one logical block for +// Block-Max WAND upper bounds. The postings come from a standard Dgraph posting +// list (already resident in memory once loaded); these blocks exist only to give +// WAND per-block score bounds for pruning — they are not a storage format. +const wandBlockSize = 128 + +// termCursor is an in-memory cursor over one query term's posting list, +// materialized from the standard posting List as UID-ascending (uid, tf, docLen) +// entries. Document length travels with each posting, so scoring needs no separate +// lookup. Per-block max bounds drive Block-Max WAND pruning. +type termCursor struct { + postings []posting.BM25Posting + idf float64 + pos int + + // blockUBPre[i] is the pre-IDF BM25 upper bound for block i (max term + // frequency, min document length in the block). suffixUBPre[i] = max over + // j >= i of blockUBPre[j], for the remaining-list upper bound. + blockUBPre []float64 + suffixUBPre []float64 +} + +// ubPre computes the pre-IDF BM25 contribution upper bound for a block, using the +// block's maximum term frequency and minimum document length (the score is +// increasing in tf and decreasing in dl, so this is a safe upper bound). +func ubPre(maxTF, minDL uint32, k, b, avgDL float64) float64 { + if avgDL <= 0 { + avgDL = 1 + } + tf := float64(maxTF) + dl := float64(minDL) + denom := k*(1-b+b*dl/avgDL) + tf + if denom <= 0 { + return 0 + } + return (k + 1) * tf / denom +} + +// newTermCursor builds a cursor and precomputes its per-block upper bounds. +func newTermCursor(postings []posting.BM25Posting, idf, k, b, avgDL float64) *termCursor { + c := &termCursor{postings: postings, idf: idf} + numBlocks := (len(postings) + wandBlockSize - 1) / wandBlockSize + c.blockUBPre = make([]float64, numBlocks) + for blk := 0; blk < numBlocks; blk++ { + start := blk * wandBlockSize + end := start + wandBlockSize + if end > len(postings) { + end = len(postings) + } + var maxTF uint32 + minDL := uint32(math.MaxUint32) + for i := start; i < end; i++ { + if postings[i].TF > maxTF { + maxTF = postings[i].TF + } + dl := postings[i].DocLen + if dl == 0 { + dl = 1 + } + if dl < minDL { + minDL = dl + } + } + c.blockUBPre[blk] = ubPre(maxTF, minDL, k, b, avgDL) + } + c.suffixUBPre = make([]float64, numBlocks) + var running float64 + for blk := numBlocks - 1; blk >= 0; blk-- { + if c.blockUBPre[blk] > running { + running = c.blockUBPre[blk] + } + c.suffixUBPre[blk] = running + } + return c +} + +func (c *termCursor) exhausted() bool { return c.pos >= len(c.postings) } + +func (c *termCursor) currentDoc() uint64 { + if c.exhausted() { + return math.MaxUint64 + } + return c.postings[c.pos].Uid +} + +func (c *termCursor) currentTF() uint32 { + if c.exhausted() { + return 0 + } + return c.postings[c.pos].TF +} + +func (c *termCursor) currentDocLen() uint32 { + if c.exhausted() { + return 0 + } + return c.postings[c.pos].DocLen +} + +// remainingUB returns the IDF-weighted upper-bound score over the remainder of the +// list from the current position. +func (c *termCursor) remainingUB() float64 { + if c.exhausted() || len(c.suffixUBPre) == 0 { + return 0 + } + blk := c.pos / wandBlockSize + if blk >= len(c.suffixUBPre) { + return 0 + } + return c.idf * c.suffixUBPre[blk] +} + +// next advances by one posting. +func (c *termCursor) next() bool { + c.pos++ + return !c.exhausted() +} + +// skipTo advances to the first posting with UID >= target. +func (c *termCursor) skipTo(target uint64) bool { + if c.exhausted() { + return false + } + if c.postings[c.pos].Uid >= target { + return true + } + rel := sort.Search(len(c.postings)-c.pos, func(i int) bool { + return c.postings[c.pos+i].Uid >= target + }) + c.pos += rel + return !c.exhausted() +} + +// skipToWithBMW is skipTo with Block-Max WAND pruning: blocks whose upper bound +// combined with otherUB cannot beat theta are skipped wholesale. +func (c *termCursor) skipToWithBMW(target uint64, theta, otherUB float64) bool { + if !c.skipTo(target) { + return false + } + for !c.exhausted() { + blk := c.pos / wandBlockSize + if c.idf*c.blockUBPre[blk]+otherUB > theta { + return true + } + // This block can't produce a winner; jump to the start of the next block. + c.pos = (blk + 1) * wandBlockSize + } + return false +} + +// scoredDoc holds a UID and its BM25 score for the min-heap. +type scoredDoc struct { + uid uint64 + score float64 +} + +// topKHeap is a min-heap of scored documents for top-k tracking. +type topKHeap struct { + docs []scoredDoc + k int +} + +func (h *topKHeap) Len() int { return len(h.docs) } +func (h *topKHeap) Less(i, j int) bool { return h.docs[i].score < h.docs[j].score } +func (h *topKHeap) Swap(i, j int) { h.docs[i], h.docs[j] = h.docs[j], h.docs[i] } +func (h *topKHeap) Push(x interface{}) { h.docs = append(h.docs, x.(scoredDoc)) } +func (h *topKHeap) Pop() interface{} { + old := h.docs + n := len(old) + item := old[n-1] + h.docs = old[:n-1] + return item +} + +// threshold returns the minimum score in the heap (the score to beat). +func (h *topKHeap) threshold() float64 { + if len(h.docs) < h.k { + return 0 + } + return h.docs[0].score +} + +// tryPush adds a doc if it beats the current threshold. +func (h *topKHeap) tryPush(uid uint64, score float64) { + if len(h.docs) < h.k { + heap.Push(h, scoredDoc{uid: uid, score: score}) + return + } + if score > h.docs[0].score { + h.docs[0] = scoredDoc{uid: uid, score: score} + heap.Fix(h, 0) + } +} + +// sorted returns all docs sorted by score descending, then UID ascending. +func (h *topKHeap) sorted() []scoredDoc { + result := make([]scoredDoc, len(h.docs)) + copy(result, h.docs) + sort.Slice(result, func(i, j int) bool { + if result[i].score != result[j].score { + return result[i].score > result[j].score + } + return result[i].uid < result[j].uid + }) + return result +} + +// bm25TopK returns how many top-scored documents the WAND search should retain for a +// first/offset query window. With a first limit it is first+offset (so the offset can +// be dropped afterward while still bounding work and memory to the window); with no +// first limit it is 0, meaning every matching document is scored. Returning first+offset +// rather than 0 whenever an offset is present is what keeps a deep-paginated query from +// materializing and scoring the entire corpus. +func bm25TopK(first, offset int) int { + if first <= 0 { + return 0 + } + return first + offset +} + +// bm25PaginateScored slices score-descending results to the [offset, offset+first) +// window. first <= 0 means no upper bound. It clamps offset to the slice length so an +// offset past the end yields an empty result instead of panicking. +func bm25PaginateScored(results []scoredDoc, first, offset int) []scoredDoc { + if offset > len(results) { + offset = len(results) + } + results = results[offset:] + if first > 0 && first < len(results) { + results = results[:first] + } + return results +} + +// bm25Score computes the BM25 contribution of a single term occurrence. +func bm25Score(idf, tf, dl, avgDL, k, b float64) float64 { + if avgDL <= 0 { + avgDL = 1 + } + if dl <= 0 { + dl = 1 + } + return idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) +} + +// wandSearch performs a WAND / Block-Max WAND top-k BM25 search over standard +// posting lists. queryTokens must already carry the BM25 tokenizer identifier +// byte. getList reads a posting list for a key. If topK <= 0, every matching +// document is scored (no early termination). +func wandSearch(getList func(key []byte) (*posting.List, error), attr string, readTs uint64, + queryTokens []string, k, b, avgDL, N float64, topK int, + filterSet map[uint64]struct{}, useBMW bool) ([]scoredDoc, error) { + + var cursors []*termCursor + for _, token := range queryTokens { + postings, err := posting.ReadBM25TermPostings(getList, attr, token, readTs) + if err != nil { + return nil, err + } + df := uint64(len(postings)) + if df == 0 { + continue + } + // N comes from bucketed stats and df from the term's posting list; if stats + // ever lag the postings, clamp N >= df for this term so the smoothed IDF + // stays non-negative and finite instead of producing a negative/NaN score. + dfN := float64(df) + nDocs := N + if nDocs < dfN { + nDocs = dfN + } + idf := math.Log1p((nDocs - dfN + 0.5) / (dfN + 0.5)) + cursors = append(cursors, newTermCursor(postings, idf, k, b, avgDL)) + } + + if len(cursors) == 0 { + return nil, nil + } + + if topK <= 0 { + return scoreAllDocs(cursors, k, b, avgDL, filterSet), nil + } + return wandTopK(cursors, k, b, avgDL, topK, filterSet, useBMW), nil +} + +// wandTopK runs the WAND / Block-Max WAND main loop over prepared cursors and +// returns the top-k documents sorted by score descending. It is the core scoring +// loop, separated from posting-list I/O so it can be exercised directly. +func wandTopK(cursors []*termCursor, k, b, avgDL float64, topK int, + filterSet map[uint64]struct{}, useBMW bool) []scoredDoc { + + h := &topKHeap{k: topK} + heap.Init(h) + + for { + // Drop exhausted cursors. + active := cursors[:0] + for _, c := range cursors { + if !c.exhausted() { + active = append(active, c) + } + } + cursors = active + if len(cursors) == 0 { + break + } + + // Sort cursors by current document ascending. + sort.Slice(cursors, func(i, j int) bool { + return cursors[i].currentDoc() < cursors[j].currentDoc() + }) + + theta := h.threshold() + + // Find pivot: accumulate upper bounds until they exceed theta. + var sumUB float64 + pivot := -1 + var pivotDoc uint64 + for i, c := range cursors { + sumUB += c.remainingUB() + if sumUB > theta && pivot == -1 { + pivot = i + pivotDoc = c.currentDoc() + } + } + if pivot == -1 { + break // sum of all upper bounds can't beat theta + } + + // Advance all cursors before the pivot up to pivotDoc. + allAtPivot := true + for i := 0; i < pivot; i++ { + if cursors[i].currentDoc() < pivotDoc { + var ok bool + if useBMW { + otherUB := sumUB - cursors[i].remainingUB() + ok = cursors[i].skipToWithBMW(pivotDoc, theta, otherUB) + } else { + ok = cursors[i].skipTo(pivotDoc) + } + if !ok { + allAtPivot = false + break + } + if cursors[i].currentDoc() != pivotDoc { + allAtPivot = false + } + } + } + if !allAtPivot { + continue + } + + // Score the pivot document. + if filterSet != nil { + if _, ok := filterSet[pivotDoc]; !ok { + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + c.next() + } + } + continue + } + } + + var score float64 + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + dl := float64(c.currentDocLen()) + score += bm25Score(c.idf, float64(c.currentTF()), dl, avgDL, k, b) + } + } + h.tryPush(pivotDoc, score) + + for _, c := range cursors { + if c.currentDoc() == pivotDoc { + c.next() + } + } + } + + return h.sorted() +} + +// scoreAllDocs scores every matching document without early termination. Used when +// no top-k limit is specified. +func scoreAllDocs(cursors []*termCursor, k, b, avgDL float64, + filterSet map[uint64]struct{}) []scoredDoc { + + scores := make(map[uint64]float64) + + for _, c := range cursors { + for !c.exhausted() { + uid := c.currentDoc() + if filterSet != nil { + if _, ok := filterSet[uid]; !ok { + c.next() + continue + } + } + scores[uid] += bm25Score(c.idf, float64(c.currentTF()), float64(c.currentDocLen()), + avgDL, k, b) + c.next() + } + } + + results := make([]scoredDoc, 0, len(scores)) + for uid, s := range scores { + results = append(results, scoredDoc{uid: uid, score: s}) + } + sort.Slice(results, func(i, j int) bool { + if results[i].score != results[j].score { + return results[i].score > results[j].score + } + return results[i].uid < results[j].uid + }) + return results +} diff --git a/worker/bm25wand_test.go b/worker/bm25wand_test.go new file mode 100644 index 00000000000..93031a16301 --- /dev/null +++ b/worker/bm25wand_test.go @@ -0,0 +1,332 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "container/heap" + "math" + "math/rand" + "sort" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/posting" +) + +func TestTopKHeapBasic(t *testing.T) { + h := &topKHeap{k: 3} + heap.Init(h) + + require.Equal(t, 0.0, h.threshold()) + + h.tryPush(1, 5.0) + h.tryPush(2, 3.0) + require.Equal(t, 0.0, h.threshold()) // not full yet + + h.tryPush(3, 7.0) + require.InEpsilon(t, 3.0, h.threshold(), 1e-9) // full, min is 3.0 + + h.tryPush(4, 4.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) // 3.0 evicted, min is now 4.0 + + // 2.0 shouldn't be accepted. + h.tryPush(5, 2.0) + require.InEpsilon(t, 4.0, h.threshold(), 1e-9) + + sorted := h.sorted() + require.Len(t, sorted, 3) + require.Equal(t, uint64(3), sorted[0].uid) // highest score (7.0) + require.Equal(t, uint64(1), sorted[1].uid) // 5.0 + require.Equal(t, uint64(4), sorted[2].uid) // 4.0 +} + +func TestTopKHeapTieBreaking(t *testing.T) { + h := &topKHeap{k: 5} + heap.Init(h) + + // Same score, different UIDs — should sort by UID ascending. + h.tryPush(10, 5.0) + h.tryPush(5, 5.0) + h.tryPush(15, 5.0) + + sorted := h.sorted() + require.Equal(t, uint64(5), sorted[0].uid) + require.Equal(t, uint64(10), sorted[1].uid) + require.Equal(t, uint64(15), sorted[2].uid) +} + +func TestBm25TopK(t *testing.T) { + // No first limit: score every matching document (0 means "no early termination"). + require.Equal(t, 0, bm25TopK(0, 0)) + require.Equal(t, 0, bm25TopK(0, 100)) + + // With a first limit, WAND must retain first+offset documents so the offset can be + // dropped afterward — NOT 0 (which would fall back to scoring the entire corpus + // just because an offset was supplied, the memory blow-up this guards against). + require.Equal(t, 10, bm25TopK(10, 0)) + require.Equal(t, 15, bm25TopK(10, 5)) + require.Equal(t, 1001, bm25TopK(1, 1000)) +} + +func TestBm25PaginateScored(t *testing.T) { + mk := func(uids ...uint64) []scoredDoc { + out := make([]scoredDoc, len(uids)) + for i, u := range uids { + out[i] = scoredDoc{uid: u, score: float64(len(uids) - i)} // already score-descending + } + return out + } + ids := func(ds []scoredDoc) []uint64 { + out := make([]uint64, len(ds)) + for i, d := range ds { + out[i] = d.uid + } + return out + } + + full := mk(1, 2, 3, 4, 5) + require.Equal(t, []uint64{1, 2, 3, 4, 5}, ids(bm25PaginateScored(full, 0, 0))) + require.Equal(t, []uint64{1, 2}, ids(bm25PaginateScored(mk(1, 2, 3, 4, 5), 2, 0))) + require.Equal(t, []uint64{3, 4}, ids(bm25PaginateScored(mk(1, 2, 3, 4, 5), 2, 2))) + require.Equal(t, []uint64{4, 5}, ids(bm25PaginateScored(mk(1, 2, 3, 4, 5), 10, 3))) + // Offset past the end yields nothing rather than panicking. + require.Empty(t, bm25PaginateScored(mk(1, 2, 3), 2, 10)) +} + +func TestBm25ScoreFunction(t *testing.T) { + k, b := 1.2, 0.75 + avgDL := 10.0 + + // idf * (k+1) * tf / (k*(1-b+b*dl/avgDL) + tf) + idf := 1.5 + tf := 3.0 + dl := 10.0 + + expected := idf * (k + 1) * tf / (k*(1-b+b*dl/avgDL) + tf) + got := bm25Score(idf, tf, dl, avgDL, k, b) + require.InEpsilon(t, expected, got, 1e-9) + + // With b=0: no length normalization. + expected0 := idf * (k + 1) * tf / (k + tf) + got0 := bm25Score(idf, tf, dl, avgDL, k, 0) + require.InEpsilon(t, expected0, got0, 1e-9) + + // Score should be positive for positive inputs. + require.Greater(t, bm25Score(1.0, 1.0, 5.0, 10.0, k, b), 0.0) + + // Higher tf should produce higher score (same dl). + s1 := bm25Score(idf, 1.0, dl, avgDL, k, b) + s3 := bm25Score(idf, 3.0, dl, avgDL, k, b) + require.Greater(t, s3, s1) + + // Shorter doc should score higher (same tf). + sShort := bm25Score(idf, tf, 5.0, avgDL, k, b) + sLong := bm25Score(idf, tf, 20.0, avgDL, k, b) + require.Greater(t, sShort, sLong) +} + +func TestBm25ScoreNaN(t *testing.T) { + // Ensure no NaN/Inf for edge-case inputs. + score := bm25Score(0.5, 1.0, 0.0, 10.0, 1.2, 0.75) + require.False(t, math.IsNaN(score)) + require.False(t, math.IsInf(score, 0)) + require.Greater(t, score, 0.0) +} + +// brute force scores every doc across all cursors (ground truth for WAND). When +// filterSet is non-nil, only documents in it are scored — mirroring @filter(bm25(...)). +func bruteForceTopK(termPostings [][]posting.BM25Posting, idfs []float64, + k, b, avgDL float64, topK int) []scoredDoc { + return bruteForceTopKFiltered(termPostings, idfs, k, b, avgDL, topK, nil) +} + +func bruteForceTopKFiltered(termPostings [][]posting.BM25Posting, idfs []float64, + k, b, avgDL float64, topK int, filterSet map[uint64]struct{}) []scoredDoc { + scores := map[uint64]float64{} + for ti, ps := range termPostings { + for _, p := range ps { + if filterSet != nil { + if _, ok := filterSet[p.Uid]; !ok { + continue + } + } + scores[p.Uid] += bm25Score(idfs[ti], float64(p.TF), float64(p.DocLen), avgDL, k, b) + } + } + out := make([]scoredDoc, 0, len(scores)) + for uid, s := range scores { + out = append(out, scoredDoc{uid: uid, score: s}) + } + sort.Slice(out, func(i, j int) bool { + if out[i].score != out[j].score { + return out[i].score > out[j].score + } + return out[i].uid < out[j].uid + }) + if topK > 0 && len(out) > topK { + out = out[:topK] + } + return out +} + +// TestWandFilteredMatchesBruteForce checks that WAND/Block-Max WAND and the +// score-all path honor a filter set identically to exhaustive filtered scoring. The +// filter must never change which documents or scores are produced (only which are +// considered), so WAND pruning driven by a threshold built from filtered-in documents +// must still be sound. +func TestWandFilteredMatchesBruteForce(t *testing.T) { + rng := rand.New(rand.NewSource(7)) + k, b, avgDL := 1.2, 0.75, 9.0 + + for trial := 0; trial < 200; trial++ { + numTerms := 1 + rng.Intn(4) + termPostings := make([][]posting.BM25Posting, numTerms) + idfs := make([]float64, numTerms) + allUids := map[uint64]bool{} + for ti := 0; ti < numTerms; ti++ { + n := rng.Intn(400) + seen := map[uint64]bool{} + var ps []posting.BM25Posting + for j := 0; j < n; j++ { + uid := uint64(1 + rng.Intn(500)) + if seen[uid] { + continue + } + seen[uid] = true + ps = append(ps, posting.BM25Posting{ + Uid: uid, + TF: uint32(1 + rng.Intn(10)), + DocLen: uint32(1 + rng.Intn(30)), + }) + allUids[uid] = true + } + sort.Slice(ps, func(i, j int) bool { return ps[i].Uid < ps[j].Uid }) + termPostings[ti] = ps + idfs[ti] = 0.5 + rng.Float64()*2 + } + + // Random filter subset (may be empty). + filterSet := map[uint64]struct{}{} + for uid := range allUids { + if rng.Intn(2) == 0 { + filterSet[uid] = struct{}{} + } + } + + build := func() []*termCursor { + cs := make([]*termCursor, 0, numTerms) + for ti, ps := range termPostings { + if len(ps) == 0 { + continue + } + cs = append(cs, newTermCursor(ps, idfs[ti], k, b, avgDL)) + } + return cs + } + + // score-all path with filter must reproduce the full filtered ranking exactly. + wantAll := bruteForceTopKFiltered(termPostings, idfs, k, b, avgDL, 0, filterSet) + gotAll := scoreAllDocs(build(), k, b, avgDL, filterSet) + require.Lenf(t, gotAll, len(wantAll), "trial %d filtered score-all len", trial) + for i := range wantAll { + require.InEpsilonf(t, wantAll[i].score, gotAll[i].score, 1e-9, + "trial %d filtered score-all rank %d score", trial, i) + } + + // top-k WAND/BMW with filter must match the filtered top-k scores. + topK := 1 + rng.Intn(8) + want := bruteForceTopKFiltered(termPostings, idfs, k, b, avgDL, topK, filterSet) + wantPlus := bruteForceTopKFiltered(termPostings, idfs, k, b, avgDL, topK+1, filterSet) + for _, useBMW := range []bool{false, true} { + got := wandTopK(build(), k, b, avgDL, topK, filterSet, useBMW) + require.Lenf(t, got, len(want), "trial %d filtered bmw=%v len", trial, useBMW) + for i := range want { + require.InEpsilonf(t, want[i].score, got[i].score, 1e-9, + "trial %d filtered bmw=%v rank %d score", trial, useBMW, i) + tied := (i > 0 && wantPlus[i].score == wantPlus[i-1].score) || + (i+1 < len(wantPlus) && wantPlus[i].score == wantPlus[i+1].score) + if !tied { + require.Equalf(t, want[i].uid, got[i].uid, + "trial %d filtered bmw=%v rank %d uid", trial, useBMW, i) + } + } + } + } +} + +// TestWandMatchesBruteForce checks that WAND and Block-Max WAND return exactly the +// same top-k documents and scores as exhaustive scoring, across many randomized +// posting lists. This is the core correctness guarantee: pruning must never change +// the result, only the work done. +func TestWandMatchesBruteForce(t *testing.T) { + rng := rand.New(rand.NewSource(42)) + k, b, avgDL := 1.2, 0.75, 12.0 + + for trial := 0; trial < 200; trial++ { + numTerms := 1 + rng.Intn(4) + termPostings := make([][]posting.BM25Posting, numTerms) + idfs := make([]float64, numTerms) + for ti := 0; ti < numTerms; ti++ { + n := rng.Intn(400) // spans multiple wandBlockSize blocks + seen := map[uint64]bool{} + var ps []posting.BM25Posting + for j := 0; j < n; j++ { + uid := uint64(1 + rng.Intn(500)) + if seen[uid] { + continue + } + seen[uid] = true + ps = append(ps, posting.BM25Posting{ + Uid: uid, + TF: uint32(1 + rng.Intn(10)), + DocLen: uint32(1 + rng.Intn(30)), + }) + } + sort.Slice(ps, func(i, j int) bool { return ps[i].Uid < ps[j].Uid }) + termPostings[ti] = ps + // Vary IDF per term so different terms carry different weight. + idfs[ti] = 0.5 + rng.Float64()*2 + } + + topK := 1 + rng.Intn(10) + want := bruteForceTopK(termPostings, idfs, k, b, avgDL, topK) + // One extra result lets us detect a tie between the cutoff rank and the + // first excluded document (a boundary tie outside the top-k window). + wantPlus := bruteForceTopK(termPostings, idfs, k, b, avgDL, topK+1) + + build := func() []*termCursor { + cs := make([]*termCursor, 0, numTerms) + for ti, ps := range termPostings { + if len(ps) == 0 { + continue + } + cs = append(cs, newTermCursor(ps, idfs[ti], k, b, avgDL)) + } + return cs + } + + for _, useBMW := range []bool{false, true} { + got := wandTopK(build(), k, b, avgDL, topK, nil, useBMW) + require.Lenf(t, got, len(want), "trial %d bmw=%v len", trial, useBMW) + for i := range want { + // The score at each rank must match exactly: WAND/BMW pruning must + // never change which scores make the top-k, only the work done. + require.InEpsilonf(t, want[i].score, got[i].score, 1e-9, + "trial %d bmw=%v rank %d score", trial, useBMW, i) + // The uid is only guaranteed when this rank's score is not tied with + // a neighbor (including the first excluded doc); tied-boundary docs + // are interchangeable in the ranking. + tied := (i > 0 && wantPlus[i].score == wantPlus[i-1].score) || + (i+1 < len(wantPlus) && wantPlus[i].score == wantPlus[i+1].score) + if !tied { + require.Equalf(t, want[i].uid, got[i].uid, + "trial %d bmw=%v rank %d uid", trial, useBMW, i) + } + } + } + } +} diff --git a/worker/mutation.go b/worker/mutation.go index fdac2a41c1b..076beed185f 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -410,6 +410,19 @@ func checkSchema(s *pb.SchemaUpdate) error { x.ParseAttr(s.Predicate)) } + // BM25 scores a single document (one value) per UID: per-document length and + // corpus statistics are not well-defined for a list predicate, and the bucketed + // stats maintenance relies on conflict detection that a list predicate's + // value-dependent conflict key would not provide. Reject @index(bm25) on lists. + if s.List { + for _, tokenizer := range s.Tokenizer { + if tokenizer == "bm25" { + return errors.Errorf("Tokenizer 'bm25' cannot be applied to list predicate: %s", + x.ParseAttr(s.Predicate)) + } + } + } + // If schema update has upsert directive, it should have index directive. if s.Upsert && len(s.Tokenizer) == 0 && !s.Unique { return errors.Errorf("Index tokenizer is mandatory for: [%s] when specifying @upsert directive", diff --git a/worker/mutation_integration_test.go b/worker/mutation_integration_test.go index 99a2a1eed01..f1f4b81695b 100644 --- a/worker/mutation_integration_test.go +++ b/worker/mutation_integration_test.go @@ -93,6 +93,25 @@ func TestCheckSchema(t *testing.T) { } require.NoError(t, checkSchema(s1)) + // bm25 on a scalar string predicate is allowed. + s1 = &pb.SchemaUpdate{ + Predicate: x.AttrInRootNamespace("bio"), + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"bm25"}, + } + require.NoError(t, checkSchema(s1)) + + // bm25 on a list predicate is rejected. + s1 = &pb.SchemaUpdate{ + Predicate: x.AttrInRootNamespace("tags"), + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"bm25"}, + List: true, + } + require.Error(t, checkSchema(s1)) + s1 = &pb.SchemaUpdate{ Predicate: x.AttrInRootNamespace("friend"), ValueType: pb.Posting_UID, diff --git a/worker/task.go b/worker/task.go index 409ec3f0fc4..5098acdc0c8 100644 --- a/worker/task.go +++ b/worker/task.go @@ -7,6 +7,7 @@ package worker import ( "context" + "encoding/binary" "fmt" "math" "sort" @@ -224,6 +225,7 @@ const ( customIndexFn matchFn similarToFn + bm25SearchFn standardFn = 100 ) @@ -266,6 +268,8 @@ func parseFuncTypeHelper(name string) (FuncType, string) { return uidInFn, f case "similar_to": return similarToFn, f + case "bm25": + return bm25SearchFn, f case "anyof", "allof": return customIndexFn, f case "match": @@ -292,6 +296,8 @@ func needsIndex(fnType FuncType, uidList *pb.List) bool { return true case similarToFn: return true + case bm25SearchFn: + return true } return false } @@ -314,7 +320,7 @@ type funcArgs struct { // The function tells us whether we want to fetch value posting lists or uid posting lists. func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) { switch srcFn.fnType { - case aggregatorFn, passwordFn, similarToFn: + case aggregatorFn, passwordFn, similarToFn, bm25SearchFn: return true, nil case compareAttrFn: if len(srcFn.tokens) > 0 { @@ -351,11 +357,15 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er attribute.String("srcFn", x.SafeUTF8(fmt.Sprintf("%+v", args.srcFn))))) switch srcFn.fnType { - case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn: + case notAFunction, aggregatorFn, passwordFn, compareAttrFn, similarToFn, bm25SearchFn: default: return errors.Errorf("Unhandled function in handleValuePostings: %s", srcFn.fname) } + if srcFn.fnType == bm25SearchFn { + return qs.handleBM25Search(ctx, args) + } + if srcFn.fnType == similarToFn { numNeighbors, err := strconv.ParseInt(q.SrcFunc.Args[0], 10, 32) if err != nil { @@ -1219,6 +1229,115 @@ func needsStringFiltering(srcFn *functionContext, langs []string, attr string) b srcFn.fnType == customIndexFn || srcFn.fnType == ngramFn) } +func (qs *queryState) handleBM25Search(ctx context.Context, args funcArgs) error { + q := args.q + attr := q.Attr + + // 1. Parse args: query text, optional k (default 1.2), b (default 0.75). + if len(q.SrcFunc.Args) < 1 { + return errors.Errorf("bm25 requires at least 1 argument (query text)") + } + queryText := q.SrcFunc.Args[0] + k := 1.2 + b := 0.75 + if len(q.SrcFunc.Args) >= 2 { + var err error + k, err = strconv.ParseFloat(q.SrcFunc.Args[1], 64) + if err != nil { + return errors.Errorf("bm25: invalid k parameter: %s", q.SrcFunc.Args[1]) + } + } + if len(q.SrcFunc.Args) >= 3 { + var err error + b, err = strconv.ParseFloat(q.SrcFunc.Args[2], 64) + if err != nil { + return errors.Errorf("bm25: invalid b parameter: %s", q.SrcFunc.Args[2]) + } + } + if math.IsNaN(k) || math.IsInf(k, 0) || k <= 0 { + return errors.Errorf("bm25: k must be a positive finite number, got %v", k) + } + if math.IsNaN(b) || math.IsInf(b, 0) || b < 0 || b > 1 { + return errors.Errorf("bm25: b must be between 0 and 1, got %v", b) + } + + // 2. Tokenize query (deduplicated) using the fulltext pipeline. The returned + // tokens already carry the BM25 tokenizer identifier byte. + lang := langForFunc(q.Langs) + queryTokens, err := tok.GetBM25QueryTokens([]string{queryText}, lang) + if err != nil { + return err + } + if len(queryTokens) == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + + // 3. Read bucketed corpus stats and derive N and the average document length. + docCount, totalTerms, err := posting.ReadBM25Stats(qs.cache.Get, attr, q.ReadTs) + if err != nil { + return err + } + if docCount == 0 || totalTerms == 0 { + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{}) + return nil + } + avgDL := float64(totalTerms) / float64(docCount) + N := float64(docCount) + + // Build a filter set if bm25 is used as a filter (@filter(bm25(...))). + var filterSet map[uint64]struct{} + if q.UidList != nil && len(q.UidList.Uids) > 0 { + filterSet = make(map[uint64]struct{}, len(q.UidList.Uids)) + for _, uid := range q.UidList.Uids { + filterSet[uid] = struct{}{} + } + } + + // 4. Use WAND top-k early termination whenever a first limit is set, retaining + // first+offset documents so the offset can be dropped afterward. This bounds work + // and memory to the requested window instead of scoring (and materializing) every + // matching document just because an offset was supplied. With no first limit, + // topK is 0 and every match is scored (the caller explicitly asked for all results). + topK := bm25TopK(int(q.First), int(q.Offset)) + + // 5. Run WAND / Block-Max WAND over the standard posting lists. + results, err := wandSearch(qs.cache.Get, attr, q.ReadTs, queryTokens, k, b, avgDL, N, + topK, filterSet, true) + if err != nil { + return err + } + + // 6. Apply the first/offset window over the score-descending results. wandSearch + // returns results sorted by score (descending), whether or not it top-k'd them, so + // the same slice is correct for both the top-k and score-all paths. + results = bm25PaginateScored(results, int(q.First), int(q.Offset)) + + // 7. Emit UIDs ascending (required by the query pipeline) with positionally- + // aligned scores in ValueMatrix; the query layer binds these to a value + // variable so callers order by and project the score via val(). + sort.Slice(results, func(i, j int) bool { return results[i].uid < results[j].uid }) + uids := make([]uint64, len(results)) + for i, r := range results { + uids[i] = r.uid + } + args.out.UidMatrix = append(args.out.UidMatrix, &pb.List{Uids: uids}) + + scoreBuf := make([]byte, len(results)*8) + scoreValues := make([]*pb.ValueList, len(results)) + for i, r := range results { + off := i * 8 + binary.LittleEndian.PutUint64(scoreBuf[off:off+8], math.Float64bits(r.score)) + // Three-index slice caps capacity at 8 so a downstream append can't corrupt + // adjacent scores in the shared backing array. + scoreValues[i] = &pb.ValueList{ + Values: []*pb.TaskValue{{Val: scoreBuf[off : off+8 : off+8], ValType: pb.Posting_FLOAT}}, + } + } + args.out.ValueMatrix = append(args.out.ValueMatrix, scoreValues...) + return nil +} + func (qs *queryState) handleCompareScalarFunction(ctx context.Context, arg funcArgs) error { attr := arg.q.Attr if ok := schema.State().HasCount(ctx, attr); !ok { @@ -2167,6 +2286,18 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { return nil, err } checkRoot(q, fc) + case bm25SearchFn: + // bm25(pred, "query text") or bm25(pred, "query text", "k", "b") + if len(q.SrcFunc.Args) < 1 || len(q.SrcFunc.Args) > 3 { + return nil, errors.Errorf("Function 'bm25' requires 1-3 arguments (query [, k, b]), but got %d", + len(q.SrcFunc.Args)) + } + required, found := verifyStringIndex(ctx, attr, fnType) + if !found { + return nil, errors.Errorf("Attribute %s is not indexed with type %s", x.ParseAttr(attr), + required) + } + checkRoot(q, fc) case similarToFn: // similar_to accepts 2 mandatory args: k, vector_or_uid followed by optional key:value pairs // Example: similar_to(vpred, 3, $vec, ef: 64, distance_threshold: 0.5) diff --git a/worker/tokens.go b/worker/tokens.go index 2740d29f447..b8c85a22816 100644 --- a/worker/tokens.go +++ b/worker/tokens.go @@ -25,6 +25,8 @@ func verifyStringIndex(ctx context.Context, attr string, funcType FuncType) (str requiredTokenizer = tok.NGramTokenizer{} case fullTextSearchFn: requiredTokenizer = tok.FullTextTokenizer{} + case bm25SearchFn: + requiredTokenizer = tok.BM25Tokenizer{} case matchFn: requiredTokenizer = tok.TrigramTokenizer{} default: @@ -65,6 +67,9 @@ func getStringTokens(funcArgs []string, lang string, funcType FuncType, query bo if funcType == fullTextSearchFn { return tok.GetFullTextTokens(funcArgs, lang) } + if funcType == bm25SearchFn { + return tok.GetBM25QueryTokens(funcArgs, lang) + } if funcType == ngramFn { if query { return tok.GetNGramQueryTokens(funcArgs, lang) diff --git a/x/keys.go b/x/keys.go index 94112d07c03..138551bb7bc 100644 --- a/x/keys.go +++ b/x/keys.go @@ -291,6 +291,44 @@ func CountKey(attr string, count uint32, reverse bool) []byte { return buf } +// BM25IndexKey generates the index key for a BM25 term posting list. The +// encodedToken already carries the BM25 tokenizer identifier byte, so BM25 term +// postings live at the same standard index key as every other tokenizer — +// IndexKey(attr, identifier || term) — and inherit rollup, splits, backup, and +// index-rebuild handling for free. This is a thin alias of IndexKey so the index +// write path and the query read path share one definition. +func BM25IndexKey(attr string, encodedToken string) []byte { + return IndexKey(attr, encodedToken) +} + +// bm25StatsPrefix namespaces the BM25 corpus-statistics keys. These hold the +// document count and total term count (used to derive the average document +// length); they are auxiliary metadata, not term postings, so they use a reserved +// token that cannot collide with any stemmed BM25 term. +const bm25StatsPrefix = "\x00_bm25stats_" + +// BM25StatsKey generates the key for one bucket of BM25 corpus statistics. Stats +// are sharded across buckets (keyed by uid%numBuckets) to spread write contention. +func BM25StatsKey(attr string, bucket int) []byte { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], uint16(bucket)) + return IndexKey(attr, bm25StatsPrefix+string(buf[:])) +} + +// BM25StatsPrefix returns the key prefix covering every BM25 corpus-statistics +// bucket for attr. Stats live under a reserved token prefix that is distinct from +// the IdentBM25 term-posting prefix, so dropping/rebuilding the BM25 index must +// delete this prefix too — otherwise the stale stats survive a drop and +// double-count when the index is rebuilt on top of them. The split argument +// selects the ByteSplit-prefixed variant used by multi-part lists. +func BM25StatsPrefix(attr string, split bool) []byte { + prefix := IndexKey(attr, bm25StatsPrefix) + if split { + prefix[0] = ByteSplit + } + return prefix +} + // ParsedKey represents a key that has been parsed into its multiple attributes. type ParsedKey struct { Attr string