Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions posting/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ func (vc *viLocalCache) Get(key []byte) ([]byte, error) {
return vc.GetValueFromPostingList(pl)
}

// MultiGet resolves many keys to their values in one batched read (see
// LocalCache.MultiGet). vals and errs are aligned with keys; errs[i] is
// ErrNoValue when keys[i] has no value, matching Get's per-key semantics.
func (vc *viLocalCache) MultiGet(keys [][]byte) ([][]byte, []error) {
vals := make([][]byte, len(keys))
errs := make([]error, len(keys))
lists, err := vc.delegate.MultiGet(keys)
if err != nil {
for i := range errs {
errs[i] = err
}
return vals, errs
}
for i, pl := range lists {
if pl == nil {
errs[i] = ErrNoValue
continue
}
pl.Lock()
vals[i], errs[i] = vc.GetValueFromPostingList(pl)
pl.Unlock()
}
return vals, errs
}

func (vc *viLocalCache) GetWithLockHeld(key []byte) ([]byte, error) {
pl, err := vc.delegate.Get(key)
if err != nil {
Expand Down Expand Up @@ -390,6 +415,61 @@ func (lc *LocalCache) Get(key []byte) (*List, error) {
return lc.getInternal(key, true, false)
}

// MultiGet is the batched form of Get(readFromDisk=true): it resolves many keys
// to their *List in one shared read. Keys already in the per-txn cache (plists)
// are served from there; the cold remainder is read from disk in a single
// MemoryLayer.ReadManyData (one badger MultiGet), then has any pending txn delta
// applied and is interned via SetIfAbsent — mirroring getInternal per key.
// Returned lists are aligned with keys.
func (lc *LocalCache) MultiGet(keys [][]byte) ([]*List, error) {
lists := make([]*List, len(keys))

lc.RLock()
plistsNil := lc.plists == nil
lc.RUnlock()

var missIdx []int
var missKeys [][]byte
for i, key := range keys {
if !plistsNil {
lc.RLock()
l, ok := lc.plists[string(key)]
lc.RUnlock()
if ok {
lists[i] = l
continue
}
}
missIdx = append(missIdx, i)
missKeys = append(missKeys, key)
}
if len(missKeys) == 0 {
return lists, nil
}

fetched, err := MemLayerInstance.ReadManyData(missKeys, pstore, lc.startTs)
if err != nil {
return nil, err
}
for j, idx := range missIdx {
pl := fetched[j]
if plistsNil {
// Cache disabled: return the freshly read list directly, exactly as
// getInternal's plists==nil path returns getNew without caching.
lists[idx] = pl
continue
}
skey := string(missKeys[j])
lc.RLock()
if delta, ok := lc.deltas[skey]; ok && len(delta) > 0 {
pl.setMutation(lc.startTs, delta)
}
lc.RUnlock()
lists[idx] = lc.SetIfAbsent(skey, pl)
}
return lists, nil
}

func (lc *LocalCache) GetUids(key []byte) (*List, error) {
return lc.getInternal(key, true, true)
}
Expand Down
165 changes: 165 additions & 0 deletions posting/multiget_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
* SPDX-License-Identifier: Apache-2.0
*/

package posting

import (
"fmt"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

bpb "github.com/dgraph-io/badger/v4/pb"
"github.com/dgraph-io/dgraph/v25/protos/pb"
"github.com/dgraph-io/dgraph/v25/x"
"google.golang.org/protobuf/proto"
)

// TestReadManyDataMatchesReadData verifies that the batched read path
// (MemoryLayer.ReadManyData, a shared iterator reusing ReadPostingList) folds a
// key's version chain identically to the per-key path
// (NewKeyIterator -> ReadPostingList) over a real on-disk round-trip.
func TestReadManyDataMatchesReadData(t *testing.T) {
mkpl := func(uid uint64) []byte {
pl := &pb.PostingList{
Postings: []*pb.Posting{{Uid: uid, Op: uint32(Set)}},
}
b, err := proto.Marshal(pl)
require.NoError(t, err)
return b
}

// Build keys with a few different version-chain shapes, each under a unique
// predicate so the process-global cache starts cold.
pred := x.AttrInRootNamespace("multiget-" + uuid.New().String())
var keys [][]byte
var kvs []*bpb.KV
for i := 0; i < 12; i++ {
key := x.DataKey(pred, uint64(i+1))
keys = append(keys, key)
switch i % 4 {
case 0:
// complete posting only
kvs = append(kvs, &bpb.KV{Key: key, Value: mkpl(uint64(i + 1)),
UserMeta: []byte{BitCompletePosting}, Version: 2})
case 1:
// complete + newer delta on top
kvs = append(kvs, &bpb.KV{Key: key, Value: mkpl(uint64(i + 1)),
UserMeta: []byte{BitCompletePosting}, Version: 2})
kvs = append(kvs, &bpb.KV{Key: key, Value: mkpl(uint64(100 + i)),
UserMeta: []byte{BitDeltaPosting}, Version: 4})
case 2:
// empty posting
kvs = append(kvs, &bpb.KV{Key: key, Value: []byte{},
UserMeta: []byte{BitEmptyPosting}, Version: 3})
case 3:
// absent key (write nothing)
}
}
require.NoError(t, writePostingListToDisk(kvs))

readTs := uint64(10)

// Ground truth: per-key path (bypasses caches, reads raw from disk).
want := make([][]uint64, len(keys))
for i, key := range keys {
l, err := readPostingListFromDisk(key, ps, readTs)
require.NoError(t, err)
want[i] = listToArray(t, 0, l, readTs)
}

// Batched path.
got, err := MemLayerInstance.ReadManyData(keys, ps, readTs)
require.NoError(t, err)
require.Equal(t, len(keys), len(got))
for i, key := range keys {
require.NotNil(t, got[i], "list for key %x", key)
require.Equal(t, want[i], listToArray(t, 0, got[i], readTs),
"uid set mismatch for key index %d", i)
}

// Also exercise the viLocalCache value path end-to-end: per-key Get vs
// batched MultiGet must agree value-for-value.
refVals := make([][]byte, len(keys))
refErrs := make([]bool, len(keys))
for i, key := range keys {
vc := &viLocalCache{delegate: NewLocalCache(readTs)}
v, e := vc.Get(key)
refVals[i] = v
refErrs[i] = e != nil
}
vc := &viLocalCache{delegate: NewLocalCache(readTs)}
gotVals, gotErrs := vc.MultiGet(keys)
for i := range keys {
require.Equal(t, refErrs[i], gotErrs[i] != nil,
"error presence mismatch for key index %d (%v)", i, gotErrs[i])
require.Equal(t, fmt.Sprintf("%x", refVals[i]), fmt.Sprintf("%x", gotVals[i]),
"value mismatch for key index %d", i)
}
}

// benchFrontierKeys writes nKeys vector-shaped posting lists (one complete
// posting whose Value holds the vector bytes) under a unique predicate and
// returns their keys. Mirrors how dgraph stores vectors that HNSW reads.
func benchFrontierKeys(b *testing.B, nKeys, dim int) [][]byte {
b.Helper()
pred := x.AttrInRootNamespace("frontier-bench-" + uuid.New().String())
vec := make([]byte, dim*4) // dim float32s
for i := range vec {
vec[i] = byte(i)
}
pl := &pb.PostingList{Postings: []*pb.Posting{{Uid: 1, Op: uint32(Set), Value: vec}}}
val, err := proto.Marshal(pl)
if err != nil {
b.Fatal(err)
}
keys := make([][]byte, nKeys)
kvs := make([]*bpb.KV, 0, nKeys)
for i := 0; i < nKeys; i++ {
key := x.DataKey(pred, uint64(i+1))
keys[i] = key
kvs = append(kvs, &bpb.KV{Key: key, Value: val, UserMeta: []byte{BitCompletePosting}, Version: 1})
}
if err := writePostingListToDisk(kvs); err != nil {
b.Fatal(err)
}
return keys
}

// BenchmarkHNSWFrontierRead compares the HNSW neighbor-vector read path against
// a real (cache-disabled) badger store: K serial per-key Get calls (today) vs a
// single batched MultiGet (this change), over a fresh per-iteration cache so
// every read is cold — the search-time reality.
func BenchmarkHNSWFrontierRead(b *testing.B) {
const dim = 384
readTs := uint64(10)
for _, K := range []int{16, 64, 256} {
keys := benchFrontierKeys(b, K, dim)
b.Run(fmt.Sprintf("Get/K=%d", K), func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
vc := &viLocalCache{delegate: NewLocalCache(readTs)}
for _, key := range keys {
if _, err := vc.Get(key); err != nil {
b.Fatal(err)
}
}
}
})
b.Run(fmt.Sprintf("MultiGet/K=%d", K), func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
vc := &viLocalCache{delegate: NewLocalCache(readTs)}
_, errs := vc.MultiGet(keys)
for _, e := range errs {
if e != nil {
b.Fatal(e)
}
}
}
})
}
}
89 changes: 89 additions & 0 deletions posting/mvcc.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,95 @@ func (ml *MemoryLayer) readFromCache(key []byte, readTs uint64) *List {
return nil
}

// ReadManyData is the batched counterpart of ReadData: it resolves many keys to
// their *List in one shared read. Keys already warm in the process-global cache
// are served from there; the cold remainder is read from disk by reusing a single
// transaction and a single AllVersions iterator across the whole batch (amortizing
// txn/iterator construction over the frontier) and folded by the same
// ReadPostingList used for single-key reads. It mirrors ReadData's two-phase read
// (full chain at MaxUint64 to populate the cache, then a readTs-bounded re-read
// only for keys whose complete posting is newer than readTs). Returned lists are
// aligned with keys.
func (ml *MemoryLayer) ReadManyData(keys [][]byte, pstore *badger.DB, readTs uint64) ([]*List, error) {
if pstore.IsClosed() {
return nil, badger.ErrDBClosed
}
lists := make([]*List, len(keys))

// Phase 0: serve warm keys from the global cache.
var missIdx []int
var missKeys [][]byte
for i, key := range keys {
if l := ml.readFromCache(key, readTs); l != nil {
l.mutationMap.setTs(readTs)
lists[i] = l
continue
}
missIdx = append(missIdx, i)
missKeys = append(missKeys, key)
}
if len(missKeys) == 0 {
return lists, nil
}

// Phase 1: full version chains at MaxUint64 (so the cached list stays valid
// for a range of read timestamps, as ReadData relies on), populate cache.
readChains := func(ks [][]byte, ts uint64) ([]*List, error) {
txn := pstore.NewTransactionAt(ts, false)
defer txn.Discard()
// One AllVersions iterator reused across the batch amortizes iterator/txn
// construction over the whole frontier; ReadPostingList folds each key's
// version chain exactly as the single-key disk read (readFromDisk) does.
// PrefetchValues is off to match readFromDisk.
iterOpts := badger.DefaultIteratorOptions
iterOpts.AllVersions = true
iterOpts.PrefetchValues = false
it := txn.NewIterator(iterOpts)
defer it.Close()
out := make([]*List, len(ks))
for j, key := range ks {
it.Seek(key)
l, err := ReadPostingList(key, it)
if err != nil {
return nil, err
}
out[j] = l
}
return out, nil
}

full, err := readChains(missKeys, math.MaxUint64)
if err != nil {
return nil, err
}

// Phase 2: keys whose complete posting is newer than readTs need a
// readTs-bounded re-read (mirrors ReadData's second readFromDisk).
var reIdx []int
var reKeys [][]byte
for j, l := range full {
ml.saveInCache(missKeys[j], l)
if l.minTs == 0 || readTs >= l.minTs {
l.mutationMap.setTs(readTs)
lists[missIdx[j]] = l
continue
}
reIdx = append(reIdx, j)
reKeys = append(reKeys, missKeys[j])
}
if len(reKeys) > 0 {
bounded, err := readChains(reKeys, readTs)
if err != nil {
return nil, err
}
for k, l := range bounded {
l.mutationMap.setTs(readTs)
lists[missIdx[reIdx[k]]] = l
}
}
return lists, nil
}

func (ml *MemoryLayer) readFromDisk(key []byte, pstore *badger.DB, readTs uint64, readUids bool) (*List, error) {
txn := pstore.NewTransactionAt(readTs, false)
defer txn.Discard()
Expand Down
25 changes: 25 additions & 0 deletions posting/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ func (vt *viTxn) Get(key []byte) ([]byte, error) {
return vt.GetValueFromPostingList(pl)
}

// MultiGet resolves many keys to their values in one batched read (see
// LocalCache.MultiGet). vals and errs are aligned with keys; errs[i] is
// ErrNoValue when keys[i] has no value, matching Get's per-key semantics.
func (vt *viTxn) MultiGet(keys [][]byte) ([][]byte, []error) {
vals := make([][]byte, len(keys))
errs := make([]error, len(keys))
lists, err := vt.delegate.cache.MultiGet(keys)
if err != nil {
for i := range errs {
errs[i] = err
}
return vals, errs
}
for i, pl := range lists {
if pl == nil {
errs[i] = ErrNoValue
continue
}
pl.Lock()
vals[i], errs[i] = vt.GetValueFromPostingList(pl)
pl.Unlock()
}
return vals, errs
}

func (vt *viTxn) GetWithLockHeld(key []byte) ([]byte, error) {
pl, err := vt.delegate.cache.Get(key)
if err != nil {
Expand Down
Loading
Loading