From 95316e9d6b707ac4c8b8df2fd251ca2c3b750e20 Mon Sep 17 00:00:00 2001 From: GeneAI Date: Thu, 4 Jun 2026 01:17:43 -0400 Subject: [PATCH 1/2] feat(redis): configurable vector datatype with int8 quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a REDISVL_DATATYPE setting so the long-term-memory vector index can use int8 (and other RedisVL datatypes) instead of the hardcoded float32. int8 cuts index memory ~75% and speeds search ~30% with negligible recall loss (Redis 8 Query Engine required for TYPE INT8). - config: new redisvl_datatype setting (default "float32") - factory: _build_redis_schema uses settings.redisvl_datatype, and passes it to RedisVLMemoryVectorDatabase - vector db: encode/query honor the datatype. Float types go through RedisVL's array_to_buffer; int8 is quantized first (per-vector max-abs scaling — RedisVL validates the int8 range but does not quantize). Query vectors are quantized to match and the VectorQuery/RangeQuery dtype is set accordingly. Default behavior is unchanged (float32). Adds 6 tests; full test_memory_vector_db.py suite passes (33). --- agent_memory_server/config.py | 1 + agent_memory_server/memory_vector_db.py | 34 ++++++++++++- .../memory_vector_db_factory.py | 6 ++- tests/test_memory_vector_db.py | 50 +++++++++++++++++++ 4 files changed, 87 insertions(+), 4 deletions(-) diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 95464d01..45402043 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -402,6 +402,7 @@ class Settings(BaseSettings): redisvl_vector_dimensions: str = "1536" redisvl_index_prefix: str = "memory_idx" redisvl_indexing_algorithm: str = "HNSW" + redisvl_datatype: str = "float32" # Working Memory Index Settings # Used for listing sessions via Redis Search instead of sorted sets diff --git a/agent_memory_server/memory_vector_db.py b/agent_memory_server/memory_vector_db.py index f6c5cfdf..cd863848 100644 --- a/agent_memory_server/memory_vector_db.py +++ b/agent_memory_server/memory_vector_db.py @@ -20,6 +20,7 @@ VectorQuery, ) from redisvl.query.filter import FilterExpression +from redisvl.redis.utils import array_to_buffer from redisvl.utils.token_escaper import TokenEscaper from agent_memory_server.filters import ( @@ -398,7 +399,9 @@ class RedisVLMemoryVectorDatabase(MemoryVectorDatabase): "event_date", ] - def __init__(self, index: AsyncSearchIndex, embeddings: Any): + def __init__( + self, index: AsyncSearchIndex, embeddings: Any, datatype: str = "float32" + ): """Initialize the RedisVL memory vector database. Args: @@ -409,6 +412,25 @@ def __init__(self, index: AsyncSearchIndex, embeddings: Any): self._index = index self.embeddings = embeddings self._index_created = False + self._datatype = datatype + + def _maybe_quantize(self, embedding: Any) -> Any: + """Quantize a float embedding to int8 range for an int8 index. + + RedisVL validates the int8 range but does not quantize; float + datatypes pass through unchanged. Per-vector max-abs scaling is + used, which COSINE distance is invariant to. + """ + if self._datatype.lower() != "int8": + return embedding + arr = np.asarray(embedding, dtype=np.float32) + peak = float(np.max(np.abs(arr))) or 1.0 + scaled = np.clip(np.round(arr * (127.0 / peak)), -127, 127) + return scaled.astype(np.int8).tolist() + + def _encode_vector(self, embedding: Any) -> bytes: + """Encode an embedding to bytes for the configured datatype.""" + return array_to_buffer(self._maybe_quantize(embedding), dtype=self._datatype) @property def index(self) -> AsyncSearchIndex: @@ -676,12 +698,14 @@ async def _search_with_recency_aggregation( """ # Embed the query text to vector embedding_vector = await self.embeddings.aembed_query(query) + embedding_vector = self._maybe_quantize(embedding_vector) # Build base KNN or range query if distance_threshold is not None: knn = RangeQuery( vector=embedding_vector, vector_field_name="vector", + dtype=self._datatype, filter_expression=redis_filter, distance_threshold=float(distance_threshold), num_results=limit, @@ -690,6 +714,7 @@ async def _search_with_recency_aggregation( knn = VectorQuery( vector=embedding_vector, vector_field_name="vector", + dtype=self._datatype, filter_expression=redis_filter, num_results=limit, ) @@ -763,7 +788,7 @@ async def add_memories(self, memories: list[MemoryRecord]) -> list[str]: memory_ids = [] for memory, embedding in zip(memories, embeddings, strict=False): data = self._memory_to_data(memory) - data["vector"] = np.array(embedding, dtype=np.float32).tobytes() + data["vector"] = self._encode_vector(embedding) data_list.append(data) memory_ids.append(memory.id) @@ -884,11 +909,13 @@ async def search_memories( break elif search_mode == SearchModeEnum.HYBRID: embedding_vector = await self.embeddings.aembed_query(query) + embedding_vector = self._maybe_quantize(embedding_vector) hybrid_query = PhraseAwareAggregateHybridQuery( text=query, text_field_name="text", vector=embedding_vector, vector_field_name="vector", + dtype=self._datatype, text_scorer=text_scorer, filter_expression=redis_filter, alpha=hybrid_alpha, @@ -926,10 +953,12 @@ async def search_memories( break else: embedding_vector = await self.embeddings.aembed_query(query) + embedding_vector = self._maybe_quantize(embedding_vector) if distance_threshold is not None: vector_query = RangeQuery( vector=embedding_vector, vector_field_name="vector", + dtype=self._datatype, filter_expression=redis_filter, distance_threshold=float(distance_threshold), num_results=limit + offset, @@ -939,6 +968,7 @@ async def search_memories( vector_query = VectorQuery( vector=embedding_vector, vector_field_name="vector", + dtype=self._datatype, filter_expression=redis_filter, num_results=limit + offset, return_fields=self.RETURN_FIELDS, diff --git a/agent_memory_server/memory_vector_db_factory.py b/agent_memory_server/memory_vector_db_factory.py index 8541730e..e2840435 100644 --- a/agent_memory_server/memory_vector_db_factory.py +++ b/agent_memory_server/memory_vector_db_factory.py @@ -166,7 +166,7 @@ def _build_redis_schema() -> dict: "dims": embedding_dimensions, "distance_metric": settings.redisvl_distance_metric.lower(), "algorithm": settings.redisvl_indexing_algorithm.lower(), - "datatype": "float32", + "datatype": settings.redisvl_datatype, }, }, ], @@ -192,7 +192,9 @@ def create_redis_memory_vector_db( schema, redis_url=redis_url_for_redisvl(settings.redis_url), ) - return RedisVLMemoryVectorDatabase(index, embeddings) + return RedisVLMemoryVectorDatabase( + index, embeddings, datatype=settings.redisvl_datatype + ) except Exception as e: logger.error(f"Error creating Redis memory vector database: {e}") raise diff --git a/tests/test_memory_vector_db.py b/tests/test_memory_vector_db.py index 51b95ee9..1d3b0d98 100644 --- a/tests/test_memory_vector_db.py +++ b/tests/test_memory_vector_db.py @@ -738,3 +738,53 @@ def test_create_embeddings_anthropic_raises_error(self): ModelValidationError, match="Anthropic does not provide embedding" ): create_embeddings() + + +class TestConfigurableVectorDatatype: + """Tests for the configurable vector datatype (int8 quantization).""" + + def _db(self, datatype): + return RedisVLMemoryVectorDatabase( + MagicMock(), MockEmbeddings(), datatype=datatype + ) + + def test_default_datatype_is_float32(self): + db = RedisVLMemoryVectorDatabase(MagicMock(), MockEmbeddings()) + assert db._datatype == "float32" + + def test_maybe_quantize_passthrough_for_float(self): + db = self._db("float32") + v = [0.1, -0.2, 0.3] + assert db._maybe_quantize(v) == v + + def test_maybe_quantize_int8_range_and_scaling(self): + db = self._db("int8") + out = db._maybe_quantize([0.5, -1.0, 0.25]) + assert out == [64, -127, 32] + assert all(-127 <= x <= 127 for x in out) + + def test_encode_vector_int8_byte_width(self): + import numpy as np + + db = self._db("int8") + blob = db._encode_vector([0.5, -1.0, 0.25]) + assert len(blob) == 3 # int8 = 1 byte/component + assert list(np.frombuffer(blob, dtype=np.int8)) == [64, -127, 32] + + def test_config_default_datatype_is_float32(self): + from agent_memory_server.config import Settings + + assert Settings().redisvl_datatype == "float32" + + def test_build_schema_uses_configured_datatype(self): + from agent_memory_server.config import settings + from agent_memory_server.memory_vector_db_factory import _build_redis_schema + + original = settings.redisvl_datatype + try: + settings.redisvl_datatype = "int8" + schema = _build_redis_schema() + vec = next(f for f in schema["fields"] if f.get("type") == "vector") + assert vec["attrs"]["datatype"] == "int8" + finally: + settings.redisvl_datatype = original From d0955b8f177d944350a5177fc2b97675c2ab62d6 Mon Sep 17 00:00:00 2001 From: GeneAI Date: Thu, 4 Jun 2026 05:46:17 -0400 Subject: [PATCH 2/2] review: validate datatype, guard int8+cosine, drop list boxing, monkeypatch Addresses Copilot review feedback: - config: field_validator normalizes redisvl_datatype to lowercase and validates against RedisVL's VectorDataType set (rejects e.g. 'float'). - factory: raise a clear ValueError when a quantized datatype (int8/ uint8) is paired with a non-cosine distance metric, since per-vector max-abs scaling changes geometry for L2/IP. - vector db: _maybe_quantize returns an np.int8 array (no .tolist() boxing); array_to_buffer consumes it directly. - tests: use monkeypatch instead of mutating global settings; add tests for the validator (normalize + reject) and the int8/cosine guard. --- agent_memory_server/config.py | 15 ++++++++ agent_memory_server/memory_vector_db.py | 2 +- .../memory_vector_db_factory.py | 9 +++++ tests/test_memory_vector_db.py | 38 ++++++++++++++----- 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 45402043..cdfe8954 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -552,6 +552,21 @@ class Settings(BaseSettings): New summary: """ + @field_validator("redisvl_datatype") + @classmethod + def validate_redisvl_datatype(cls, v: str) -> str: + """Normalize and validate the vector datatype against RedisVL's set.""" + from redisvl.schema.fields import VectorDataType + + try: + VectorDataType(v.upper()) + except ValueError as e: + valid = sorted(t.lower() for t in VectorDataType) + raise ValueError( + f"redisvl_datatype must be one of {valid}, got {v!r}" + ) from e + return v.lower() + @field_validator("progressive_summarization_prompt") @classmethod def validate_progressive_summarization_prompt(cls, v: str) -> str: diff --git a/agent_memory_server/memory_vector_db.py b/agent_memory_server/memory_vector_db.py index cd863848..74016275 100644 --- a/agent_memory_server/memory_vector_db.py +++ b/agent_memory_server/memory_vector_db.py @@ -426,7 +426,7 @@ def _maybe_quantize(self, embedding: Any) -> Any: arr = np.asarray(embedding, dtype=np.float32) peak = float(np.max(np.abs(arr))) or 1.0 scaled = np.clip(np.round(arr * (127.0 / peak)), -127, 127) - return scaled.astype(np.int8).tolist() + return scaled.astype(np.int8) def _encode_vector(self, embedding: Any) -> bytes: """Encode an embedding to bytes for the configured datatype.""" diff --git a/agent_memory_server/memory_vector_db_factory.py b/agent_memory_server/memory_vector_db_factory.py index e2840435..94540ce5 100644 --- a/agent_memory_server/memory_vector_db_factory.py +++ b/agent_memory_server/memory_vector_db_factory.py @@ -134,6 +134,15 @@ def _build_redis_schema() -> dict: """ embedding_dimensions = _get_embedding_dimensions() + datatype = settings.redisvl_datatype.lower() + metric = settings.redisvl_distance_metric.lower() + if datatype in ("int8", "uint8") and metric != "cosine": + raise ValueError( + f"redisvl_datatype={datatype!r} (quantized) requires " + f"redisvl_distance_metric='cosine', got {metric!r}: per-vector " + f"quantization changes geometry for non-cosine metrics." + ) + return { "index": { "name": settings.redisvl_index_name, diff --git a/tests/test_memory_vector_db.py b/tests/test_memory_vector_db.py index 1d3b0d98..bd3c91b0 100644 --- a/tests/test_memory_vector_db.py +++ b/tests/test_memory_vector_db.py @@ -760,7 +760,7 @@ def test_maybe_quantize_passthrough_for_float(self): def test_maybe_quantize_int8_range_and_scaling(self): db = self._db("int8") out = db._maybe_quantize([0.5, -1.0, 0.25]) - assert out == [64, -127, 32] + assert list(out) == [64, -127, 32] assert all(-127 <= x <= 127 for x in out) def test_encode_vector_int8_byte_width(self): @@ -776,15 +776,33 @@ def test_config_default_datatype_is_float32(self): assert Settings().redisvl_datatype == "float32" - def test_build_schema_uses_configured_datatype(self): + def test_build_schema_uses_configured_datatype(self, monkeypatch): from agent_memory_server.config import settings from agent_memory_server.memory_vector_db_factory import _build_redis_schema - original = settings.redisvl_datatype - try: - settings.redisvl_datatype = "int8" - schema = _build_redis_schema() - vec = next(f for f in schema["fields"] if f.get("type") == "vector") - assert vec["attrs"]["datatype"] == "int8" - finally: - settings.redisvl_datatype = original + monkeypatch.setattr(settings, "redisvl_datatype", "int8") + schema = _build_redis_schema() + vec = next(f for f in schema["fields"] if f.get("type") == "vector") + assert vec["attrs"]["datatype"] == "int8" + + def test_datatype_validator_normalizes_case(self): + from agent_memory_server.config import Settings + + assert Settings(redisvl_datatype="INT8").redisvl_datatype == "int8" + + def test_datatype_validator_rejects_invalid(self): + from pydantic import ValidationError + + from agent_memory_server.config import Settings + + with pytest.raises(ValidationError): + Settings(redisvl_datatype="not-a-real-type") + + def test_int8_requires_cosine_distance_metric(self, monkeypatch): + from agent_memory_server.config import settings + from agent_memory_server.memory_vector_db_factory import _build_redis_schema + + monkeypatch.setattr(settings, "redisvl_datatype", "int8") + monkeypatch.setattr(settings, "redisvl_distance_metric", "L2") + with pytest.raises(ValueError, match="cosine"): + _build_redis_schema()