Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions agent_memory_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ class Settings(BaseSettings):
redisvl_vector_dimensions: str = "1536"
redisvl_index_prefix: str = "memory_idx"
redisvl_indexing_algorithm: str = "HNSW"
redisvl_datatype: str = "float32"

# Working Memory Index Settings
# Used for listing sessions via Redis Search instead of sorted sets
Expand Down Expand Up @@ -551,6 +552,21 @@ class Settings(BaseSettings):
New summary:
"""

@field_validator("redisvl_datatype")
@classmethod
def validate_redisvl_datatype(cls, v: str) -> str:
"""Normalize and validate the vector datatype against RedisVL's set."""
from redisvl.schema.fields import VectorDataType

try:
VectorDataType(v.upper())
except ValueError as e:
valid = sorted(t.lower() for t in VectorDataType)
raise ValueError(
f"redisvl_datatype must be one of {valid}, got {v!r}"
) from e
return v.lower()

@field_validator("progressive_summarization_prompt")
@classmethod
def validate_progressive_summarization_prompt(cls, v: str) -> str:
Expand Down
34 changes: 32 additions & 2 deletions agent_memory_server/memory_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
VectorQuery,
)
from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.token_escaper import TokenEscaper

from agent_memory_server.filters import (
Expand Down Expand Up @@ -398,7 +399,9 @@ class RedisVLMemoryVectorDatabase(MemoryVectorDatabase):
"event_date",
]

def __init__(self, index: AsyncSearchIndex, embeddings: Any):
def __init__(
self, index: AsyncSearchIndex, embeddings: Any, datatype: str = "float32"
):
"""Initialize the RedisVL memory vector database.

Args:
Expand All @@ -409,6 +412,25 @@ def __init__(self, index: AsyncSearchIndex, embeddings: Any):
self._index = index
self.embeddings = embeddings
self._index_created = False
self._datatype = datatype

def _maybe_quantize(self, embedding: Any) -> Any:
"""Quantize a float embedding to int8 range for an int8 index.

RedisVL validates the int8 range but does not quantize; float
datatypes pass through unchanged. Per-vector max-abs scaling is
used, which COSINE distance is invariant to.
"""
Comment on lines +418 to +423
if self._datatype.lower() != "int8":
return embedding
arr = np.asarray(embedding, dtype=np.float32)
peak = float(np.max(np.abs(arr))) or 1.0
scaled = np.clip(np.round(arr * (127.0 / peak)), -127, 127)
return scaled.astype(np.int8)

def _encode_vector(self, embedding: Any) -> bytes:
"""Encode an embedding to bytes for the configured datatype."""
return array_to_buffer(self._maybe_quantize(embedding), dtype=self._datatype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint8 config lacks quantization

Medium Severity

redisvl_datatype can be set to uint8 (validated like other RedisVL types), and _build_redis_schema treats uint8 as quantized, but _maybe_quantize only scales for int8. Indexing and search then pass raw float embeddings through array_to_buffer with dtype=uint8, so stored/query vectors won’t match a proper uint8 index.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit d0955b8. Configure here.


@property
def index(self) -> AsyncSearchIndex:
Expand Down Expand Up @@ -676,12 +698,14 @@ async def _search_with_recency_aggregation(
"""
# Embed the query text to vector
embedding_vector = await self.embeddings.aembed_query(query)
embedding_vector = self._maybe_quantize(embedding_vector)

# Build base KNN or range query
if distance_threshold is not None:
knn = RangeQuery(
vector=embedding_vector,
vector_field_name="vector",
dtype=self._datatype,
filter_expression=redis_filter,
distance_threshold=float(distance_threshold),
num_results=limit,
Expand All @@ -690,6 +714,7 @@ async def _search_with_recency_aggregation(
knn = VectorQuery(
vector=embedding_vector,
vector_field_name="vector",
dtype=self._datatype,
filter_expression=redis_filter,
num_results=limit,
)
Expand Down Expand Up @@ -763,7 +788,7 @@ async def add_memories(self, memories: list[MemoryRecord]) -> list[str]:
memory_ids = []
for memory, embedding in zip(memories, embeddings, strict=False):
data = self._memory_to_data(memory)
data["vector"] = np.array(embedding, dtype=np.float32).tobytes()
data["vector"] = self._encode_vector(embedding)
data_list.append(data)
memory_ids.append(memory.id)

Expand Down Expand Up @@ -884,11 +909,13 @@ async def search_memories(
break
elif search_mode == SearchModeEnum.HYBRID:
embedding_vector = await self.embeddings.aembed_query(query)
embedding_vector = self._maybe_quantize(embedding_vector)
hybrid_query = PhraseAwareAggregateHybridQuery(
text=query,
text_field_name="text",
vector=embedding_vector,
vector_field_name="vector",
dtype=self._datatype,
text_scorer=text_scorer,
filter_expression=redis_filter,
alpha=hybrid_alpha,
Expand Down Expand Up @@ -926,10 +953,12 @@ async def search_memories(
break
else:
embedding_vector = await self.embeddings.aembed_query(query)
embedding_vector = self._maybe_quantize(embedding_vector)
if distance_threshold is not None:
vector_query = RangeQuery(
vector=embedding_vector,
vector_field_name="vector",
dtype=self._datatype,
filter_expression=redis_filter,
distance_threshold=float(distance_threshold),
num_results=limit + offset,
Expand All @@ -939,6 +968,7 @@ async def search_memories(
vector_query = VectorQuery(
vector=embedding_vector,
vector_field_name="vector",
dtype=self._datatype,
filter_expression=redis_filter,
num_results=limit + offset,
return_fields=self.RETURN_FIELDS,
Expand Down
15 changes: 13 additions & 2 deletions agent_memory_server/memory_vector_db_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ def _build_redis_schema() -> dict:
"""
embedding_dimensions = _get_embedding_dimensions()

datatype = settings.redisvl_datatype.lower()
metric = settings.redisvl_distance_metric.lower()
if datatype in ("int8", "uint8") and metric != "cosine":
raise ValueError(
f"redisvl_datatype={datatype!r} (quantized) requires "
f"redisvl_distance_metric='cosine', got {metric!r}: per-vector "
f"quantization changes geometry for non-cosine metrics."
)

return {
"index": {
"name": settings.redisvl_index_name,
Expand Down Expand Up @@ -166,7 +175,7 @@ def _build_redis_schema() -> dict:
"dims": embedding_dimensions,
"distance_metric": settings.redisvl_distance_metric.lower(),
"algorithm": settings.redisvl_indexing_algorithm.lower(),
"datatype": "float32",
"datatype": settings.redisvl_datatype,
},
},
],
Expand All @@ -192,7 +201,9 @@ def create_redis_memory_vector_db(
schema,
redis_url=redis_url_for_redisvl(settings.redis_url),
)
return RedisVLMemoryVectorDatabase(index, embeddings)
return RedisVLMemoryVectorDatabase(
index, embeddings, datatype=settings.redisvl_datatype
)
except Exception as e:
logger.error(f"Error creating Redis memory vector database: {e}")
raise
Expand Down
68 changes: 68 additions & 0 deletions tests/test_memory_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,3 +738,71 @@ def test_create_embeddings_anthropic_raises_error(self):
ModelValidationError, match="Anthropic does not provide embedding"
):
create_embeddings()


class TestConfigurableVectorDatatype:
"""Tests for the configurable vector datatype (int8 quantization)."""

def _db(self, datatype):
return RedisVLMemoryVectorDatabase(
MagicMock(), MockEmbeddings(), datatype=datatype
)

def test_default_datatype_is_float32(self):
db = RedisVLMemoryVectorDatabase(MagicMock(), MockEmbeddings())
assert db._datatype == "float32"

def test_maybe_quantize_passthrough_for_float(self):
db = self._db("float32")
v = [0.1, -0.2, 0.3]
assert db._maybe_quantize(v) == v

def test_maybe_quantize_int8_range_and_scaling(self):
db = self._db("int8")
out = db._maybe_quantize([0.5, -1.0, 0.25])
assert list(out) == [64, -127, 32]
assert all(-127 <= x <= 127 for x in out)

def test_encode_vector_int8_byte_width(self):
import numpy as np

db = self._db("int8")
blob = db._encode_vector([0.5, -1.0, 0.25])
assert len(blob) == 3 # int8 = 1 byte/component
assert list(np.frombuffer(blob, dtype=np.int8)) == [64, -127, 32]

def test_config_default_datatype_is_float32(self):
from agent_memory_server.config import Settings

assert Settings().redisvl_datatype == "float32"

def test_build_schema_uses_configured_datatype(self, monkeypatch):
from agent_memory_server.config import settings
from agent_memory_server.memory_vector_db_factory import _build_redis_schema

monkeypatch.setattr(settings, "redisvl_datatype", "int8")
schema = _build_redis_schema()
vec = next(f for f in schema["fields"] if f.get("type") == "vector")
assert vec["attrs"]["datatype"] == "int8"

def test_datatype_validator_normalizes_case(self):
from agent_memory_server.config import Settings

assert Settings(redisvl_datatype="INT8").redisvl_datatype == "int8"

def test_datatype_validator_rejects_invalid(self):
from pydantic import ValidationError

from agent_memory_server.config import Settings

with pytest.raises(ValidationError):
Settings(redisvl_datatype="not-a-real-type")

def test_int8_requires_cosine_distance_metric(self, monkeypatch):
from agent_memory_server.config import settings
from agent_memory_server.memory_vector_db_factory import _build_redis_schema

monkeypatch.setattr(settings, "redisvl_datatype", "int8")
monkeypatch.setattr(settings, "redisvl_distance_metric", "L2")
with pytest.raises(ValueError, match="cosine"):
_build_redis_schema()