diff --git a/posting/index.go b/posting/index.go index ae6c3352a44..3743a21c283 100644 --- a/posting/index.go +++ b/posting/index.go @@ -1702,6 +1702,11 @@ func prefixesToDropVectorIndexEdges(ctx context.Context, rb *IndexRebuild) [][]b prefixes := append([][]byte{}, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecEntry))) prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecDead))) prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecKeyword))) + // VecQuant ("__vector_q") is a distinct predicate from VecKeyword + // ("__vector_"), so its keys are not covered by the VecKeyword prefix and + // must be dropped explicitly on rebuild to avoid leaving stale quantized + // blobs behind. + prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecQuant))) for i := range hnsw.VectorIndexMaxLevels { prefixes = append(prefixes, x.PredicatePrefix(hnsw.ConcatStrings(rb.Attr, hnsw.VecKeyword, fmt.Sprint(i)))) diff --git a/schema/parse_test.go b/schema/parse_test.go index adb64311f2d..e8e738929be 100644 --- a/schema/parse_test.go +++ b/schema/parse_test.go @@ -689,3 +689,26 @@ func TestMain(m *testing.M) { Init(ps) m.Run() } + +// TestParseVectorQuantizeOption verifies the int8 quantization option is +// accepted end-to-end through schema parsing (registered as an HNSW factory +// option) and that an invalid value is rejected at parse time. +func TestParseVectorQuantizeOption(t *testing.T) { + require.NoError(t, ParseBytes([]byte( + `vqi: float32vector @index(hnsw(metric:"euclidean", quantize:"int8")) .`+"\n"), 1)) + su, ok := State().predicate[x.AttrInRootNamespace("vqi")] + require.True(t, ok) + require.Len(t, su.IndexSpecs, 1) + found := false + for _, op := range su.IndexSpecs[0].Options { + if op.Key == "quantize" { + require.Equal(t, "int8", op.Value) + found = true + } + } + require.True(t, found, "quantize option must be captured in the vector index spec") + + // An unsupported quantize value must be rejected when the schema is parsed. + require.Error(t, ParseBytes([]byte( + `vqbad: float32vector @index(hnsw(quantize:"int4")) .`+"\n"), 1)) +} diff --git a/schema/schema.go b/schema/schema.go index 3ce0da8ea74..cef16c21a2b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -475,6 +475,7 @@ func (s *state) PredicatesToDelete(pred string) []string { preds = append(preds, pred+hnsw.VecEntry) preds = append(preds, pred+hnsw.VecKeyword) preds = append(preds, pred+hnsw.VecDead) + preds = append(preds, pred+hnsw.VecQuant) } } return preds diff --git a/tok/hnsw/helper.go b/tok/hnsw/helper.go index 39d72d8f5e7..54c53cccafd 100644 --- a/tok/hnsw/helper.go +++ b/tok/hnsw/helper.go @@ -30,6 +30,7 @@ const ( DotProd = "dotproduct" EmptyHNSWTreeError = "HNSW tree has no elements" VecKeyword = "__vector_" + VecQuant = "__vector_q" // per-node int8-quantized vector (opt-in) visitedVectorsLevel = "visited_vectors_level_" distanceComputations = "vector_distance_computations" searchTime = "vector_search_time" @@ -365,6 +366,32 @@ var emptyVec = []byte{} // adds the data corresponding to a uid to the given vec variable in the form of []T // this does not allocate memory for vec, so it must be allocated before calling this function func (ph *persistentHNSW[T]) getVecFromUid(uid uint64, c index.CacheType, vec *[]T) error { + // Quantized index: read the int8 blob from vecQKey and dequantize into the + // caller's reused buffer. On a missing/undecodable blob, fall back to the + // raw vector (graceful degradation; the build writes vecQKey for every node + // so misses are rare and indicate a partial/corrupt index). + if ph.quantize { + data, err := getDataFromKeyWithCacheType(ph.vecQKey, uid, c) + if err != nil && !errors.Is(err, errFetchingPostingList) { + return err + } + if len(data) > 0 { + if derr := index.DequantizeInto(data, vec); derr == nil { + // Accept only when the decoded length matches the index + // dimension. ph.dim is established solely from full-precision + // vectors (the raw path below), so a corrupt blob can never + // poison it, and a wrong-length slice never reaches the SIMD + // distance kernels. When dim is not yet known (d == 0) we fall + // back to raw, which both returns a correct-length vector and + // sets dim from trusted data. + if d := ph.dim.Load(); d != 0 && int(d) == len(*vec) { + return nil + } + } + // fall through to the raw vector on decode failure or dim mismatch. + } + } + data, err := getDataFromKeyWithCacheType(ph.pred, uid, c) if err != nil { if errors.Is(err, errFetchingPostingList) { @@ -376,6 +403,7 @@ func (ph *persistentHNSW[T]) getVecFromUid(uid uint64, c index.CacheType, vec *[ } if data != nil { index.BytesAsFloatArray(data, vec, ph.floatBits) + ph.noteDim(len(*vec)) return nil } else { index.BytesAsFloatArray(emptyVec, vec, ph.floatBits) @@ -383,6 +411,47 @@ func (ph *persistentHNSW[T]) getVecFromUid(uid uint64, c index.CacheType, vec *[ } } +// noteDim records the index's vector dimension the first time a vector is +// materialized. Safe for concurrent use during a multi-goroutine build. +func (ph *persistentHNSW[T]) noteDim(n int) { + if n > 0 { + ph.dim.CompareAndSwap(0, int32(n)) + } +} + +// writeQuantizedVec stores the int8-quantized copy of inVec at vecQKey[uid]. +// No-op unless the index is quantized. It must be called before uid can be read +// as a neighbor by a later insertion; since insertHelper calls it up front for +// the node being inserted, and neighbors are always earlier insertions, the +// blob is always present by the time it is needed. +func (ph *persistentHNSW[T]) writeQuantizedVec( + ctx context.Context, tc *TxnCache, uid uint64, inVec []T) error { + if !ph.quantize || len(inVec) == 0 { + return nil + } + // Fast path: T is already float32 (the only width quantization supports), + // so avoid the per-insert copy. + f32, ok := any(inVec).([]float32) + if !ok { + f32 = make([]float32, len(inVec)) + for i, x := range inVec { + f32[i] = float32(x) + } + } + blob := index.QuantizeFloat32(f32) + if blob == nil { + return nil + } + key := DataKey(ph.vecQKey, uid) + tc.txn.LockKey(key) + defer tc.txn.UnlockKey(key) + return tc.txn.AddMutationWithLockHeld(ctx, key, &index.KeyValue{ + Entity: uid, + Attr: ph.vecQKey, + Value: blob, + }) +} + // chooses whether to create the entry and start nodes based on if it already // exists, and if it hasnt been created yet, it adds the startNode to all // levels. diff --git a/tok/hnsw/persistent_factory.go b/tok/hnsw/persistent_factory.go index b7485c83932..b62c1e19c19 100644 --- a/tok/hnsw/persistent_factory.go +++ b/tok/hnsw/persistent_factory.go @@ -21,6 +21,7 @@ const ( EfConstructionOpt string = "efConstruction" EfSearchOpt string = "efSearch" MetricOpt string = "metric" + QuantizeOpt string = "quantize" Hnsw string = "hnsw" ) @@ -73,8 +74,18 @@ func (hf *persistentIndexFactory[T]) AllowedOptions() opt.AllowedOptions { } return GetSimType[T](optValue, hf.floatBits), nil } - retVal.AddCustomOption(MetricOpt, getSimFunc) + + // quantize is validated at the option layer so a bad value (e.g. "int4") + // is rejected when the schema is altered, not at first index build. + getQuantFunc := func(optValue string) (any, error) { + if optValue != "int8" { + return nil, fmt.Errorf("unsupported %q value %q (only \"int8\" is supported)", + QuantizeOpt, optValue) + } + return optValue, nil + } + retVal.AddCustomOption(QuantizeOpt, getQuantFunc) return retVal } @@ -106,6 +117,7 @@ func (hf *persistentIndexFactory[T]) createWithLock( vecEntryKey: ConcatStrings(name, VecEntry), vecKey: ConcatStrings(name, VecKeyword), vecDead: ConcatStrings(name, VecDead), + vecQKey: ConcatStrings(name, VecQuant), floatBits: floatBits, nodeAllEdges: map[uint64][][]uint64{}, } diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 5658800e579..7d1a79cdfbf 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -9,6 +9,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "time" c "github.com/dgraph-io/dgraph/v25/tok/constraints" @@ -26,8 +27,19 @@ type persistentHNSW[T c.Float] struct { vecEntryKey string vecKey string vecDead string + vecQKey string simType SimilarityType[T] floatBits int + // quantize is true when this index stores an int8-quantized copy of each + // vector in vecQKey and computes distances against the dequantized copy + // (opt-in via the "quantize":"int8" index option). The raw float vectors + // in pred are left untouched. + quantize bool + // dim is the vector dimension, learned lazily from the first materialized + // vector. Used to reject a quantized blob whose dimension disagrees + // (corruption / stale schema) before its wrong-length slice reaches the + // SIMD distance kernels. 0 means "not yet known". + dim atomic.Int32 // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first // layer for UUID 65443. The result will be a neighboring UUID. nodeAllEdges map[uint64][][]uint64 @@ -58,6 +70,10 @@ func GetPersistantOptions[T c.Float](o opt.Options) string { sb.WriteString(fmt.Sprintf(`"%s":"%s",`, MetricOpt, sim.indexType)) } + if val, ok, _ := opt.GetOpt(o, QuantizeOpt, ""); ok && val != "" { + sb.WriteString(fmt.Sprintf(`"%s":"%s",`, QuantizeOpt, val)) + } + final := sb.String() if len(final) > 0 { // Remove last , and cover with brackets @@ -109,6 +125,21 @@ func (ph *persistentHNSW[T]) applyOptions(o opt.Options) error { insortHeap: insortPersistentHeapAscending[T], isBetterScore: isBetterScoreForDistance[T], isSimilarityMetric: false} } + + qval, _, err := opt.GetOpt(o, QuantizeOpt, "") + if err != nil { + return err + } + if qval != "" { + if qval != "int8" { + return fmt.Errorf("unsupported %q value %q (only \"int8\" is supported)", QuantizeOpt, qval) + } + // int8 scalar quantization currently targets 32-bit float vectors. + if ph.floatBits != 32 { + return fmt.Errorf("%q=int8 requires 32-bit float vectors, got %d-bit", QuantizeOpt, ph.floatBits) + } + ph.quantize = true + } return nil } @@ -572,6 +603,12 @@ func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, inUuid uint64, inVec []T) ([]persistentHeapElement[T], []*index.KeyValue, error) { + // Persist the quantized copy of this node's vector first (no-op unless the + // index is quantized), so later insertions can read it as a neighbor. + if err := ph.writeQuantizedVec(ctx, tc, inUuid, inVec); err != nil { + return []persistentHeapElement[T]{}, []*index.KeyValue{}, err + } + // return all the new edges created at all HNSW levels var startVec []T entry, edges, err := ph.createEntryAndStartNodes(ctx, tc, inUuid, &startVec) diff --git a/tok/hnsw/quantize_integration_test.go b/tok/hnsw/quantize_integration_test.go new file mode 100644 index 00000000000..dd859b47ff2 --- /dev/null +++ b/tok/hnsw/quantize_integration_test.go @@ -0,0 +1,198 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hnsw + +import ( + "context" + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/tok/index" + opt "github.com/dgraph-io/dgraph/v25/tok/options" + "github.com/dgraph-io/dgraph/v25/x" +) + +func float32ArrayAsBytes(v []float32) []byte { + b := make([]byte, 4*len(v)) + for i, f := range v { + binary.LittleEndian.PutUint32(b[i*4:], math.Float32bits(f)) + } + return b +} + +// TestQuantizedOptionParsing checks the opt-in plumbing and the float-width guard. +func TestQuantizedOptionParsing(t *testing.T) { + mk := func(bits int, q string) (*persistentHNSW[float32], error) { + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(MetricOpt, GetSimType[float32](Euclidean, bits)) + if q != "" { + options.SetOpt(QuantizeOpt, q) + } + idx, err := CreateFactory[float32](bits).Create( + x.NamespaceAttr(x.RootNamespace, "quant_opt_"+q), options, bits) + if err != nil { + return nil, err + } + return idx.(*persistentHNSW[float32]), nil + } + + ph, err := mk(32, "int8") + require.NoError(t, err) + require.True(t, ph.quantize) + require.Equal(t, ConcatStrings(ph.pred, VecQuant), ph.vecQKey) + + ph, err = mk(32, "") + require.NoError(t, err) + require.False(t, ph.quantize, "quantization must be off by default") + + _, err = mk(32, "int4") + require.Error(t, err, "unsupported quantize value must be rejected") + + // 64-bit float vectors are not supported for int8 quantization. + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + options.SetOpt(QuantizeOpt, "int8") + _, err = CreateFactory[float64](64).Create( + x.NamespaceAttr(x.RootNamespace, "quant_opt_64"), options, 64) + require.Error(t, err, "int8 quantization must require 32-bit vectors") +} + +// TestQuantizedSearchReadPath drives a real SearchWithOptions over a quantized +// index whose vectors live in __vector_q. It mirrors the non-quantized recall +// test (same graph/vectors) and must surface the true nearest neighbor (uid +// 100), proving the read+dequantize path feeds distances correctly. +func TestQuantizedSearchReadPath(t *testing.T) { + ctx := context.Background() + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(EfSearchOpt, 1) + options.SetOpt(MetricOpt, GetSimType[float32](Euclidean, 32)) + options.SetOpt(QuantizeOpt, "int8") + pred := x.NamespaceAttr(x.RootNamespace, "quant_read_pred") + idx, err := CreateFactory[float32](32).Create(pred, options, 32) + require.NoError(t, err) + ph := idx.(*persistentHNSW[float32]) + require.True(t, ph.quantize) + + vectors := map[uint64][]float32{ + 1: {0, 0, 10, 0}, + 100: {0, 0, 0.1, 0}, + 200: {0, 0, 3, 0}, + 201: {0, 0, 3.2, 0}, + } + data := make(map[string][]byte) + for uid, vec := range vectors { + // Stored as quantized blobs in __vector_q (what the index reads). + data[string(DataKey(ph.vecQKey, uid))] = index.QuantizeFloat32(vec) + } + // Provide the raw vector for the entry node only: the first read seeds the + // index dimension from trusted full-precision data. The neighbors (200, 201, + // 100) have NO raw vector, so search MUST use their quantized blobs — this + // proves the quantized read path drives traversal/distance. + data[string(DataKey(ph.pred, 1))] = float32ArrayAsBytes(vectors[1]) + data[string(DataKey(ph.vecEntryKey, 1))] = Uint64ToBytes(1) + ph.nodeAllEdges[1] = [][]uint64{{}, {200, 201}} + ph.nodeAllEdges[200] = [][]uint64{{1}, {1}} + ph.nodeAllEdges[201] = [][]uint64{{1}, {100}} + ph.nodeAllEdges[100] = [][]uint64{{201}, {201}} + + cache := &memoryCache{data: data} + query := []float32{0, 0, 0.12, 0} + + res, err := ph.SearchWithOptions(ctx, cache, query, 1, + index.VectorIndexOptions[float32]{EfOverride: 4}) + require.NoError(t, err) + require.Equal(t, []uint64{100}, res, "quantized search must find the true nearest neighbor") +} + +// TestQuantizedInsertWritesBlob exercises the write path: a real Insert on a +// quantized index must persist a __vector_q blob that round-trips back to ~the +// input vector. +func TestQuantizedInsertWritesBlob(t *testing.T) { + emptyTsDbs() + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(EfConstructionOpt, 5) + options.SetOpt(EfSearchOpt, 5) + options.SetOpt(MetricOpt, GetSimType[float32](Euclidean, 32)) + options.SetOpt(QuantizeOpt, "int8") + pred := x.NamespaceAttr(x.RootNamespace, "quant_insert_pred") + idx, err := CreateFactory[float32](32).Create(pred, options, 32) + require.NoError(t, err) + ph := idx.(*persistentHNSW[float32]) + + tc := NewTxnCache(&inMemTxn{startTs: 0, commitTs: 1}, 0) + vec := []float32{1, 2, 3, 4, 5, 6, 7, 8} + _, err = ph.Insert(context.TODO(), tc, 42, vec) + require.NoError(t, err) + + blob := tsDbs[99].inMemTestDb[string(DataKey(ph.vecQKey, 42))] + require.NotEmpty(t, blob, "Insert must persist a __vector_q blob") + require.Equal(t, len(vec), index.QuantizedDim(blob)) + + var got []float32 + require.NoError(t, index.DequantizeInto(blob, &got)) + require.Len(t, got, len(vec)) + for i := range vec { + require.InDelta(t, vec[i], got[i], 0.05, "dim %d", i) + } +} + +// TestQuantizedFallbackOnBadBlob verifies that a corrupt or wrong-dimension +// __vector_q blob does not crash search: getVecFromUid falls back to the raw +// vector, and search still returns the true nearest neighbor. +func TestQuantizedFallbackOnBadBlob(t *testing.T) { + ctx := context.Background() + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(EfSearchOpt, 1) + options.SetOpt(MetricOpt, GetSimType[float32](Euclidean, 32)) + options.SetOpt(QuantizeOpt, "int8") + pred := x.NamespaceAttr(x.RootNamespace, "quant_fallback_pred") + idx, err := CreateFactory[float32](32).Create(pred, options, 32) + require.NoError(t, err) + ph := idx.(*persistentHNSW[float32]) + + vectors := map[uint64][]float32{ + 1: {0, 0, 10, 0}, + 100: {0, 0, 0.1, 0}, + 200: {0, 0, 3, 0}, + 201: {0, 0, 3.2, 0}, + } + data := make(map[string][]byte) + for uid, vec := range vectors { + data[string(DataKey(ph.pred, uid))] = float32ArrayAsBytes(vec) // raw (fallback) + data[string(DataKey(ph.vecQKey, uid))] = index.QuantizeFloat32(vec) // good quant + } + // Corrupt uid 100's blob (valid header claiming dim 9 but no codes -> length + // mismatch -> decode error -> raw fallback). + bad := []byte{0x71, 1, 1, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0} + data[string(DataKey(ph.vecQKey, 100))] = bad + // Also corrupt the ENTRY node's blob (uid 1, read first): proves ph.dim is + // established from its raw fallback rather than poisoned by a bad first blob. + data[string(DataKey(ph.vecQKey, 1))] = bad + // Wrong-dimension (valid) blob for uid 201: dim 3 != index dim 4 -> dim + // guard rejects it -> raw fallback. + data[string(DataKey(ph.vecQKey, 201))] = index.QuantizeFloat32([]float32{1, 2, 3}) + + data[string(DataKey(ph.vecEntryKey, 1))] = Uint64ToBytes(1) + ph.nodeAllEdges[1] = [][]uint64{{}, {200, 201}} + ph.nodeAllEdges[200] = [][]uint64{{1}, {1}} + ph.nodeAllEdges[201] = [][]uint64{{1}, {100}} + ph.nodeAllEdges[100] = [][]uint64{{201}, {201}} + + cache := &memoryCache{data: data} + query := []float32{0, 0, 0.12, 0} + res, err := ph.SearchWithOptions(ctx, cache, query, 1, + index.VectorIndexOptions[float32]{EfOverride: 4}) + require.NoError(t, err) + require.Equal(t, []uint64{100}, res, "must fall back to raw and find the true NN without crashing") +} diff --git a/tok/index/quantize.go b/tok/index/quantize.go new file mode 100644 index 00000000000..171c0c0eb60 --- /dev/null +++ b/tok/index/quantize.go @@ -0,0 +1,248 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package index + +import ( + "encoding/binary" + "errors" + "math" + + c "github.com/dgraph-io/dgraph/v25/tok/constraints" +) + +// Scalar (uint8) quantization for dense float vectors. +// +// Each vector is quantized independently with an affine (per-vector min/step) +// scheme: every component is mapped to a uint8 code in [0,255]. The encoded +// blob is self-describing and versioned: +// +// [0] magic = 'q' (0x71) +// [1] version (quantVersion1) +// [2] codec (quantCodecAffineU8) +// [3] flags (reserved, 0) +// [4:6) uint16 dim (little-endian) +// [6:10) float32 min (little-endian) +// [10:14) float32 step (= (max-min)/255; 0 for a constant vector) +// [14:14+dim) dim uint8 codes +// +// so the on-disk size is dim+14 bytes vs dim*4 for raw float32 — ~3.8x smaller +// at 384 dims. The codec is stateless (no global training pass), so it is safe +// to apply per vector as they are written, and robust to distribution shifts. +// +// Distance is computed asymmetrically: the query stays full-precision and each +// stored component is dequantized (min + code*step). In dgraph's HNSW the +// integration dequantizes a blob into a reused float32 buffer and reuses the +// existing SIMD distance kernels; the Asym* helpers here are an allocation-free +// alternative kept for tests and a possible future fast path. +// +// NaN/Inf inputs are sanitized at encode time (NaN->0, +/-Inf->+/-MaxFloat32) +// because a NaN leaking into a distance would corrupt HNSW's ordering. + +const ( + quantMagic = 0x71 // 'q' + quantVersion1 = 1 + quantCodecAffineU8 = 1 + quantHeaderSize = 14 // magic+version+codec+flags + u16 dim + f32 min + f32 step +) + +// ErrInvalidQuantBlob is returned when a quantized blob is malformed, +// the wrong version/codec, or its declared dim does not match its length. +var ErrInvalidQuantBlob = errors.New("index: invalid quantized vector blob") + +// QuantizedLen returns the encoded byte length for a vector of dimension dim. +func QuantizedLen(dim int) int { return quantHeaderSize + dim } + +func sanitize(x float32) float32 { + switch { + case x != x: // NaN + return 0 + case math.IsInf(float64(x), 1): + return math.MaxFloat32 + case math.IsInf(float64(x), -1): + return -math.MaxFloat32 + default: + return x + } +} + +// QuantizeFloat32 affinely quantizes v into a self-describing, versioned blob. +// A nil/empty input yields a nil blob (mirrors an empty vector). Dimensions +// above the uint16 range are not supported and yield a nil blob. +func QuantizeFloat32(v []float32) []byte { + if len(v) == 0 || len(v) > math.MaxUint16 { + return nil + } + lo, hi := sanitize(v[0]), sanitize(v[0]) + for _, raw := range v[1:] { + x := sanitize(raw) + if x < lo { + lo = x + } + if x > hi { + hi = x + } + } + // Compute in float64 to avoid overflow when the (sanitized) range is huge. + step := float32((float64(hi) - float64(lo)) / 255.0) + + out := make([]byte, quantHeaderSize+len(v)) + out[0] = quantMagic + out[1] = quantVersion1 + out[2] = quantCodecAffineU8 + out[3] = 0 + binary.LittleEndian.PutUint16(out[4:6], uint16(len(v))) + binary.LittleEndian.PutUint32(out[6:10], math.Float32bits(lo)) + binary.LittleEndian.PutUint32(out[10:14], math.Float32bits(step)) + codes := out[quantHeaderSize:] + if step == 0 { + // Constant vector: all codes 0, dequantize back to lo. + return out + } + inv := 1.0 / step + for i, raw := range v { + c := int32(math.Round(float64((sanitize(raw) - lo) * inv))) + if c < 0 { + c = 0 + } else if c > 255 { + c = 255 + } + codes[i] = byte(c) + } + return out +} + +// quantParams validates a blob and extracts (min, step, codes). +func quantParams(blob []byte) (lo, step float32, codes []byte, err error) { + if len(blob) == 0 { + return 0, 0, nil, nil + } + if len(blob) < quantHeaderSize || + blob[0] != quantMagic || blob[1] != quantVersion1 || blob[2] != quantCodecAffineU8 { + return 0, 0, nil, ErrInvalidQuantBlob + } + dim := int(binary.LittleEndian.Uint16(blob[4:6])) + if len(blob) != quantHeaderSize+dim { + return 0, 0, nil, ErrInvalidQuantBlob + } + lo = math.Float32frombits(binary.LittleEndian.Uint32(blob[6:10])) + step = math.Float32frombits(binary.LittleEndian.Uint32(blob[10:14])) + codes = blob[quantHeaderSize:] + return lo, step, codes, nil +} + +// QuantizedDim returns the vector dimension declared in blob, or 0 if the blob +// is empty or malformed. +func QuantizedDim(blob []byte) int { + if len(blob) < quantHeaderSize || blob[0] != quantMagic { + return 0 + } + return int(binary.LittleEndian.Uint16(blob[4:6])) +} + +// DequantizeFloat32 reconstructs an approximate float32 vector from a blob, +// appending into *out (reusing its capacity when possible). The result slice is +// sliced to exactly the blob's dimension. +func DequantizeFloat32(blob []byte, out *[]float32) error { + lo, step, codes, err := quantParams(blob) + if err != nil { + return err + } + dst := (*out)[:0] + if cap(dst) < len(codes) { + dst = make([]float32, 0, len(codes)) + } + lo64, step64 := float64(lo), float64(step) + for _, c := range codes { + dst = append(dst, float32(lo64+float64(c)*step64)) + } + *out = dst + return nil +} + +// DequantizeInto reconstructs an approximate vector of element type T (float32 +// or float64) from a blob, appending into *out and reusing its capacity. This +// is the hot-path decoder used by HNSW: callers pass their reused per-vector +// buffer so no allocation happens once it is warm. The result is sliced to +// exactly the blob's dimension. +func DequantizeInto[T c.Float](blob []byte, out *[]T) error { + lo, step, codes, err := quantParams(blob) + if err != nil { + return err + } + dst := (*out)[:0] + if cap(dst) < len(codes) { + dst = make([]T, 0, len(codes)) + } + lo64, step64 := float64(lo), float64(step) + for _, code := range codes { + dst = append(dst, T(lo64+float64(code)*step64)) + } + *out = dst + return nil +} + +// AsymSquaredL2Float32 returns the squared L2 distance between a full-precision +// query and a quantized vector, dequantizing each stored component on the fly. +func AsymSquaredL2Float32(query []float32, blob []byte) (float32, error) { + lo, step, codes, err := quantParams(blob) + if err != nil { + return 0, err + } + if len(query) != len(codes) { + return 0, ErrInvalidQuantBlob + } + lo64, step64 := float64(lo), float64(step) + var sum float64 + for i, c := range codes { + d := float64(query[i]) - (lo64 + float64(c)*step64) + sum += d * d + } + return float32(sum), nil +} + +// AsymDotFloat32 returns the dot product of a full-precision query and a +// quantized vector. +func AsymDotFloat32(query []float32, blob []byte) (float32, error) { + lo, step, codes, err := quantParams(blob) + if err != nil { + return 0, err + } + if len(query) != len(codes) { + return 0, ErrInvalidQuantBlob + } + lo64, step64 := float64(lo), float64(step) + var sum float64 + for i, c := range codes { + sum += float64(query[i]) * (lo64 + float64(c)*step64) + } + return float32(sum), nil +} + +// AsymCosineSimilarityFloat32 returns the cosine similarity between a +// full-precision query and a quantized vector. Returns 0 if either side has +// zero magnitude. +func AsymCosineSimilarityFloat32(query []float32, blob []byte) (float32, error) { + lo, step, codes, err := quantParams(blob) + if err != nil { + return 0, err + } + if len(query) != len(codes) { + return 0, ErrInvalidQuantBlob + } + lo64, step64 := float64(lo), float64(step) + var dot, qn, vn float64 + for i, c := range codes { + q := float64(query[i]) + v := lo64 + float64(c)*step64 + dot += q * v + qn += q * q + vn += v * v + } + if qn == 0 || vn == 0 { + return 0, nil + } + return float32(dot / (math.Sqrt(qn) * math.Sqrt(vn))), nil +} diff --git a/tok/index/quantize_test.go b/tok/index/quantize_test.go new file mode 100644 index 00000000000..4186d83164d --- /dev/null +++ b/tok/index/quantize_test.go @@ -0,0 +1,244 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package index + +import ( + "math" + "math/rand" + "sort" + "testing" + + "github.com/stretchr/testify/require" +) + +func randVec(rng *rand.Rand, dim int) []float32 { + v := make([]float32, dim) + for i := range v { + v[i] = float32(rng.NormFloat64()) + } + return v +} + +func exactSqL2(a, b []float32) float64 { + var s float64 + for i := range a { + d := float64(a[i]) - float64(b[i]) + s += d * d + } + return s +} + +func TestQuantizeRoundTrip(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + for _, dim := range []int{1, 3, 8, 128, 384, 768} { + v := randVec(rng, dim) + blob := QuantizeFloat32(v) + require.Equal(t, QuantizedLen(dim), len(blob)) + require.Equal(t, dim, QuantizedDim(blob)) + + var got []float32 + require.NoError(t, DequantizeFloat32(blob, &got)) + require.Equal(t, dim, len(got)) + + // Per-component reconstruction error is bounded by step/2 = range/510. + lo, hi := v[0], v[0] + for _, x := range v { + lo, hi = float32(math.Min(float64(lo), float64(x))), float32(math.Max(float64(hi), float64(x))) + } + tol := (hi-lo)/510.0 + 1e-6 + for i := range v { + require.LessOrEqual(t, math.Abs(float64(v[i]-got[i])), float64(tol), + "dim=%d i=%d v=%f got=%f", dim, i, v[i], got[i]) + } + } +} + +func TestQuantizeEdgeCases(t *testing.T) { + // empty + require.Nil(t, QuantizeFloat32(nil)) + require.Nil(t, QuantizeFloat32([]float32{})) + var out []float32 + require.NoError(t, DequantizeFloat32(nil, &out)) + require.Empty(t, out) + + // constant vector -> step 0, dequant returns the constant + blob := QuantizeFloat32([]float32{2.5, 2.5, 2.5}) + require.NoError(t, DequantizeFloat32(blob, &out)) + require.Equal(t, []float32{2.5, 2.5, 2.5}, out) + + // malformed blob + _, _, _, err := quantParams([]byte{1, 2, 3}) + require.ErrorIs(t, err, ErrInvalidQuantBlob) + _, err = AsymSquaredL2Float32([]float32{1, 2}, blob) // dim mismatch (blob dim 3) + require.ErrorIs(t, err, ErrInvalidQuantBlob) +} + +func TestQuantizeSanitizeAndHeader(t *testing.T) { + // NaN/Inf are sanitized at encode time so they never reach a distance. + v := []float32{float32(math.NaN()), float32(math.Inf(1)), float32(math.Inf(-1)), 1.0} + blob := QuantizeFloat32(v) + require.Equal(t, 4, QuantizedDim(blob)) + var out []float32 + require.NoError(t, DequantizeFloat32(blob, &out)) + for _, x := range out { + require.False(t, math.IsNaN(float64(x)) || math.IsInf(float64(x), 0), "got non-finite %v", x) + } + // A distance may legitimately overflow to +Inf for absurd (~MaxFloat32) + // inputs, but it must never be NaN — NaN would break HNSW heap ordering, + // which is the whole reason we sanitize. + q := []float32{0, 0, 0, 0} + d, err := AsymSquaredL2Float32(q, blob) + require.NoError(t, err) + require.False(t, math.IsNaN(float64(d)), "distance must never be NaN") + + // Header validation: bad magic / version / codec / truncated all rejected. + good := QuantizeFloat32([]float32{1, 2, 3}) + bad := append([]byte{}, good...) + bad[0] = 0x00 // wrong magic + _, _, _, err = quantParams(bad) + require.ErrorIs(t, err, ErrInvalidQuantBlob) + bad = append([]byte{}, good...) + bad[1] = 9 // wrong version + _, _, _, err = quantParams(bad) + require.ErrorIs(t, err, ErrInvalidQuantBlob) + bad = append([]byte{}, good...) + bad[2] = 9 // wrong codec + _, _, _, err = quantParams(bad) + require.ErrorIs(t, err, ErrInvalidQuantBlob) + _, _, _, err = quantParams(good[:len(good)-1]) // length != header+dim + require.ErrorIs(t, err, ErrInvalidQuantBlob) +} + +func TestAsymDistanceApproxMatchesExact(t *testing.T) { + rng := rand.New(rand.NewSource(7)) + const dim = 384 + for trial := 0; trial < 200; trial++ { + q := randVec(rng, dim) + v := randVec(rng, dim) + blob := QuantizeFloat32(v) + + // Squared L2: asymmetric (q exact, v quantized) vs fully-exact. + gotL2, err := AsymSquaredL2Float32(q, blob) + require.NoError(t, err) + wantL2 := exactSqL2(q, v) + // Relative error should be small (quantization only perturbs v). + require.InEpsilon(t, wantL2, float64(gotL2), 0.05, "trial=%d", trial) + + // Dot product. + gotDot, err := AsymDotFloat32(q, blob) + require.NoError(t, err) + var wantDot float64 + for i := range q { + wantDot += float64(q[i]) * float64(v[i]) + } + require.InDelta(t, wantDot, float64(gotDot), 0.02*math.Abs(wantDot)+0.5, "trial=%d", trial) + + // Cosine. + gotCos, err := AsymCosineSimilarityFloat32(q, blob) + require.NoError(t, err) + var dot, qn, vn float64 + for i := range q { + dot += float64(q[i]) * float64(v[i]) + qn += float64(q[i]) * float64(q[i]) + vn += float64(v[i]) * float64(v[i]) + } + wantCos := dot / (math.Sqrt(qn) * math.Sqrt(vn)) + require.InDelta(t, wantCos, float64(gotCos), 0.02, "trial=%d", trial) + } +} + +// TestQuantizationRecall is the key quality test: over a corpus of random +// vectors, the top-k nearest neighbors ranked by asymmetric quantized L2 must +// largely agree with the exact top-k. This is the metric that matters for an +// ANN index using quantized distance. +func TestQuantizationRecall(t *testing.T) { + rng := rand.New(rand.NewSource(42)) + const ( + dim = 384 + corpus = 2000 + queries = 200 + k = 10 + ) + vecs := make([][]float32, corpus) + blobs := make([][]byte, corpus) + for i := range vecs { + vecs[i] = randVec(rng, dim) + blobs[i] = QuantizeFloat32(vecs[i]) + } + + topK := func(score func(i int) float64) []int { + idx := make([]int, corpus) + for i := range idx { + idx[i] = i + } + sort.Slice(idx, func(a, b int) bool { return score(idx[a]) < score(idx[b]) }) + return idx[:k] + } + + var hits, total int + for qi := 0; qi < queries; qi++ { + q := randVec(rng, dim) + exact := topK(func(i int) float64 { return exactSqL2(q, vecs[i]) }) + approx := topK(func(i int) float64 { + d, _ := AsymSquaredL2Float32(q, blobs[i]) + return float64(d) + }) + exactSet := map[int]bool{} + for _, i := range exact { + exactSet[i] = true + } + for _, i := range approx { + if exactSet[i] { + hits++ + } + } + total += k + } + recall := float64(hits) / float64(total) + t.Logf("recall@%d = %.4f (corpus=%d, dim=%d, queries=%d)", k, recall, corpus, dim, queries) + // int8 scalar quantization should preserve recall well above 0.90. + require.Greater(t, recall, 0.90, "recall@%d too low: %.4f", k, recall) +} + +func BenchmarkQuantizeFloat32(b *testing.B) { + rng := rand.New(rand.NewSource(1)) + v := randVec(rng, 384) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = QuantizeFloat32(v) + } +} + +func BenchmarkAsymSquaredL2(b *testing.B) { + rng := rand.New(rand.NewSource(1)) + q := randVec(rng, 384) + blob := QuantizeFloat32(randVec(rng, 384)) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = AsymSquaredL2Float32(q, blob) + } +} + +// BenchmarkExactSquaredL2 is the float32 baseline for comparison. +func BenchmarkExactSquaredL2(b *testing.B) { + rng := rand.New(rand.NewSource(1)) + q := randVec(rng, 384) + v := randVec(rng, 384) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = exactSqL2(q, v) + } +} + +// BenchmarkMemoryFootprint documents the size reduction. +func BenchmarkMemoryFootprint(b *testing.B) { + const dim = 384 + raw := dim * 4 + quant := QuantizedLen(dim) + b.ReportMetric(float64(raw)/float64(quant), "x-smaller") + b.ReportMetric(float64(raw), "raw-bytes") + b.ReportMetric(float64(quant), "quant-bytes") +} diff --git a/worker/backup.go b/worker/backup.go index 31077854dd3..110ef76eb96 100644 --- a/worker/backup.go +++ b/worker/backup.go @@ -351,7 +351,8 @@ func ProcessBackupRequest(ctx context.Context, req *pb.BackupRequest) error { for _, pred := range schema { if pred.Type == "float32vector" && len(pred.IndexSpecs) != 0 { vecPredMap[gid] = append(vecPredMap[gid], pred.Predicate+hnsw.VecEntry, - pred.Predicate+hnsw.VecKeyword, pred.Predicate+hnsw.VecDead) + pred.Predicate+hnsw.VecKeyword, pred.Predicate+hnsw.VecDead, + pred.Predicate+hnsw.VecQuant) } } } diff --git a/worker/online_restore.go b/worker/online_restore.go index 9709f41e6c8..29d1f3a1d1a 100644 --- a/worker/online_restore.go +++ b/worker/online_restore.go @@ -321,7 +321,8 @@ func handleRestoreProposal(ctx context.Context, req *pb.RestoreRequest, pidx uin // still be restored to the correct Alpha's Badger store. if strings.HasSuffix(pred, hnsw.VecEntry) || strings.HasSuffix(pred, hnsw.VecKeyword) || - strings.HasSuffix(pred, hnsw.VecDead) { + strings.HasSuffix(pred, hnsw.VecDead) || + strings.HasSuffix(pred, hnsw.VecQuant) { continue } diff --git a/worker/restore_map.go b/worker/restore_map.go index de17eeb4d45..62230402528 100644 --- a/worker/restore_map.go +++ b/worker/restore_map.go @@ -488,7 +488,7 @@ func (m *mapper) processReqCh(ctx context.Context) error { // If the predicate is a vector indexing predicate, skip further processing. // currently we don't store vector supporting predicates in the schema. if strings.HasSuffix(parsedKey.Attr, hnsw.VecEntry) || strings.HasSuffix(parsedKey.Attr, hnsw.VecKeyword) || - strings.HasSuffix(parsedKey.Attr, hnsw.VecDead) { + strings.HasSuffix(parsedKey.Attr, hnsw.VecDead) || strings.HasSuffix(parsedKey.Attr, hnsw.VecQuant) { return nil } // Reset the StreamId to prevent ordering issues while writing to stream writer.