From 47f665d99a08d0af2574ea3968f5f909cb4850dd Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Mon, 8 Jun 2026 12:40:39 -0400 Subject: [PATCH 1/2] feat(hnsw): batch neighbor vector reads via badger Txn.MultiGet HNSW search reads each candidate's neighbor vectors one key at a time (getVecFromUid -> CacheType.Get), and each Get becomes a full single-key NewKeyIterator(AllVersions) in badger. For a fixed candidate these sibling reads are independent, so fold them into one batched read. Changes: - index.CacheType / Txn / LocalCache: add MultiGet(keys) (vals, errs), the batched counterpart of Get. - posting: ReadPostingListFromVersions folds a key's version chain from a badger []ItemVersion exactly as ReadPostingList does from an iterator; MemoryLayer.ReadManyData resolves many keys in one badger Txn.MultiGet (warm keys served from the global cache; two-phase read mirroring ReadData); LocalCache.MultiGet adds the per-txn cache/delta layer; viLocalCache/viTxn.MultiGet expose resolved values. - tok/hnsw: getVecsFromUids batch-fetches a frontier's vectors; searchPersistentLayer collects a candidate's unvisited neighbors and reads their vectors in one MultiGet (traversal/heap logic unchanged). - TxnCache/QueryCache and the test mocks implement MultiGet. DEPENDS ON the badger Txn.MultiGet change (dgraph-io/badger#2297). Until a badger release with MultiGet is available, build/test locally with: go mod edit -replace github.com/dgraph-io/badger/v4=/path/to/badger against a checkout of badger branch sp/badger_multiget; then bump the badger version here. Tested (with the local badger replace): posting differential test (ReadManyData + viLocalCache.MultiGet match the per-key ReadPostingList/Get path over an on-disk round-trip of complete/delta/empty/absent posting shapes); tok/hnsw and posting suites pass; full dgraph build clean. Benchmarks (cold cache, real badger): - badger MultiGet vs per-key NewKeyIterator: -5..-16% time, -35..-40% allocs. - posting frontier read (Get loop vs MultiGet): -12..-18% time, -15..-21% allocs. Co-Authored-By: Claude Opus 4.8 (1M context) --- posting/lists.go | 80 +++++++++++++++++ posting/multiget_test.go | 165 ++++++++++++++++++++++++++++++++++++ posting/mvcc.go | 156 ++++++++++++++++++++++++++++++++++ posting/oracle.go | 25 ++++++ tok/hnsw/ef_recall_test.go | 9 ++ tok/hnsw/helper.go | 37 ++++++++ tok/hnsw/persistent_hnsw.go | 23 +++-- tok/hnsw/test_helper.go | 24 ++++++ tok/index/index.go | 11 +++ 9 files changed, 524 insertions(+), 6 deletions(-) create mode 100644 posting/multiget_test.go diff --git a/posting/lists.go b/posting/lists.go index a4bc4fb355b..00413e03f68 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -98,6 +98,31 @@ func (vc *viLocalCache) Get(key []byte) ([]byte, error) { return vc.GetValueFromPostingList(pl) } +// MultiGet resolves many keys to their values in one batched read (see +// LocalCache.MultiGet). vals and errs are aligned with keys; errs[i] is +// ErrNoValue when keys[i] has no value, matching Get's per-key semantics. +func (vc *viLocalCache) MultiGet(keys [][]byte) ([][]byte, []error) { + vals := make([][]byte, len(keys)) + errs := make([]error, len(keys)) + lists, err := vc.delegate.MultiGet(keys) + if err != nil { + for i := range errs { + errs[i] = err + } + return vals, errs + } + for i, pl := range lists { + if pl == nil { + errs[i] = ErrNoValue + continue + } + pl.Lock() + vals[i], errs[i] = vc.GetValueFromPostingList(pl) + pl.Unlock() + } + return vals, errs +} + func (vc *viLocalCache) GetWithLockHeld(key []byte) ([]byte, error) { pl, err := vc.delegate.Get(key) if err != nil { @@ -390,6 +415,61 @@ func (lc *LocalCache) Get(key []byte) (*List, error) { return lc.getInternal(key, true, false) } +// MultiGet is the batched form of Get(readFromDisk=true): it resolves many keys +// to their *List in one shared read. Keys already in the per-txn cache (plists) +// are served from there; the cold remainder is read from disk in a single +// MemoryLayer.ReadManyData (one badger MultiGet), then has any pending txn delta +// applied and is interned via SetIfAbsent — mirroring getInternal per key. +// Returned lists are aligned with keys. +func (lc *LocalCache) MultiGet(keys [][]byte) ([]*List, error) { + lists := make([]*List, len(keys)) + + lc.RLock() + plistsNil := lc.plists == nil + lc.RUnlock() + + var missIdx []int + var missKeys [][]byte + for i, key := range keys { + if !plistsNil { + lc.RLock() + l, ok := lc.plists[string(key)] + lc.RUnlock() + if ok { + lists[i] = l + continue + } + } + missIdx = append(missIdx, i) + missKeys = append(missKeys, key) + } + if len(missKeys) == 0 { + return lists, nil + } + + fetched, err := MemLayerInstance.ReadManyData(missKeys, pstore, lc.startTs) + if err != nil { + return nil, err + } + for j, idx := range missIdx { + pl := fetched[j] + if plistsNil { + // Cache disabled: return the freshly read list directly, exactly as + // getInternal's plists==nil path returns getNew without caching. + lists[idx] = pl + continue + } + skey := string(missKeys[j]) + lc.RLock() + if delta, ok := lc.deltas[skey]; ok && len(delta) > 0 { + pl.setMutation(lc.startTs, delta) + } + lc.RUnlock() + lists[idx] = lc.SetIfAbsent(skey, pl) + } + return lists, nil +} + func (lc *LocalCache) GetUids(key []byte) (*List, error) { return lc.getInternal(key, true, true) } diff --git a/posting/multiget_test.go b/posting/multiget_test.go new file mode 100644 index 00000000000..4385f221f32 --- /dev/null +++ b/posting/multiget_test.go @@ -0,0 +1,165 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package posting + +import ( + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + bpb "github.com/dgraph-io/badger/v4/pb" + "github.com/dgraph-io/dgraph/v25/protos/pb" + "github.com/dgraph-io/dgraph/v25/x" + "google.golang.org/protobuf/proto" +) + +// TestReadManyDataMatchesReadData verifies that the batched MultiGet-backed +// read path (MemoryLayer.ReadManyData -> ReadPostingListFromVersions) folds a +// key's version chain identically to the per-key path +// (NewKeyIterator -> ReadPostingList) over a real on-disk round-trip. +func TestReadManyDataMatchesReadData(t *testing.T) { + mkpl := func(uid uint64) []byte { + pl := &pb.PostingList{ + Postings: []*pb.Posting{{Uid: uid, Op: uint32(Set)}}, + } + b, err := proto.Marshal(pl) + require.NoError(t, err) + return b + } + + // Build keys with a few different version-chain shapes, each under a unique + // predicate so the process-global cache starts cold. + pred := x.AttrInRootNamespace("multiget-" + uuid.New().String()) + var keys [][]byte + var kvs []*bpb.KV + for i := 0; i < 12; i++ { + key := x.DataKey(pred, uint64(i+1)) + keys = append(keys, key) + switch i % 4 { + case 0: + // complete posting only + kvs = append(kvs, &bpb.KV{Key: key, Value: mkpl(uint64(i + 1)), + UserMeta: []byte{BitCompletePosting}, Version: 2}) + case 1: + // complete + newer delta on top + kvs = append(kvs, &bpb.KV{Key: key, Value: mkpl(uint64(i + 1)), + UserMeta: []byte{BitCompletePosting}, Version: 2}) + kvs = append(kvs, &bpb.KV{Key: key, Value: mkpl(uint64(100 + i)), + UserMeta: []byte{BitDeltaPosting}, Version: 4}) + case 2: + // empty posting + kvs = append(kvs, &bpb.KV{Key: key, Value: []byte{}, + UserMeta: []byte{BitEmptyPosting}, Version: 3}) + case 3: + // absent key (write nothing) + } + } + require.NoError(t, writePostingListToDisk(kvs)) + + readTs := uint64(10) + + // Ground truth: per-key path (bypasses caches, reads raw from disk). + want := make([][]uint64, len(keys)) + for i, key := range keys { + l, err := readPostingListFromDisk(key, ps, readTs) + require.NoError(t, err) + want[i] = listToArray(t, 0, l, readTs) + } + + // Batched path. + got, err := MemLayerInstance.ReadManyData(keys, ps, readTs) + require.NoError(t, err) + require.Equal(t, len(keys), len(got)) + for i, key := range keys { + require.NotNil(t, got[i], "list for key %x", key) + require.Equal(t, want[i], listToArray(t, 0, got[i], readTs), + "uid set mismatch for key index %d", i) + } + + // Also exercise the viLocalCache value path end-to-end: per-key Get vs + // batched MultiGet must agree value-for-value. + refVals := make([][]byte, len(keys)) + refErrs := make([]bool, len(keys)) + for i, key := range keys { + vc := &viLocalCache{delegate: NewLocalCache(readTs)} + v, e := vc.Get(key) + refVals[i] = v + refErrs[i] = e != nil + } + vc := &viLocalCache{delegate: NewLocalCache(readTs)} + gotVals, gotErrs := vc.MultiGet(keys) + for i := range keys { + require.Equal(t, refErrs[i], gotErrs[i] != nil, + "error presence mismatch for key index %d (%v)", i, gotErrs[i]) + require.Equal(t, fmt.Sprintf("%x", refVals[i]), fmt.Sprintf("%x", gotVals[i]), + "value mismatch for key index %d", i) + } +} + +// benchFrontierKeys writes nKeys vector-shaped posting lists (one complete +// posting whose Value holds the vector bytes) under a unique predicate and +// returns their keys. Mirrors how dgraph stores vectors that HNSW reads. +func benchFrontierKeys(b *testing.B, nKeys, dim int) [][]byte { + b.Helper() + pred := x.AttrInRootNamespace("frontier-bench-" + uuid.New().String()) + vec := make([]byte, dim*4) // dim float32s + for i := range vec { + vec[i] = byte(i) + } + pl := &pb.PostingList{Postings: []*pb.Posting{{Uid: 1, Op: uint32(Set), Value: vec}}} + val, err := proto.Marshal(pl) + if err != nil { + b.Fatal(err) + } + keys := make([][]byte, nKeys) + kvs := make([]*bpb.KV, 0, nKeys) + for i := 0; i < nKeys; i++ { + key := x.DataKey(pred, uint64(i+1)) + keys[i] = key + kvs = append(kvs, &bpb.KV{Key: key, Value: val, UserMeta: []byte{BitCompletePosting}, Version: 1}) + } + if err := writePostingListToDisk(kvs); err != nil { + b.Fatal(err) + } + return keys +} + +// BenchmarkHNSWFrontierRead compares the HNSW neighbor-vector read path against +// a real (cache-disabled) badger store: K serial per-key Get calls (today) vs a +// single batched MultiGet (this change), over a fresh per-iteration cache so +// every read is cold — the search-time reality. +func BenchmarkHNSWFrontierRead(b *testing.B) { + const dim = 384 + readTs := uint64(10) + for _, K := range []int{16, 64, 256} { + keys := benchFrontierKeys(b, K, dim) + b.Run(fmt.Sprintf("Get/K=%d", K), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + vc := &viLocalCache{delegate: NewLocalCache(readTs)} + for _, key := range keys { + if _, err := vc.Get(key); err != nil { + b.Fatal(err) + } + } + } + }) + b.Run(fmt.Sprintf("MultiGet/K=%d", K), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + vc := &viLocalCache{delegate: NewLocalCache(readTs)} + _, errs := vc.MultiGet(keys) + for _, e := range errs { + if e != nil { + b.Fatal(e) + } + } + } + }) + } +} diff --git a/posting/mvcc.go b/posting/mvcc.go index 81c5e375553..653f3b9c131 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -776,6 +776,162 @@ func (ml *MemoryLayer) readFromCache(key []byte, readTs uint64) *List { return nil } +// ReadPostingListFromVersions builds a *List from the version chain of a single +// key as returned by badger.Txn.MultiGet. It is the batched-read counterpart of +// ReadPostingList: same folding logic (delta postings layered on top of a +// complete posting, newest-first, stopping at the first complete/empty/deleted +// version), but consuming an already-materialized []badger.ItemVersion instead +// of advancing a live *badger.Iterator. +func ReadPostingListFromVersions(key []byte, versions []badger.ItemVersion) (*List, error) { + pk, err := x.Parse(key) + if err != nil { + return nil, errors.Wrapf(err, "while reading posting list with key [%v]", key) + } + if pk.HasStartUid { + // Multi-part lists must be read via the main key; mirror ReadPostingList. + return nil, ErrInvalidKey + } + + l := new(List) + l.key = key + l.plist = new(pb.PostingList) + l.mutationMap = newMutableLayer() + l.minTs = 0 + + deltaCount := 0 + defer func() { + if deltaCount > 0 { + if deltaCount > 500 { + IncrRollup.addKeyToBatch(key, 0) + } else { + IncrRollup.addKeyToBatch(key, 1) + } + } + }() + + // versions are newest-first (commit ts descending), exactly the order + // ReadPostingList walks the iterator in. + for i := range versions { + v := &versions[i] + l.maxTs = x.Max(l.maxTs, v.Version) + if v.IsDeletedOrExpired() { + break + } + switch v.UserMeta { + case BitEmptyPosting: + return l, nil + case BitCompletePosting: + if len(v.Value) > 0 { + if err := proto.Unmarshal(v.Value, l.plist); err != nil { + return nil, err + } + } + l.minTs = v.Version + return l, nil + case BitDeltaPosting: + pl := &pb.PostingList{} + if err := proto.Unmarshal(v.Value, pl); err != nil { + return nil, err + } + pl.CommitTs = v.Version + l.mutationMap.insertCommittedPostings(pl) + deltaCount++ + case BitSchemaPosting: + return nil, errors.Errorf( + "Trying to read schema in ReadPostingListFromVersions for key: %s", hex.Dump(key)) + default: + return nil, errors.Errorf( + "Unexpected meta: %d for key: %s", v.UserMeta, hex.Dump(key)) + } + if v.DiscardEarlierVersions() { + break + } + } + return l, nil +} + +// ReadManyData is the batched counterpart of ReadData: it resolves many keys to +// their *List in one shot. Keys already warm in the process-global cache are +// served from there; the cold remainder is fetched from disk with a single +// badger.Txn.MultiGet (amortizing iterator construction across the batch) and +// folded via ReadPostingListFromVersions. It mirrors ReadData's two-phase read +// (full chain at MaxUint64 to populate the cache, then a readTs-bounded re-read +// only for keys whose complete posting is newer than readTs). Returned lists are +// aligned with keys. +func (ml *MemoryLayer) ReadManyData(keys [][]byte, pstore *badger.DB, readTs uint64) ([]*List, error) { + if pstore.IsClosed() { + return nil, badger.ErrDBClosed + } + lists := make([]*List, len(keys)) + + // Phase 0: serve warm keys from the global cache. + var missIdx []int + var missKeys [][]byte + for i, key := range keys { + if l := ml.readFromCache(key, readTs); l != nil { + l.mutationMap.setTs(readTs) + lists[i] = l + continue + } + missIdx = append(missIdx, i) + missKeys = append(missKeys, key) + } + if len(missKeys) == 0 { + return lists, nil + } + + // Phase 1: full version chains at MaxUint64 (so the cached list stays valid + // for a range of read timestamps, as ReadData relies on), populate cache. + readChains := func(ks [][]byte, ts uint64) ([]*List, error) { + txn := pstore.NewTransactionAt(ts, false) + defer txn.Discard() + results, err := txn.MultiGet(ks) + if err != nil { + return nil, err + } + out := make([]*List, len(ks)) + for j := range ks { + l, err := ReadPostingListFromVersions(ks[j], results[j].Versions) + if err != nil { + return nil, err + } + out[j] = l + } + return out, nil + } + + full, err := readChains(missKeys, math.MaxUint64) + if err != nil { + return nil, err + } + + // Phase 2: keys whose complete posting is newer than readTs need a + // readTs-bounded re-read (mirrors ReadData's second readFromDisk). + var reIdx []int + var reKeys [][]byte + for j, l := range full { + ml.saveInCache(missKeys[j], l) + if l.minTs == 0 || readTs >= l.minTs { + l.mutationMap.setTs(readTs) + lists[missIdx[j]] = l + continue + } + reIdx = append(reIdx, j) + reKeys = append(reKeys, missKeys[j]) + } + if len(reKeys) > 0 { + bounded, err := readChains(reKeys, readTs) + if err != nil { + return nil, err + } + for k, l := range bounded { + l.mutationMap.setTs(readTs) + lists[missIdx[reIdx[k]]] = l + } + } + return lists, nil +} + func (ml *MemoryLayer) readFromDisk(key []byte, pstore *badger.DB, readTs uint64, readUids bool) (*List, error) { txn := pstore.NewTransactionAt(readTs, false) defer txn.Discard() diff --git a/posting/oracle.go b/posting/oracle.go index d7c3837b4b2..7c7ff2b043b 100644 --- a/posting/oracle.go +++ b/posting/oracle.go @@ -84,6 +84,31 @@ func (vt *viTxn) Get(key []byte) ([]byte, error) { return vt.GetValueFromPostingList(pl) } +// MultiGet resolves many keys to their values in one batched read (see +// LocalCache.MultiGet). vals and errs are aligned with keys; errs[i] is +// ErrNoValue when keys[i] has no value, matching Get's per-key semantics. +func (vt *viTxn) MultiGet(keys [][]byte) ([][]byte, []error) { + vals := make([][]byte, len(keys)) + errs := make([]error, len(keys)) + lists, err := vt.delegate.cache.MultiGet(keys) + if err != nil { + for i := range errs { + errs[i] = err + } + return vals, errs + } + for i, pl := range lists { + if pl == nil { + errs[i] = ErrNoValue + continue + } + pl.Lock() + vals[i], errs[i] = vt.GetValueFromPostingList(pl) + pl.Unlock() + } + return vals, errs +} + func (vt *viTxn) GetWithLockHeld(key []byte) ([]byte, error) { pl, err := vt.delegate.cache.Get(key) if err != nil { diff --git a/tok/hnsw/ef_recall_test.go b/tok/hnsw/ef_recall_test.go index 5ae22282222..9e631aacdb5 100644 --- a/tok/hnsw/ef_recall_test.go +++ b/tok/hnsw/ef_recall_test.go @@ -30,6 +30,15 @@ func (m *memoryCache) Get(key []byte) ([]byte, error) { return nil, nil } +func (m *memoryCache) MultiGet(keys [][]byte) ([][]byte, []error) { + vals := make([][]byte, len(keys)) + errs := make([]error, len(keys)) + for i, key := range keys { + vals[i], errs[i] = m.Get(key) + } + return vals, errs +} + func (m *memoryCache) Ts() uint64 { return 0 } func (m *memoryCache) Find([]byte, func([]byte) bool) (uint64, error) { return 0, nil } diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 39d72d8f5e7..aad6f35281f 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -245,6 +245,10 @@ func (tc *TxnCache) Get(key []byte) (rval []byte, rerr error) { return tc.txn.Get(key) } +func (tc *TxnCache) MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) { + return tc.txn.MultiGet(keys) +} + func (tc *TxnCache) Ts() uint64 { return tc.startTs } @@ -274,6 +278,10 @@ func (qc *QueryCache) Get(key []byte) (rval []byte, rerr error) { return qc.cache.Get(key) } +func (qc *QueryCache) MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) { + return qc.cache.MultiGet(keys) +} + func (qc *QueryCache) Ts() uint64 { return qc.readTs } @@ -383,6 +391,35 @@ func (ph *persistentHNSW[T]) getVecFromUid(uid uint64, c index.CacheType, vec *[ } } +// getVecsFromUids batch-fetches the vectors for uids in a single CacheType +// MultiGet, returning one []T per uid (aligned with uids). This is the batched +// counterpart of getVecFromUid: it folds the per-neighbor point reads on the +// search hot path into one round trip to the store. A missing or unreadable +// vector yields an empty slice (matching getVecFromUid's empty-on-miss +// behavior). Each returned slice is a zero-copy view over its own value bytes, +// so the slices are independent. +func (ph *persistentHNSW[T]) getVecsFromUids(uids []uint64, c index.CacheType) [][]T { + vecs := make([][]T, len(uids)) + if len(uids) == 0 { + return vecs + } + keys := make([][]byte, len(uids)) + for i, uid := range uids { + keys[i] = DataKey(ph.pred, uid) + } + vals, errs := c.MultiGet(keys) + for i := range uids { + var vec []T + if errs[i] == nil && len(vals[i]) > 0 { + index.BytesAsFloatArray(vals[i], &vec, ph.floatBits) + } else { + index.BytesAsFloatArray(emptyVec, &vec, ph.floatBits) + } + vecs[i] = vec + } + return vecs +} + // chooses whether to create the entry and start nodes based on if it already // exists, and if it hasnt been created yet, it adds the startNode to all // levels. diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 5658800e579..dbbf6547ca7 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -197,17 +197,28 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( if !found { continue } - var eVec []T improved := false + // Batch the neighbor vector reads: collect this candidate's unvisited + // neighbors and fetch all their vectors in one MultiGet, instead of a + // serial point read per neighbor. The reads are independent for a fixed + // candidate; the traversal/heap logic below is otherwise unchanged. + toVisit := make([]uint64, 0, len(allLayerEdges[level])) for _, currUid := range allLayerEdges[level] { if r.indexVisited(currUid) { continue } - // iterate over candidate's neighbors distances to get - // best ones - _ = ph.getVecFromUid(currUid, c, &eVec) - // intentionally ignoring error -- we catch it - // indirectly via eVec == nil check. + toVisit = append(toVisit, currUid) + } + eVecs := ph.getVecsFromUids(toVisit, c) + for i, currUid := range toVisit { + // Re-check visited to preserve exact semantics if a neighbor list + // ever contains a duplicate uid within one batch. + if r.indexVisited(currUid) { + continue + } + // candidate's neighbor vector (empty on miss -- caught via the + // len(eVec) == 0 check below, as before). + eVec := eVecs[i] if len(eVec) == 0 { continue } diff --git a/tok/hnsw/test_helper.go b/tok/hnsw/test_helper.go index 04a3f1ac029..efa7e5775a7 100644 --- a/tok/hnsw/test_helper.go +++ b/tok/hnsw/test_helper.go @@ -159,6 +159,18 @@ func (t *inMemTxn) GetWithLockHeld(key []byte) (rval []byte, rerr error) { return val, nil } +// MultiGet reads many keys; the mock simply loops Get under one lock. +func (t *inMemTxn) MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) { + tsDbs[t.startTs].readMu.RLock() + defer tsDbs[t.startTs].readMu.RUnlock() + rvals = make([][]byte, len(keys)) + rerrs = make([]error, len(keys)) + for i, key := range keys { + rvals[i], rerrs[i] = t.GetWithLockHeld(key) + } + return rvals, rerrs +} + // locks the txn and invokes AddMutationWithLockHeld func (t *inMemTxn) AddMutation(ctx context.Context, key []byte, t1 *index.KeyValue) error { tsDbs[t.startTs].writeMu.Lock() @@ -223,6 +235,18 @@ func (c *inMemLocalCache) Find(prefix []byte, filter func([]byte) bool) (uint64, return 0, nil } +// MultiGet reads many keys; the mock simply loops Get under one lock. +func (c *inMemLocalCache) MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) { + tsDbs[c.readTs].readMu.RLock() + defer tsDbs[c.readTs].readMu.RUnlock() + rvals = make([][]byte, len(keys)) + rerrs = make([]error, len(keys)) + for i, key := range keys { + rvals[i], rerrs[i] = c.GetWithLockHeld(key) + } + return rvals, rerrs +} + // reads value from the database at c's readTs func (c *inMemLocalCache) GetWithLockHeld(key []byte) (rval []byte, rerr error) { val, ok := tsDbs[c.readTs].inMemTestDb[string(key[:])] diff --git a/tok/index/index.go b/tok/index/index.go index 1e981ef189e..85ad15d057c 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -154,6 +154,10 @@ type Txn interface { StartTs() uint64 // Get uses a []byte key to return the Value corresponding to the key Get(key []byte) (rval []byte, rerr error) + // MultiGet returns the Values for many keys in one batched read. rvals and + // rerrs are aligned with keys; rerrs[i] is non-nil (e.g. ErrNoValue) when + // keys[i] has no value. + MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) // GetWithLockHeld uses a []byte key to return the Value corresponding to the key with a mutex lock held GetWithLockHeld(key []byte) (rval []byte, rerr error) Find(prefix []byte, filter func(val []byte) bool) (uint64, error) @@ -172,6 +176,8 @@ type Txn interface { type LocalCache interface { // Get uses a []byte key to return the Value corresponding to the key Get(key []byte) (rval []byte, rerr error) + // MultiGet returns the Values for many keys in one batched read (see Txn.MultiGet). + MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) // GetWithLockHeld uses a []byte key to return the Value corresponding to the key with a mutex lock held GetWithLockHeld(key []byte) (rval []byte, rerr error) Find(prefix []byte, filter func(val []byte) bool) (uint64, error) @@ -180,6 +186,11 @@ type LocalCache interface { // CacheType is an interface representation of the cache of a persistent storage system type CacheType interface { Get(key []byte) (rval []byte, rerr error) + // MultiGet returns the Values for many keys in one batched read. rvals and + // rerrs are aligned with keys; rerrs[i] is non-nil when keys[i] has no value. + // It lets fan-out readers (e.g. HNSW search) fetch a whole frontier at once + // instead of issuing one point read per key. + MultiGet(keys [][]byte) (rvals [][]byte, rerrs []error) Ts() uint64 Find(prefix []byte, filter func(val []byte) bool) (uint64, error) } From 8082177b126830e582ecace4b0b3f39113250f71 Mon Sep 17 00:00:00 2001 From: Shaun Patterson Date: Wed, 10 Jun 2026 14:32:44 +0000 Subject: [PATCH 2/2] fix(hnsw): make batched neighbor reads compile against badger v4.9.1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The branch did not build: ReadManyData/ReadPostingListFromVersions referenced badger.Txn.MultiGet and badger.ItemVersion, which do not exist in the pinned badger v4.9.1 (nor in any released version or badger main) — go.mod was never bumped, so `go build ./posting/...` failed with "undefined: badger.ItemVersion" and "txn.MultiGet undefined". Reimplement the batched cold-read without that API: ReadManyData now opens one read transaction and one AllVersions iterator per phase and Seeks to each key, folding the version chain with the existing, proven ReadPostingList (exactly as the single-key readFromDisk does). This still amortizes txn/iterator construction across the whole neighbor frontier — the dgraph-side batching win — while staying correct: the per-txn cache layering, two-phase MaxUint64-then-readTs read, and cache population are unchanged. ReadPostingListFromVersions is removed. Validated: TestReadManyDataMatchesReadData (batched == single-key, value-for- value) and the vector integration suite (similar_to/HNSW search, delete, update, reindex, dot-product) all pass. --- posting/multiget_test.go | 4 +- posting/mvcc.go | 101 +++++++-------------------------------- 2 files changed, 19 insertions(+), 86 deletions(-) diff --git a/posting/multiget_test.go b/posting/multiget_test.go index 4385f221f32..43c0270a4f4 100644 --- a/posting/multiget_test.go +++ b/posting/multiget_test.go @@ -18,8 +18,8 @@ import ( "google.golang.org/protobuf/proto" ) -// TestReadManyDataMatchesReadData verifies that the batched MultiGet-backed -// read path (MemoryLayer.ReadManyData -> ReadPostingListFromVersions) folds a +// TestReadManyDataMatchesReadData verifies that the batched read path +// (MemoryLayer.ReadManyData, a shared iterator reusing ReadPostingList) folds a // key's version chain identically to the per-key path // (NewKeyIterator -> ReadPostingList) over a real on-disk round-trip. func TestReadManyDataMatchesReadData(t *testing.T) { diff --git a/posting/mvcc.go b/posting/mvcc.go index 653f3b9c131..e4055b74c5f 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -776,85 +776,12 @@ func (ml *MemoryLayer) readFromCache(key []byte, readTs uint64) *List { return nil } -// ReadPostingListFromVersions builds a *List from the version chain of a single -// key as returned by badger.Txn.MultiGet. It is the batched-read counterpart of -// ReadPostingList: same folding logic (delta postings layered on top of a -// complete posting, newest-first, stopping at the first complete/empty/deleted -// version), but consuming an already-materialized []badger.ItemVersion instead -// of advancing a live *badger.Iterator. -func ReadPostingListFromVersions(key []byte, versions []badger.ItemVersion) (*List, error) { - pk, err := x.Parse(key) - if err != nil { - return nil, errors.Wrapf(err, "while reading posting list with key [%v]", key) - } - if pk.HasStartUid { - // Multi-part lists must be read via the main key; mirror ReadPostingList. - return nil, ErrInvalidKey - } - - l := new(List) - l.key = key - l.plist = new(pb.PostingList) - l.mutationMap = newMutableLayer() - l.minTs = 0 - - deltaCount := 0 - defer func() { - if deltaCount > 0 { - if deltaCount > 500 { - IncrRollup.addKeyToBatch(key, 0) - } else { - IncrRollup.addKeyToBatch(key, 1) - } - } - }() - - // versions are newest-first (commit ts descending), exactly the order - // ReadPostingList walks the iterator in. - for i := range versions { - v := &versions[i] - l.maxTs = x.Max(l.maxTs, v.Version) - if v.IsDeletedOrExpired() { - break - } - switch v.UserMeta { - case BitEmptyPosting: - return l, nil - case BitCompletePosting: - if len(v.Value) > 0 { - if err := proto.Unmarshal(v.Value, l.plist); err != nil { - return nil, err - } - } - l.minTs = v.Version - return l, nil - case BitDeltaPosting: - pl := &pb.PostingList{} - if err := proto.Unmarshal(v.Value, pl); err != nil { - return nil, err - } - pl.CommitTs = v.Version - l.mutationMap.insertCommittedPostings(pl) - deltaCount++ - case BitSchemaPosting: - return nil, errors.Errorf( - "Trying to read schema in ReadPostingListFromVersions for key: %s", hex.Dump(key)) - default: - return nil, errors.Errorf( - "Unexpected meta: %d for key: %s", v.UserMeta, hex.Dump(key)) - } - if v.DiscardEarlierVersions() { - break - } - } - return l, nil -} - // ReadManyData is the batched counterpart of ReadData: it resolves many keys to -// their *List in one shot. Keys already warm in the process-global cache are -// served from there; the cold remainder is fetched from disk with a single -// badger.Txn.MultiGet (amortizing iterator construction across the batch) and -// folded via ReadPostingListFromVersions. It mirrors ReadData's two-phase read +// their *List in one shared read. Keys already warm in the process-global cache +// are served from there; the cold remainder is read from disk by reusing a single +// transaction and a single AllVersions iterator across the whole batch (amortizing +// txn/iterator construction over the frontier) and folded by the same +// ReadPostingList used for single-key reads. It mirrors ReadData's two-phase read // (full chain at MaxUint64 to populate the cache, then a readTs-bounded re-read // only for keys whose complete posting is newer than readTs). Returned lists are // aligned with keys. @@ -885,13 +812,19 @@ func (ml *MemoryLayer) ReadManyData(keys [][]byte, pstore *badger.DB, readTs uin readChains := func(ks [][]byte, ts uint64) ([]*List, error) { txn := pstore.NewTransactionAt(ts, false) defer txn.Discard() - results, err := txn.MultiGet(ks) - if err != nil { - return nil, err - } + // One AllVersions iterator reused across the batch amortizes iterator/txn + // construction over the whole frontier; ReadPostingList folds each key's + // version chain exactly as the single-key disk read (readFromDisk) does. + // PrefetchValues is off to match readFromDisk. + iterOpts := badger.DefaultIteratorOptions + iterOpts.AllVersions = true + iterOpts.PrefetchValues = false + it := txn.NewIterator(iterOpts) + defer it.Close() out := make([]*List, len(ks)) - for j := range ks { - l, err := ReadPostingListFromVersions(ks[j], results[j].Versions) + for j, key := range ks { + it.Seek(key) + l, err := ReadPostingList(key, it) if err != nil { return nil, err }