diff --git a/config.toml b/config.toml index 2752ff5264..19fd3bf8aa 100644 --- a/config.toml +++ b/config.toml @@ -12,6 +12,23 @@ contentDir = "content/" # ignoreErrors = ["error-remote-getcsv"] +# Stop Hugo's filesystem watcher from opening fds for build artifacts +# and language-package directories that ship inside use-case demos but +# are never referenced from a page as a resource. Without these the +# Rust `target/` and PHP `vendor/` trees alone account for ~8,500 of +# the files under `content/`, which pushes the watcher past macOS's +# default `kern.maxfilesperproc` ceiling on `hugo serve`. +ignoreFiles = [ + "/vendor/", # composer / bundler deps + "/node_modules/", # npm deps + "/target/", # rust + maven build output + "/bin/", # .NET build output + "/obj/", # .NET intermediate output + "/__pycache__/", # Python bytecode cache + "/models/", # Hugot model cache + "/\\.transformers-cache/", # TransformersPHP cache +] + [related] [[related.indices]] name = 'group' diff --git a/content/develop/use-cases/_index.md b/content/develop/use-cases/_index.md index 085389f8eb..c5f31764a7 100644 --- a/content/develop/use-cases/_index.md +++ b/content/develop/use-cases/_index.md @@ -29,3 +29,4 @@ This section provides practical examples and reference implementations for commo * [Recommendation engine]({{< relref "/develop/use-cases/recommendation-engine" >}}) - Serve personalized recommendations under tight latency budgets by combining vector similarity with structured filters in a single Redis call * [Feature store]({{< relref "/develop/use-cases/feature-store" >}}) - Serve pre-computed ML features on the request path with mixed batch-and-streaming freshness using per-field TTL * [Semantic cache]({{< relref "/develop/use-cases/semantic-cache" >}}) - Reuse LLM responses for semantically similar queries to cut token costs and skip multi-second model calls on near-duplicate prompts +* [Agent memory]({{< relref "/develop/use-cases/agent-memory" >}}) - Give AI agents persistent memory that spans sessions and tasks — working memory per thread, long-term semantic recall, and a time-ordered event log on one Redis instance diff --git a/content/develop/use-cases/agent-memory/_index.md b/content/develop/use-cases/agent-memory/_index.md new file mode 100644 index 0000000000..e793f1d4ea --- /dev/null +++ b/content/develop/use-cases/agent-memory/_index.md @@ -0,0 +1,77 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Give AI agents persistent memory that spans sessions and tasks — working memory per thread, long-term semantic recall, and a time-ordered event log — on a single Redis instance, with sub-millisecond reads on the agent loop's hot path. +hideListLinks: true +linkTitle: Agent memory +title: Redis as agent memory +weight: 8 +--- + +## When to use Redis as agent memory + +Use Redis as the memory layer for an AI agent when each reasoning step needs to recall both *what just happened in this session* and *what the agent has learned over time* under a strict per-step latency budget — without standing up a separate vector database, message broker, and session store for each tier. + +## Why the problem is hard + +LLMs are stateless. Every API call starts from zero unless the application supplies the relevant context. Without a memory layer, agents re-derive information through extra LLM calls, lose personalization between sessions, and cannot coordinate state in multi-agent deployments. Some of the obvious workarounds have real drawbacks: + +- **A standalone vector database** can index long-term semantic memories, but doesn't cover working session state or an ordered action log, and putting a separate service on the agent's hot path adds latency that compounds across multi-step reasoning loops. +- **In-process or app-server session storage** keeps working memory close to the agent, but disappears on process restart and can't be shared across multi-agent or load-balanced deployments — exactly the topology most production agents end up in. +- **Stuffing everything into the LLM context window** shifts the cost of memory onto every API call, hits the model's context limit on long-running sessions, and reliably degrades reasoning quality as the context grows. + +The core difficulty is that an agent needs *several kinds* of memory at once — short-lived working state per thread, durable semantic recall by meaning, and an audit trail of recent actions — each with its own retention rule and access pattern. Mapping all three onto a single primitive (only a vector index, only a key-value store, only an append log) forces compromises that show up as either lost context or extra LLM calls. Memory must also stay bounded; without deduplication, summarization, and background consolidation, stale context piles up and degrades downstream accuracy. + +This pattern is distinct from generic [session storage]({{< relref "/develop/use-cases/session-store" >}}) (spans a single user session, no semantic recall), from [semantic caching]({{< relref "/develop/use-cases/semantic-cache" >}}) (deduplicates LLM calls, not accumulated agent knowledge), and from RAG retrieval against an external document corpus (static reference material, not the agent's own experience). + +## What you can expect from a Redis solution + +You can: + +- Persist and resume agent sessions by thread ID across restarts and across load-balanced workers. +- Recall long-term memories by semantic similarity instead of exact key, scoped per user, namespace, or memory kind. +- Prevent memory bloat by deduplicating near-identical memories at write time with the same vector index that powers recall. +- Run semantic caching, RAG retrieval, and agent memory together on a single Redis deployment, sharing the same vector index infrastructure. +- Keep each step in the agent reasoning loop under budget — Redis reads and writes are sub-millisecond, so the memory layer doesn't dominate per-step latency. + +## How Redis supports the solution + +In practice, each tier of agent memory maps onto a Redis primitive that's already in the cluster. **Working memory** for an active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at a deterministic key such as `agent:session:{thread_id}`, holding the running scratchpad, current goal, and recent turns — written with [`HSET`]({{< relref "/commands/hset" >}}) and read in one round trip with [`HGETALL`]({{< relref "/commands/hgetall" >}}). **Long-term memory** — both episodic ("what happened in past sessions") and semantic ("what the agent has learned about this user or domain") — lives as [JSON]({{< relref "/develop/data-types/json" >}}) documents that carry an embedding vector, indexed by [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) on a [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) together with tag fields (user, namespace, kind, source thread). The agent recalls memories with one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call that combines vector similarity with metadata filtering, and the same similarity check runs at write time to deduplicate near-identical memories before they enter the store. **A time-ordered event log** of the agent's recent actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) appended with [`XADD`]({{< relref "/commands/xadd" >}}), replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}), and bounded with [`XTRIM`]({{< relref "/commands/xtrim" >}}). + +Redis provides the following features that make it a good fit for agent memory: + +- [Hashes]({{< relref "/develop/data-types/hashes" >}}) hold per-session working memory under one key, so loading or persisting a thread's state takes a single round trip. +- [JSON]({{< relref "/develop/data-types/json" >}}) documents store each long-term memory together with its embedding vector and metadata, so a similarity search returns everything the agent needs without a second lookup. +- [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) with [HNSW vector indexes]({{< relref "/develop/ai/search-and-query/vectors" >}}) recalls memories by meaning in sub-millisecond time, and the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call applies TAG and NUMERIC filters so user, namespace, and kind scoping happen inside the query rather than in application code. +- [Streams]({{< relref "/develop/data-types/streams" >}}) keep an ordered log of agent actions and observations, [`XTRIM`]({{< relref "/commands/xtrim" >}}) bounds retention without manual cleanup, and consumer groups let downstream workers — summarizers, consolidators — replay the log without losing position. +- [`EXPIRE`]({{< relref "/commands/expire" >}}) automates memory decay per tier — short TTLs on working memory, longer on episodic long-term memories, no TTL on semantic ones — so stale context falls off without a separate cleanup job. (The event log is bounded separately, by [`XADD MAXLEN`]({{< relref "/commands/xadd" >}}) on the Stream, not by `EXPIRE`.) +- Sub-millisecond reads and writes from memory keep each turn of the agent loop under budget, and a single Redis instance can carry working memory, long-term recall, the event log, semantic caching, and RAG retrieval at zero marginal infrastructure cost. + +## Ecosystem + +The following libraries, frameworks, and managed services build on Redis for agent memory: + +- **Python**: [RedisVL]({{< relref "/develop/ai/redisvl" >}}) provides vector-index, session-manager, and semantic-memory helpers you can compose into an agent memory layer. +- **Frameworks**: [LangChain]({{< relref "/integrate/langchain-redis" >}}) supports Redis as a chat history and memory backend, and [LangGraph & Redis](https://redis.io/blog/langgraph-redis-build-smarter-ai-agents-with-memory-persistence/) ships a Redis checkpointer for persisting graph state across runs. +- **AWS**: [Amazon Bedrock]({{< relref "/integrate/amazon-bedrock" >}}) agent runtimes integrate with Redis for memory persistence and vector search. +- **Any language**: standard Redis client libraries cover the pattern below for custom agent loops. +- **Managed**: [Redis Agent Memory Server]({{< relref "/develop/ai/context-engine/agent-memory" >}}) is a managed agent memory service with REST and MCP interfaces, working and long-term memory tiers, deduplication, summarization, and background consolidation — useful when you'd rather not build and operate the pattern below yourself. + +## Code examples to build your own Redis agent memory + +The following guides show how to build a small Redis-backed agent memory layer using only standard Redis commands — working memory in a hash per thread, long-term memory as JSON documents with a vector index, an event log in a stream, and per-tier TTLs for decay. Each guide includes a runnable interactive demo where you can send turns, watch working memory update, see semantic recall against past memories, and inspect the event log. + +* [redis-py (Python)]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) +* [node-redis (Node.js)]({{< relref "/develop/use-cases/agent-memory/nodejs" >}}) +* [NRedisStack (C#)]({{< relref "/develop/use-cases/agent-memory/dotnet" >}}) +* [redis-rs (Rust)]({{< relref "/develop/use-cases/agent-memory/rust" >}}) +* [go-redis (Go)]({{< relref "/develop/use-cases/agent-memory/go" >}}) +* [Jedis (Java)]({{< relref "/develop/use-cases/agent-memory/java-jedis" >}}) +* [Lettuce (Java)]({{< relref "/develop/use-cases/agent-memory/java-lettuce" >}}) +* [Predis (PHP)]({{< relref "/develop/use-cases/agent-memory/php" >}}) +* [redis-rb (Ruby)]({{< relref "/develop/use-cases/agent-memory/ruby" >}}) diff --git a/content/develop/use-cases/agent-memory/dotnet/.gitignore b/content/develop/use-cases/agent-memory/dotnet/.gitignore new file mode 100644 index 0000000000..8f5917312f --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/.gitignore @@ -0,0 +1,7 @@ +bin/ +obj/ +model_cache/ +*.user +*.suo +.vs/ +.idea/ diff --git a/content/develop/use-cases/agent-memory/dotnet/AgentEventLog.cs b/content/develop/use-cases/agent-memory/dotnet/AgentEventLog.cs new file mode 100644 index 0000000000..2826deef40 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/AgentEventLog.cs @@ -0,0 +1,119 @@ +using System.Globalization; +using StackExchange.Redis; + +namespace AgentMemoryDemo; + +/// +/// Append-only event log for an agent thread, backed by a Redis +/// Stream. +/// +/// +/// Each thread gets a stream at agent:events:{threadId}. +/// Every action the agent takes (a user turn arriving, a memory being +/// recalled, a memory being written, a tool being called) is one +/// XADD to that stream. Replay with XREVRANGE for the +/// most recent N events; bound retention with XTRIM MAXLEN ~ +/// so the log stays cheap regardless of how long the thread has been +/// running. +/// +/// The stream is independent of the session hash and the +/// long-term memory store: it answers the "what just happened" +/// question without competing with either of those for indexing or +/// memory budget. Consumer groups (not used in this demo) would let +/// downstream workers — summarisers, consolidators, audit pipelines — +/// replay the log without losing position. +/// +public sealed class AgentEventLog +{ + /// + /// Approximate cap on stream length. MAXLEN ~ lets Redis + /// trim in whole-node units instead of exactly-N units, which is + /// much cheaper at the cost of overshooting the bound by up to a + /// node's worth. + /// + public const int DefaultMaxLen = 1000; + + private readonly IDatabase _db; + public string KeyPrefix { get; } + public int MaxLen { get; } + + public AgentEventLog( + IDatabase db, + string keyPrefix = "agent:events:", + int maxLen = DefaultMaxLen) + { + _db = db; + KeyPrefix = keyPrefix; + MaxLen = maxLen; + } + + public string StreamKey(string threadId) => KeyPrefix + threadId; + + /// + /// Append one event and return its stream id. + /// + /// + /// MAXLEN ~ N keeps the stream bounded with near-zero + /// overhead; an exact bound (MAXLEN N without the tilde) + /// forces a scan and is rarely worth the cost. + /// + public string Record(string threadId, string action, string detail = "") + { + var fields = new NameValueEntry[] + { + new("action", action), + new("detail", detail), + new("ts", UnixSeconds().ToString("F6", CultureInfo.InvariantCulture)), + }; + // StreamAdd's `useApproximateMaxLength: true` issues + // `MAXLEN ~ N` rather than the exact form. + RedisValue id = _db.StreamAdd( + StreamKey(threadId), + fields, + messageId: null, + maxLength: MaxLen, + useApproximateMaxLength: true); + return (string)id!; + } + + /// Return the most recent events, newest first. + /// + /// StackExchange.Redis swaps the minId / maxId + /// arguments when it issues XREVRANGE under + /// , so the caller still passes + /// "low, high" in natural order (- / +). Passing + /// them the other way around — + / - — would issue + /// XREVRANGE key - +, which Redis interprets as an empty + /// range and returns nothing. + /// + public List Recent(string threadId, int count = 20) + { + var entries = _db.StreamRange( + StreamKey(threadId), "-", "+", count: count, messageOrder: Order.Descending); + var out_ = new List(entries.Length); + foreach (var entry in entries) + { + var fields = entry.Values.ToDictionary(v => (string)v.Name!, v => (string)v.Value!); + out_.Add(new AgentEvent( + EventId: (string)entry.Id!, + ThreadId: threadId, + Action: fields.GetValueOrDefault("action") ?? "", + Detail: fields.GetValueOrDefault("detail") ?? "", + Ts: ParseDouble(fields.GetValueOrDefault("ts"), 0))); + } + return out_; + } + + /// Current stream length. + public long Length(string threadId) => _db.StreamLength(StreamKey(threadId)); + + /// Drop the entire stream for a thread. + public bool Clear(string threadId) => _db.KeyDelete(StreamKey(threadId)); + + private static double UnixSeconds() + => DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000.0; + + private static double ParseDouble(string? value, double fallback) + => double.TryParse(value, NumberStyles.Float, CultureInfo.InvariantCulture, out var d) + ? d : fallback; +} diff --git a/content/develop/use-cases/agent-memory/dotnet/AgentMemoryDemo.csproj b/content/develop/use-cases/agent-memory/dotnet/AgentMemoryDemo.csproj new file mode 100644 index 0000000000..74d7378315 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/AgentMemoryDemo.csproj @@ -0,0 +1,40 @@ + + + + Exe + net8.0 + AgentMemoryDemo + AgentMemoryDemo + enable + enable + latest + false + + + + + + + + + + + + + + + + + PreserveNewest + + + + diff --git a/content/develop/use-cases/agent-memory/dotnet/AgentSession.cs b/content/develop/use-cases/agent-memory/dotnet/AgentSession.cs new file mode 100644 index 0000000000..1b5bd6f595 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/AgentSession.cs @@ -0,0 +1,305 @@ +using System.Globalization; +using System.Text.Json; +using StackExchange.Redis; + +namespace AgentMemoryDemo; + +/// +/// Working-memory store for an agent session, backed by a Redis Hash. +/// +/// +/// Each session is one Hash document at +/// agent:session:{threadId}. The hash holds the running +/// scratchpad, the current goal, a rolling window of recent turns +/// (serialised as a JSON list to fit in one field), and a few audit +/// fields. One HGETALL returns the whole session in a single +/// round trip on every step of the agent loop. +/// +/// Every write refreshes the key's TTL with EXPIRE, so +/// idle sessions fall off without a separate cleanup job and active +/// sessions stay alive as long as the agent keeps touching them. A +/// separate is what survives beyond a +/// session's TTL. +/// +/// The turn window is bounded to in +/// application code; the hash itself doesn't grow, so the working set +/// per thread stays constant regardless of how long the agent has +/// been running. +/// +public sealed class AgentSession +{ + // How many recent turns to keep inline on the session hash. Older + // turns flow through the event log (AgentEventLog) and the + // long-term memory store (LongTermMemory). + public const int DefaultMaxTurns = 20; + + private readonly IDatabase _db; + public string KeyPrefix { get; } + public long DefaultTtlSeconds { get; } + public int MaxTurns { get; } + + public AgentSession( + IDatabase db, + string keyPrefix = "agent:session:", + long defaultTtlSeconds = 3600, + int maxTurns = DefaultMaxTurns) + { + _db = db; + KeyPrefix = keyPrefix; + DefaultTtlSeconds = defaultTtlSeconds; + MaxTurns = maxTurns; + } + + public string SessionKey(string threadId) => KeyPrefix + threadId; + + public string NewThreadId() => Guid.NewGuid().ToString("N").Substring(0, 12); + + /// + /// Create a fresh working memory for a thread. Overwrites any + /// existing session at the same key. The agent normally calls + /// this once per thread at the first turn and relies on + /// / for subsequent + /// steps. + /// + public SessionState Start( + string threadId, + string user = "default", + string agentName = "default", + string goal = "", + long? ttlSeconds = null) + { + long ttl = ttlSeconds ?? DefaultTtlSeconds; + double now = UnixSeconds(); + var state = new SessionState( + ThreadId: threadId, + User: user, + Agent: agentName, + Goal: goal, + Scratchpad: "", + TurnCount: 0, + CreatedTs: now, + LastActiveTs: now, + RecentTurns: Array.Empty(), + TtlSeconds: ttl); + Write(state, ttl); + return state; + } + + /// + /// Return the session state, or null if it has expired. + /// + public SessionState? Load(string threadId) + { + string key = SessionKey(threadId); + var raw = _db.HashGetAll(key); + if (raw is null || raw.Length == 0) return null; + var fields = raw.ToDictionary(e => (string)e.Name!, e => (string)e.Value!); + TimeSpan? ttl = _db.KeyTimeToLive(key); + long ttlSeconds = ttl is { TotalSeconds: > 0 } v ? (long)v.TotalSeconds : 0L; + string turnsBlob = fields.GetValueOrDefault("recent_turns") ?? "[]"; + var turns = TryDeserializeTurns(turnsBlob); + return new SessionState( + ThreadId: threadId, + User: fields.GetValueOrDefault("user") ?? "default", + Agent: fields.GetValueOrDefault("agent") ?? "default", + Goal: fields.GetValueOrDefault("goal") ?? "", + Scratchpad: fields.GetValueOrDefault("scratchpad") ?? "", + TurnCount: ParseLong(fields.GetValueOrDefault("turn_count"), 0), + CreatedTs: ParseDouble(fields.GetValueOrDefault("created_ts"), 0), + LastActiveTs: ParseDouble(fields.GetValueOrDefault("last_active_ts"), 0), + RecentTurns: turns, + TtlSeconds: ttlSeconds); + } + + /// + /// Append a turn, bound the rolling window, refresh the TTL. + /// + /// + /// and + /// are only consulted when the session does not yet exist — they + /// seed the auto-created session so the working-memory hash + /// matches the user the caller is operating against. On an + /// existing session they're ignored; the original Start + /// values stand. + /// + /// Read-modify-write here is last-writer-wins on the turn + /// list if two concurrent turns reach the same thread; the demo + /// never triggers that race in practice (one browser, one turn at + /// a time) but a multi-worker agent that shares a thread id would + /// wrap this in WATCH / MULTI / EXEC or a + /// Lua script that does the append atomically server-side. + /// + public SessionState AppendTurn( + string threadId, + string role, + string content, + string? user = null, + string? agentName = null, + long? ttlSeconds = null) + { + var state = Load(threadId) + ?? Start( + threadId, + user: user ?? "default", + agentName: agentName ?? "default", + ttlSeconds: ttlSeconds); + + var newTurns = state.RecentTurns.ToList(); + newTurns.Add(new SessionTurn(Role: role, Content: content, Ts: UnixSeconds())); + if (newTurns.Count > MaxTurns) + { + newTurns = newTurns.GetRange(newTurns.Count - MaxTurns, MaxTurns); + } + + long ttl = ttlSeconds ?? DefaultTtlSeconds; + var next = state with + { + TurnCount = state.TurnCount + 1, + LastActiveTs = UnixSeconds(), + RecentTurns = newTurns, + TtlSeconds = ttl, + }; + Write(next, ttl); + return next; + } + + /// + /// Update the agent's running scratchpad and refresh the TTL. + /// Returns null when the session does not exist. + /// + public SessionState? SetScratchpad(string threadId, string text, long? ttlSeconds = null) + { + var state = Load(threadId); + if (state is null) return null; + long ttl = ttlSeconds ?? DefaultTtlSeconds; + var next = state with + { + Scratchpad = text, + LastActiveTs = UnixSeconds(), + TtlSeconds = ttl, + }; + Write(next, ttl); + return next; + } + + /// + /// Update the goal field without touching turns or the scratchpad. + /// Creates the session if it doesn't exist yet — setting a goal + /// on a fresh thread is a sensible first step in the agent loop, + /// so this method covers both the "rename the goal mid-session" + /// and the "start a thread with this goal" cases. + /// + public SessionState SetGoal( + string threadId, + string text, + string? user = null, + string? agentName = null, + long? ttlSeconds = null) + { + var state = Load(threadId); + if (state is null) + { + return Start( + threadId, + user: user ?? "default", + agentName: agentName ?? "default", + goal: text, + ttlSeconds: ttlSeconds); + } + long ttl = ttlSeconds ?? DefaultTtlSeconds; + var next = state with + { + Goal = text, + LastActiveTs = UnixSeconds(), + TtlSeconds = ttl, + }; + Write(next, ttl); + return next; + } + + /// Drop the session immediately. Returns true if it existed. + public bool Delete(string threadId) => _db.KeyDelete(SessionKey(threadId)); + + /// Return active thread ids (for the demo's thread switcher). + public List ListThreads(int limit = 100) + { + var out_ = new List(); + // SCAN via the server-set option; this stays incremental even + // on a database with many session keys. + var server = _db.Multiplexer.GetServer(_db.Multiplexer.GetEndPoints().First()); + foreach (var key in server.Keys(database: _db.Database, pattern: KeyPrefix + "*", pageSize: 200)) + { + string raw = (string)key!; + out_.Add(raw.StartsWith(KeyPrefix) ? raw.Substring(KeyPrefix.Length) : raw); + if (out_.Count >= limit) break; + } + return out_; + } + + private void Write(SessionState state, long ttl) + { + string key = SessionKey(state.ThreadId); + var entries = new HashEntry[] + { + new("thread_id", state.ThreadId), + new("user", state.User), + new("agent", state.Agent), + new("goal", state.Goal), + new("scratchpad", state.Scratchpad), + new("turn_count", state.TurnCount.ToString(CultureInfo.InvariantCulture)), + new("created_ts", state.CreatedTs.ToString("F6", CultureInfo.InvariantCulture)), + new("last_active_ts", state.LastActiveTs.ToString("F6", CultureInfo.InvariantCulture)), + new("recent_turns", JsonSerializer.Serialize(state.RecentTurns.Select(t => new + { + role = t.Role, + content = t.Content, + ts = t.Ts, + }))), + }; + + // MULTI/EXEC so HSET and EXPIRE either both apply or neither + // does. A connection drop between the two writes would + // otherwise leave the session without a TTL. We check the + // return value of Execute() — there's no WATCH on this + // transaction so a false here means the server rejected the + // batch (out of memory, OOM script kill, etc.); surface it + // rather than letting the in-memory state drift from Redis. + var tx = _db.CreateTransaction(); + _ = tx.HashSetAsync(key, entries); + _ = tx.KeyExpireAsync(key, TimeSpan.FromSeconds(ttl)); + if (!tx.Execute()) + { + throw new RedisServerException("session write MULTI/EXEC was discarded"); + } + } + + private static IReadOnlyList TryDeserializeTurns(string blob) + { + try + { + using var doc = JsonDocument.Parse(blob); + var result = new List(doc.RootElement.GetArrayLength()); + foreach (var el in doc.RootElement.EnumerateArray()) + { + string role = el.TryGetProperty("role", out var r) ? r.GetString() ?? "" : ""; + string content = el.TryGetProperty("content", out var c) ? c.GetString() ?? "" : ""; + double ts = el.TryGetProperty("ts", out var t) && t.TryGetDouble(out var d) ? d : 0.0; + result.Add(new SessionTurn(role, content, ts)); + } + return result; + } + catch (JsonException) + { + return Array.Empty(); + } + } + + private static double UnixSeconds() + => DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000.0; + + private static double ParseDouble(string? value, double fallback) + => double.TryParse(value, NumberStyles.Float, CultureInfo.InvariantCulture, out var d) ? d : fallback; + + private static long ParseLong(string? value, long fallback) + => long.TryParse(value, NumberStyles.Integer, CultureInfo.InvariantCulture, out var l) ? l : fallback; +} diff --git a/content/develop/use-cases/agent-memory/dotnet/LocalEmbedder.cs b/content/develop/use-cases/agent-memory/dotnet/LocalEmbedder.cs new file mode 100644 index 0000000000..9778bccdb3 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/LocalEmbedder.cs @@ -0,0 +1,276 @@ +using System.Buffers.Binary; +using System.Net.Http; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; +using Microsoft.ML.Tokenizers; + +namespace AgentMemoryDemo; + +/// +/// Local text-embedding helper backed by ONNX Runtime + a Bert +/// WordPiece tokenizer. +/// +/// +/// This is a thin wrapper around the +/// sentence-transformers/all-MiniLM-L6-v2 model loaded as an +/// ONNX export from the Xenova/all-MiniLM-L6-v2 Hugging Face +/// mirror: a 384-dimensional encoder that runs in-process on CPU +/// through ONNX Runtime, needs no API key, and produces vectors +/// numerically very close to the equivalent Python and Node ports +/// (close enough that paraphrase distances differ only at the second +/// or third decimal place). +/// +/// The class downloads model.onnx and the +/// vocab.txt WordPiece dictionary into a local cache directory +/// on the first call; every later run is offline. Vectors are mean- +/// pooled over the token positions (weighted by the attention mask) +/// and then L2-normalised explicitly so a Redis Search index declared +/// with DISTANCE_METRIC COSINE returns scores that are +/// directly comparable across entries. +/// +public sealed class LocalEmbedder : IDisposable +{ + public const string DefaultModelName = "sentence-transformers/all-MiniLM-L6-v2"; + public const int DefaultVectorDim = 384; + + // The Xenova mirror is the Node demo's source; the ONNX export + // and vocab there match the original sentence-transformers + // checkpoint and give us a single dependency-free download URL. + private const string ModelUrl = + "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx"; + private const string VocabUrl = + "https://huggingface.co/Xenova/all-MiniLM-L6-v2/resolve/main/vocab.txt"; + + private readonly InferenceSession _session; + private readonly BertTokenizer _tokenizer; + + public string ModelName { get; } + public int Dim { get; } + + private LocalEmbedder( + string modelName, + InferenceSession session, + BertTokenizer tokenizer, + int dim) + { + ModelName = modelName; + _session = session; + _tokenizer = tokenizer; + Dim = dim; + } + + /// + /// Load the default model. Blocks while ONNX Runtime initialises + /// and the model + tokenizer files are downloaded on the first + /// run. The single is shared + /// across handler threads — ONNX Runtime documents + /// InferenceSession.Run as thread-safe. + /// + /// + /// Directory the model and tokenizer files are cached in. Created + /// if it doesn't exist. Defaults to ./model_cache next to + /// the running binary, so a fresh checkout doesn't re-download on + /// every dotnet run. + /// + public static async Task CreateAsync(string? cacheDir = null) + { + cacheDir ??= Path.Combine(AppContext.BaseDirectory, "model_cache"); + Directory.CreateDirectory(cacheDir); + + string modelPath = Path.Combine(cacheDir, "model.onnx"); + string vocabPath = Path.Combine(cacheDir, "vocab.txt"); + + await DownloadIfMissingAsync(ModelUrl, modelPath); + await DownloadIfMissingAsync(VocabUrl, vocabPath); + + // The Xenova / sentence-transformers MiniLM tokenizer config + // says lower_case=true, do_basic_tokenize=true, + // tokenize_chinese_chars=true; surface those flags here so + // the tokens match the ones produced by the Python / + // Node.js sibling demos. + var options = new BertOptions + { + LowerCaseBeforeTokenization = true, + ApplyBasicTokenization = true, + IndividuallyTokenizeCjk = true, + }; + var tokenizer = BertTokenizer.Create(vocabPath, options); + + // One session per process; ONNX Runtime explicitly documents + // it as thread-safe for inference, so we can share it across + // every HttpListener handler thread without further + // synchronisation. + var session = new InferenceSession(modelPath); + + // Probe the output shape once so we fail loudly if a different + // model is ever wired up against the 384-dim Redis Search + // field. + var probe = EncodeInternal(session, tokenizer, "dimension probe"); + return new LocalEmbedder(DefaultModelName, session, tokenizer, probe.Length); + } + + private static async Task DownloadIfMissingAsync(string url, string path) + { + if (File.Exists(path)) return; + Console.WriteLine($"Downloading {url}"); + using var http = new HttpClient + { + Timeout = TimeSpan.FromMinutes(5), + }; + using var stream = await http.GetStreamAsync(url); + // Write to a temp path and rename so a Ctrl-C during the + // download doesn't leave a half-written file the next run + // would happily skip. + string tmp = path + ".part"; + using (var file = File.Create(tmp)) + { + await stream.CopyToAsync(file); + } + File.Move(tmp, path, overwrite: true); + } + + /// + /// Encode a single string. Returns a float[] of length + /// . + /// + public float[] EncodeOne(string text) => EncodeInternal(_session, _tokenizer, text); + + /// + /// Encode several strings sequentially and return one vector per + /// input. Throws when the underlying session returns a different + /// number of vectors than inputs. + /// + public List EncodeMany(IReadOnlyList texts) + { + var results = new List(texts.Count); + foreach (var text in texts) + { + results.Add(EncodeInternal(_session, _tokenizer, text)); + } + if (results.Count != texts.Count) + { + // Belt-and-braces. The loop above guarantees one vector + // per input on the happy path, but surfacing this as an + // explicit check matches the contract the seed loader + // relies on and avoids an index-out-of-range later if a + // future refactor batches into a single Run() call. + throw new InvalidOperationException( + $"embedder produced {results.Count} vectors for {texts.Count} inputs"); + } + return results; + } + + private static float[] EncodeInternal( + InferenceSession session, BertTokenizer tokenizer, string text) + { + // BertTokenizer.EncodeToIds adds the [CLS] / [SEP] sentinels + // that the MiniLM ONNX export expects. considerPreTokenization + // splits on whitespace + punctuation before WordPiece, which + // matches the do_basic_tokenize=true in the upstream + // tokenizer config. + var ids = tokenizer + .EncodeToIds(text, addSpecialTokens: true, considerPreTokenization: true) + .ToArray(); + int seqLen = ids.Length; + // Empty strings still need at least [CLS] [SEP] so the model + // has something to attend to. EncodeToIds gives us that for + // the empty string already; the guard above is just defensive. + + var idsLong = new long[seqLen]; + var mask = new long[seqLen]; + var tokenType = new long[seqLen]; + for (int i = 0; i < seqLen; i++) + { + idsLong[i] = ids[i]; + mask[i] = 1; + tokenType[i] = 0; + } + + var inputIds = new DenseTensor(idsLong, new[] { 1, seqLen }); + var attentionMask = new DenseTensor(mask, new[] { 1, seqLen }); + var tokenTypes = new DenseTensor(tokenType, new[] { 1, seqLen }); + + var inputs = new List + { + NamedOnnxValue.CreateFromTensor("input_ids", inputIds), + NamedOnnxValue.CreateFromTensor("attention_mask", attentionMask), + NamedOnnxValue.CreateFromTensor("token_type_ids", tokenTypes), + }; + + using var results = session.Run(inputs); + // The MiniLM ONNX export exposes a single output named + // last_hidden_state of shape [batch, seq, dim]. Pick it by + // position so we don't depend on a specific name across + // future re-exports. + var output = results[0].AsTensor(); + int dim = output.Dimensions[2]; + var pooled = new float[dim]; + + // Attention-masked mean pooling — the standard + // sentence-transformers recipe. The mask is all 1s here + // because we never pad, but write the masked sum so the + // code stays correct under a future batched implementation. + double maskTotal = 0; + for (int s = 0; s < seqLen; s++) + { + double w = mask[s]; + maskTotal += w; + for (int d = 0; d < dim; d++) + { + pooled[d] += (float)(output[0, s, d] * w); + } + } + if (maskTotal > 0) + { + float inv = (float)(1.0 / maskTotal); + for (int d = 0; d < dim; d++) pooled[d] *= inv; + } + + // L2-normalise explicitly. The MiniLM ONNX export does not + // ship the normalisation step the Python sentence-transformers + // pipeline applies by default with normalize_embeddings=True; + // doing it here keeps the cosine distances comparable across + // the Python, Node, Go, Java, and .NET demos. + double sq = 0; + foreach (var v in pooled) sq += (double)v * v; + if (sq > 0) + { + float inv = (float)(1.0 / Math.Sqrt(sq)); + for (int d = 0; d < dim; d++) pooled[d] *= inv; + } + return pooled; + } + + /// + /// Pack a float[] into the bytes Redis Search expects for + /// a FLOAT32 vector field — raw little-endian float32 + /// values, no header, no padding. Matches the encoding the + /// Python, Node, Go, and Java ports write. + /// + /// + /// We use + /// rather than because + /// the latter follows host endianness; explicit little-endian + /// here means the docs example is portable even on a hypothetical + /// big-endian .NET host. + /// is checked once at process start in to + /// catch any future surprise — every supported .NET runtime + /// today is little-endian, but the assertion documents the + /// assumption. + /// + public static byte[] ToBytes(float[] vector) + { + var bytes = new byte[vector.Length * sizeof(float)]; + var span = bytes.AsSpan(); + for (int i = 0; i < vector.Length; i++) + { + BinaryPrimitives.WriteSingleLittleEndian(span.Slice(i * sizeof(float)), vector[i]); + } + return bytes; + } + + public void Dispose() + { + _session.Dispose(); + } +} diff --git a/content/develop/use-cases/agent-memory/dotnet/LongTermMemory.cs b/content/develop/use-cases/agent-memory/dotnet/LongTermMemory.cs new file mode 100644 index 0000000000..161864efc1 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/LongTermMemory.cs @@ -0,0 +1,498 @@ +using System.Globalization; +using System.Text.Json; +using NRedisStack; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; +using NRedisStack.Search.Literals.Enums; +using StackExchange.Redis; + +namespace AgentMemoryDemo; + +/// +/// Long-term memory store for an agent, backed by Redis JSON and +/// Search. +/// +/// +/// Each memory lives as one JSON document at +/// agent:mem:<id>. The document holds the memory text, +/// its embedding vector, and a small metadata block — user, +/// namespace, kind, source thread, timestamps — that lets the recall +/// query scope results without falling back to application-side +/// filtering. +/// +/// A single Redis Search index covers the embedding plus every +/// metadata field, so one FT.SEARCH call performs approximate- +/// nearest-neighbour over the in-scope subset and returns the top-k +/// memories ranked by cosine distance. The same KNN check runs at +/// write time to deduplicate near-identical memories before +/// they enter the store, which keeps the index from filling with +/// paraphrases of the same fact as the agent reasons over similar +/// topics across sessions. +/// +/// Memories carry one of two kinds: episodic snapshots +/// from a specific thread, written with a medium TTL so old session +/// detail decays naturally; semantic distilled facts and +/// preferences the agent should carry forward indefinitely, written +/// with no TTL by default. The split is enforced as a TAG on the +/// index, so the recall query can ask for one kind or both with a +/// filter — no separate keyspaces. +/// +public sealed class LongTermMemory +{ + public const int VectorDimDefault = 384; + + /// + /// How close (cosine distance) a candidate must be to an existing + /// memory to count as a duplicate at write time. Smaller = + /// stricter. 0.20 is calibrated to the + /// sentence-transformers/all-MiniLM-L6-v2 embedding model + /// used in the demo, where a paraphrase of an existing memory + /// lands in the 0.10 – 0.20 range and a distinct memory lands + /// above 0.50. + /// + public const double DefaultDedupThreshold = 0.20; + + /// + /// How close (cosine distance) a candidate must be to count as a + /// relevant recall result. Larger than the dedup threshold so the + /// agent gets a wider net at read time than at write time. + /// + public const double DefaultRecallThreshold = 0.55; + + /// + /// TTL tiers, in seconds. null means "no TTL" — the memory + /// persists until explicitly deleted or evicted under memory + /// pressure. + /// + public static readonly IReadOnlyDictionary DefaultTtlByKind = + new Dictionary + { + ["episodic"] = 7L * 24 * 3600, + ["semantic"] = null, + }; + + // Characters Redis Search treats as syntax inside a TAG value; + // any of them in a user-supplied filter must be backslash-escaped + // or the surrounding `{...}` block won't parse correctly. + private static readonly HashSet TagSpecial = new( + "\\,.<>{}[]\"':;!@#$%^&*()-+=~| "); + + private readonly IDatabase _db; + private readonly ISearchCommands _ft; + private readonly IJsonCommands _json; + private readonly IReadOnlyDictionary _ttlByKind; + + public string IndexName { get; } + public string KeyPrefix { get; } + public int VectorDim { get; } + public double DedupThreshold { get; } + public double RecallThreshold { get; } + + public LongTermMemory( + IDatabase db, + string indexName = "agentmem:idx", + string keyPrefix = "agent:mem:", + int vectorDim = VectorDimDefault, + double dedupThreshold = DefaultDedupThreshold, + double recallThreshold = DefaultRecallThreshold, + IReadOnlyDictionary? ttlByKind = null) + { + _db = db; + _ft = db.FT(); + _json = db.JSON(); + IndexName = indexName; + KeyPrefix = keyPrefix; + VectorDim = vectorDim; + DedupThreshold = dedupThreshold; + RecallThreshold = recallThreshold; + _ttlByKind = ttlByKind ?? DefaultTtlByKind; + } + + public string MemoryKey(string memoryId) => KeyPrefix + memoryId; + + // ------------------------------------------------------------------ + // Index management + // ------------------------------------------------------------------ + + /// + /// Create the Redis Search index if it doesn't already exist. + /// + /// + /// The index is declared on the JSON document type with alias + /// names on each path; the same FT.SEARCH filter clause + /// works here as on a HASH-backed index, and the field paths + /// ($.user, $.embedding, ...) only show up in + /// FT.CREATE. + /// + public void CreateIndex() + { + var schema = new Schema() + .AddTextField(new FieldName("$.text", "text")) + .AddTagField(new FieldName("$.user", "user")) + .AddTagField(new FieldName("$.namespace", "namespace")) + .AddTagField(new FieldName("$.kind", "kind")) + .AddTagField(new FieldName("$.source_thread", "source_thread")) + .AddNumericField(new FieldName("$.created_ts", "created_ts"), sortable: true) + .AddNumericField(new FieldName("$.hit_count", "hit_count"), sortable: true) + .AddVectorField( + new FieldName("$.embedding", "embedding"), + Schema.VectorField.VectorAlgo.HNSW, + new Dictionary + { + ["TYPE"] = "FLOAT32", + ["DIM"] = VectorDim, + ["DISTANCE_METRIC"] = "COSINE", + }); + try + { + _ft.Create( + IndexName, + new FTCreateParams() + .On(IndexDataType.JSON) + .Prefix(KeyPrefix), + schema); + } + catch (RedisServerException ex) + when (ex.Message.Contains("Index already exists", StringComparison.OrdinalIgnoreCase)) + { + // Idempotent. + } + } + + /// Drop the search index. Optionally also delete the JSON docs. + public void DropIndex(bool deleteDocuments = false) + { + try + { + _ft.DropIndex(IndexName, deleteDocuments); + } + catch (RedisServerException ex) + { + string msg = ex.Message ?? ""; + if (!msg.Contains("no such index", StringComparison.OrdinalIgnoreCase) + && !msg.Contains("unknown index name", StringComparison.OrdinalIgnoreCase)) + { + throw; + } + } + } + + // ------------------------------------------------------------------ + // Write + // ------------------------------------------------------------------ + + /// + /// Write a new memory, deduplicating against existing entries. + /// + /// + /// Runs one in-scope KNN(1) against the index first. + /// If the nearest existing memory is within + /// , the new memory is skipped (its + /// content is already represented) and the existing memory's + /// hit_count is bumped via JSON.NUMINCRBY. Otherwise + /// a fresh JSON document is written under a new id with a TTL + /// derived from the memory's kind. + /// + /// The KNN-then-write sequence is not atomic; two workers + /// that remember the same fact at the same time can both miss + /// each other's in-flight write and insert duplicate memories. + /// See the walkthrough's "Concurrency caveats" section for the + /// production fix (periodic background consolidator that merges + /// near-duplicates). + /// + public WriteResult Remember( + string text, + float[] embedding, + string user = "default", + string @namespace = "default", + string kind = "episodic", + string sourceThread = "", + long? ttlSeconds = null) + { + if (embedding is null) throw new ArgumentNullException(nameof(embedding)); + if (embedding.Length != VectorDim) + { + throw new ArgumentException( + $"embedding length is {embedding.Length}; index expects {VectorDim}", + nameof(embedding)); + } + + var nearest = Nearest(embedding, user, @namespace, kind, k: 1); + double? nearestDistance = nearest.Count > 0 ? nearest[0].Distance : null; + if (nearest.Count > 0 + && nearest[0].Distance is double d + && d <= DedupThreshold) + { + BumpHitCount(nearest[0].Id); + return new WriteResult( + Id: nearest[0].Id, Deduped: true, ExistingDistance: nearestDistance); + } + + string id = Guid.NewGuid().ToString("N").Substring(0, 12); + string key = MemoryKey(id); + double now = UnixSeconds(); + + // Build the JSON doc as a Dictionary so + // System.Text.Json serialises the float[] embedding as a + // bare JSON array — the encoding RediSearch expects when + // indexing a JSON path as a vector field. + var doc = new Dictionary + { + ["id"] = id, + ["user"] = user, + ["namespace"] = @namespace, + ["kind"] = kind, + ["source_thread"] = sourceThread, + ["text"] = text, + ["embedding"] = embedding, + ["created_ts"] = now, + ["hit_count"] = 0, + }; + long? ttl = ttlSeconds ?? ResolveTtl(kind); + + // MULTI/EXEC so JSON.SET and EXPIRE either both apply or + // neither does. A connection drop between the two writes + // would otherwise leave the memory without an expiry — the + // index entry would still be there, but an `episodic` doc + // would outlive its intended seven-day TTL. + // + // NRedisStack exposes its JSON helpers only on `IDatabase`, + // not on `ITransaction`, so we drop down to the raw + // `ExecuteAsync("JSON.SET", ...)` command for the transactional + // path. The document body is JSON text either way. + string docJson = JsonSerializer.Serialize(doc); + var tx = _db.CreateTransaction(); + _ = tx.ExecuteAsync("JSON.SET", key, "$", docJson); + if (ttl is long t) + { + _ = tx.KeyExpireAsync(key, TimeSpan.FromSeconds(t)); + } + if (!tx.Execute()) + { + throw new RedisServerException("remember MULTI/EXEC was discarded"); + } + return new WriteResult(Id: id, Deduped: false, ExistingDistance: nearestDistance); + } + + // ------------------------------------------------------------------ + // Recall + // ------------------------------------------------------------------ + + /// + /// Return the top-k in-scope memories ranked by similarity. + /// Memories beyond (or the + /// instance default) are dropped — the index always returns + /// something for KNN, so a recall result on an unrelated + /// query would otherwise be a confidently-wrong false positive. + /// + public List Recall( + float[] queryEmbedding, + string user = "default", + string? @namespace = "default", + string? kind = null, + int k = 5, + double? distanceThreshold = null) + { + double threshold = distanceThreshold ?? RecallThreshold; + var candidates = Nearest(queryEmbedding, user, @namespace, kind, k); + return candidates + .Where(c => c.Distance is double d && d <= threshold) + .ToList(); + } + + // ------------------------------------------------------------------ + // Admin / inspection + // ------------------------------------------------------------------ + + public IndexSnapshot IndexInfo() + { + try + { + var info = _ft.Info(IndexName); + return new IndexSnapshot( + NumDocs: info.NumDocs, + IndexingFailures: info.HashIndexingFailures); + } + catch (RedisServerException) + { + return new IndexSnapshot(0, 0); + } + } + + public List ListMemories( + string? user = "default", + string? @namespace = "default", + string? kind = null, + int limit = 100) + { + // Match `Recall`'s defaults so listing and KNN recall agree + // on which memories are in scope for the same caller inputs. + // Pass `null` (or `""`) on either argument to opt out of the + // TAG filter and list across every scope. + string filterClause = BuildFilterClause(user, @namespace, kind); + var query = new Query(filterClause) + .ReturnFields( + "user", "namespace", "kind", "source_thread", + "text", "created_ts", "hit_count") + .Limit(0, limit) + .SetSortBy("created_ts", ascending: false) + .Dialect(2); + SearchResult result; + try + { + result = _ft.Search(IndexName, query); + } + catch (RedisServerException) + { + return new List(); + } + var out_ = new List(result.Documents.Count); + foreach (var doc in result.Documents) + { + string memoryId = StripPrefix(doc.Id ?? ""); + var props = doc.GetProperties().ToDictionary(p => p.Key, p => p.Value); + TimeSpan? ttl = _db.KeyTimeToLive(MemoryKey(memoryId)); + long? ttlSeconds = ttl is { TotalSeconds: > 0 } v ? (long)v.TotalSeconds : null; + out_.Add(BuildRecord(memoryId, props, ttlSeconds, distance: null)); + } + return out_; + } + + public bool DeleteMemory(string memoryId) => _db.KeyDelete(MemoryKey(memoryId)); + + /// + /// Drop the index and every memory document. Returns the count + /// of documents that were removed. In production the equivalent + /// is FLUSHDB on a dedicated memory database, or letting + /// TTLs and eviction expire entries naturally. + /// + public long Clear() + { + long before = IndexInfo().NumDocs; + DropIndex(deleteDocuments: true); + CreateIndex(); + return before; + } + + // ------------------------------------------------------------------ + // Internals + // ------------------------------------------------------------------ + + private List Nearest( + float[] embedding, string? user, string? @namespace, string? kind, int k) + { + if (embedding.Length != VectorDim) + { + throw new ArgumentException( + $"embedding length is {embedding.Length}; index expects {VectorDim}", + nameof(embedding)); + } + string filterClause = BuildFilterClause(user, @namespace, kind); + string knnQuery = $"{filterClause}=>[KNN {k} @embedding $vec AS distance]"; + byte[] vecBytes = LocalEmbedder.ToBytes(embedding); + + var query = new Query(knnQuery) + .ReturnFields( + "user", "namespace", "kind", "source_thread", + "text", "created_ts", "hit_count", "distance") + .SetSortBy("distance", ascending: true) + .Limit(0, k) + .AddParam("vec", vecBytes) + .Dialect(2); + var result = _ft.Search(IndexName, query); + if (result.Documents is null || result.Documents.Count == 0) + { + return new List(); + } + var out_ = new List(result.Documents.Count); + foreach (var doc in result.Documents) + { + // `doc.Id` is the full Redis key (e.g. + // `agent:mem:abc123`). Strip the prefix so the returned + // record exposes only the opaque id the UI and + // `DeleteMemory` work with. + string memoryId = StripPrefix(doc.Id ?? ""); + var props = doc.GetProperties().ToDictionary(p => p.Key, p => p.Value); + double distance = ParseDouble(props.GetValueOrDefault("distance"), 0.0); + TimeSpan? ttl = _db.KeyTimeToLive(MemoryKey(memoryId)); + long? ttlSeconds = ttl is { TotalSeconds: > 0 } v ? (long)v.TotalSeconds : null; + out_.Add(BuildRecord(memoryId, props, ttlSeconds, distance)); + } + return out_; + } + + private void BumpHitCount(string memoryId) + { + try + { + _json.NumIncrby(MemoryKey(memoryId), "$.hit_count", 1); + } + catch (RedisServerException) + { + // The doc may have expired between recall and bump — + // fine, we just lose the hit count update. + } + } + + private string StripPrefix(string rawKey) + => rawKey.StartsWith(KeyPrefix) ? rawKey.Substring(KeyPrefix.Length) : rawKey; + + private long? ResolveTtl(string kind) + => _ttlByKind.TryGetValue(kind, out var ttl) ? ttl : null; + + private static MemoryRecord BuildRecord( + string memoryId, + Dictionary props, + long? ttlSeconds, + double? distance) + => new( + Id: memoryId, + User: ToStringSafe(props, "user"), + Namespace: ToStringSafe(props, "namespace"), + Kind: ToStringSafe(props, "kind"), + SourceThread: ToStringSafe(props, "source_thread"), + Text: ToStringSafe(props, "text"), + CreatedTs: ParseDouble(props.GetValueOrDefault("created_ts"), 0), + HitCount: (long)ParseDouble(props.GetValueOrDefault("hit_count"), 0), + Distance: distance, + TtlSeconds: ttlSeconds); + + internal static string EscapeTagValue(string value) + { + if (string.IsNullOrEmpty(value)) return ""; + var sb = new System.Text.StringBuilder(value.Length); + foreach (var ch in value) + { + if (TagSpecial.Contains(ch)) sb.Append('\\'); + sb.Append(ch); + } + return sb.ToString(); + } + + internal static string BuildFilterClause(string? user, string? @namespace, string? kind) + { + var clauses = new List(3); + if (!string.IsNullOrEmpty(user)) + clauses.Add($"@user:{{{EscapeTagValue(user!)}}}"); + if (!string.IsNullOrEmpty(@namespace)) + clauses.Add($"@namespace:{{{EscapeTagValue(@namespace!)}}}"); + if (!string.IsNullOrEmpty(kind)) + clauses.Add($"@kind:{{{EscapeTagValue(kind!)}}}"); + return clauses.Count == 0 + ? "(*)" + : "(" + string.Join(" ", clauses) + ")"; + } + + private static string ToStringSafe(Dictionary props, string key) + => props.GetValueOrDefault(key).ToString() ?? ""; + + private static double ParseDouble(RedisValue value, double fallback) + { + if (value.IsNullOrEmpty) return fallback; + if (value.TryParse(out double d)) return d; + return fallback; + } + + private static double UnixSeconds() + => DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000.0; +} diff --git a/content/develop/use-cases/agent-memory/dotnet/Program.cs b/content/develop/use-cases/agent-memory/dotnet/Program.cs new file mode 100644 index 0000000000..050f7b3315 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/Program.cs @@ -0,0 +1,759 @@ +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Text; +using System.Text.Json; +using System.Web; +using StackExchange.Redis; + +namespace AgentMemoryDemo; + +/// +/// Redis agent-memory demo server (.NET 8 + NRedisStack + ONNX +/// Runtime). +/// +/// +/// Run this and visit http://localhost:8093 to drive a +/// small agent-memory demo backed by Redis Hashes, JSON, Search, and +/// Streams. The UI lets you type a turn, watch working memory update, +/// see semantically similar long-term memories recalled, watch the +/// write-time deduplication skip near-duplicates, and inspect the +/// per-thread event log. +/// +/// The server holds a single , one +/// , one , and +/// one for the lifetime of the process. +/// The first run downloads the embedding model into +/// ./model_cache; everything after is local. +/// +public static class Program +{ + private const string StackLabel = + "NRedisStack + ONNX Runtime + .NET HttpListener"; + + // 1 MiB cap on POST bodies so a runaway client (or a `curl + // --data-binary @big-file` by mistake) can't accumulate + // unbounded memory before the handler runs. The demo's largest + // legitimate body is a few hundred bytes of form-encoded query + // fields; 1 MiB is a generous ceiling. + private const int MaxBodyBytes = 1 * 1024 * 1024; + + public static int Main(string[] argv) + { + // The embedding is stored inside JSON documents as a JSON + // array — host endianness doesn't matter there — but the + // *query* vector is sent to Redis as raw little-endian + // float32 bytes via the $vec param. The packer in + // LocalEmbedder writes little-endian explicitly through + // BinaryPrimitives, so a hypothetical big-endian .NET host + // would still produce the correct bytes; every supported + // runtime today is little-endian and a surprise here would + // silently corrupt every recall query, so assert it loudly + // at startup. + Debug.Assert(BitConverter.IsLittleEndian, + "this demo assumes a little-endian host"); + + Args args; + try + { + args = Args.Parse(argv); + } + catch (ArgumentException ex) + { + Console.Error.WriteLine($"Error: {ex.Message}"); + PrintHelp(); + return 2; + } + + // Resources held for the lifetime of the process. Declared + // here so the try/finally below disposes them on every exit + // path, including the early `return 1` branches that follow + // a successful Redis connect or model load. + ConnectionMultiplexer? mux = null; + LocalEmbedder? embedder = null; + HttpListener? listener = null; + try + { + try + { + mux = ConnectionMultiplexer.Connect(new ConfigurationOptions + { + EndPoints = { { args.RedisHost, args.RedisPort } }, + AbortOnConnectFail = false, + ConnectTimeout = 2000, + SyncTimeout = 5000, + }); + mux.GetDatabase().Ping(); + } + catch (Exception ex) + { + Console.Error.WriteLine( + $"Error: cannot reach Redis at {args.RedisHost}:{args.RedisPort}"); + Console.Error.WriteLine($" ({ex.Message})"); + return 1; + } + + var db = mux.GetDatabase(); + var session = new AgentSession( + db, + keyPrefix: args.SessionKeyPrefix, + defaultTtlSeconds: args.SessionTtlSeconds); + var memory = new LongTermMemory( + db, + indexName: args.MemIndexName, + keyPrefix: args.MemKeyPrefix, + dedupThreshold: args.DedupThreshold, + recallThreshold: args.RecallThreshold); + memory.CreateIndex(); + var events = new AgentEventLog(db, keyPrefix: args.EventKeyPrefix); + + Console.WriteLine( + "Loading embedding model (first run downloads ~90 MB of ONNX weights)..."); + try + { + embedder = LocalEmbedder.CreateAsync().GetAwaiter().GetResult(); + } + catch (Exception ex) + { + Console.Error.WriteLine($"Error loading embedder: {ex.Message}"); + return 1; + } + + var demo = new AgentMemoryDemo(session, memory, events, embedder); + + if (args.ResetOnStart) + { + Console.WriteLine( + $"Dropping any existing memories under '{args.MemKeyPrefix}*' and " + + "re-seeding from the sample memory list (pass --no-reset to keep)."); + int seeded = demo.SeedAll("default", "default"); + Console.WriteLine($"Seeded {seeded} memories."); + } + + // Load index.html once and substitute the template tokens so + // the docs panel shows the actual values in use rather than + // the default copies. The file ships next to the binary via + // the entry in the .csproj. + string htmlPath = Path.Combine(AppContext.BaseDirectory, "index.html"); + if (!File.Exists(htmlPath)) + { + Console.Error.WriteLine( + $"index.html not found next to the binary at {htmlPath}."); + return 1; + } + string rawHtml = File.ReadAllText(htmlPath); + string htmlPage = rawHtml + .Replace("__SESSION_PREFIX__", args.SessionKeyPrefix) + .Replace("__MEM_PREFIX__", args.MemKeyPrefix) + .Replace("__MEM_INDEX__", args.MemIndexName) + .Replace("__EVENT_PREFIX__", args.EventKeyPrefix); + + listener = new HttpListener(); + // HttpListener prefixes need a trailing slash; '+' wildcard + // would require admin rights on macOS/Linux, so we bind to + // the literal host string. 127.0.0.1 keeps the demo off the + // network by default. + string prefix = $"http://{args.Host}:{args.Port}/"; + listener.Prefixes.Add(prefix); + try + { + listener.Start(); + } + catch (Exception ex) + { + Console.Error.WriteLine($"Failed to bind {prefix}: {ex.Message}"); + return 1; + } + + Console.WriteLine( + $"Redis agent memory demo listening on http://{args.Host}:{args.Port}"); + Console.WriteLine( + $"Using Redis at {args.RedisHost}:{args.RedisPort}" + + $" with memory index '{args.MemIndexName}'"); + + var cts = new CancellationTokenSource(); + Console.CancelKeyPress += (_, e) => + { + e.Cancel = true; + Console.WriteLine("\nShutting down..."); + cts.Cancel(); + try { listener.Stop(); } catch { /* best-effort */ } + }; + + // One handler thread per request out of the ThreadPool. The + // ONNX session, the Redis multiplexer, the session helper, + // the memory helper, and the event log are all thread-safe; + // ``currentThreadId`` on the demo is mutable shared state + // but reads and writes race only in the corner case the + // walkthrough's Concurrency caveats section documents. + while (!cts.IsCancellationRequested) + { + HttpListenerContext ctx; + try + { + ctx = listener.GetContext(); + } + catch (HttpListenerException) { break; } + catch (ObjectDisposedException) { break; } + ThreadPool.QueueUserWorkItem(_ => + { + try + { + HandleRequest(ctx, demo, session, memory, events, embedder, htmlPage); + } + catch (Exception ex) + { + Console.Error.WriteLine( + $"[demo] handler error: {ex.GetType().Name}: {ex.Message}"); + TrySendError(ctx, ex); + } + }); + } + + return 0; + } + finally + { + try { listener?.Close(); } catch { /* best-effort */ } + embedder?.Dispose(); + mux?.Dispose(); + } + } + + // ------------------------------------------------------------------ + // Request dispatch + // ------------------------------------------------------------------ + + private static void HandleRequest( + HttpListenerContext ctx, + AgentMemoryDemo demo, + AgentSession session, + LongTermMemory memory, + AgentEventLog events, + LocalEmbedder embedder, + string htmlPage) + { + var req = ctx.Request; + string path = req.Url?.AbsolutePath ?? "/"; + + if (string.Equals(req.HttpMethod, "GET", StringComparison.OrdinalIgnoreCase)) + { + if (path == "/" || path == "/index.html") + { + SendHtml(ctx, htmlPage); + return; + } + if (path == "/state") + { + var qs = HttpUtility.ParseQueryString(req.Url?.Query ?? ""); + string user = OrDefault(qs["user"], demo.DefaultUser); + string @namespace = OrDefault(qs["namespace"], demo.DefaultNamespace); + SendJson(ctx, BuildState(demo, session, memory, events, embedder, user, @namespace)); + return; + } + SendJson(ctx, new { error = "not found" }, 404); + return; + } + + if (string.Equals(req.HttpMethod, "POST", StringComparison.OrdinalIgnoreCase)) + { + string body = ReadBody(req); + var form = HttpUtility.ParseQueryString(body); + + if (path == "/turn") + { + string text = (form["text"] ?? "").Trim(); + if (text.Length == 0) + { + SendJson(ctx, new { error = "text is required" }, 400); + return; + } + double threshold = ClampThreshold(form["threshold"], memory.RecallThreshold); + var payload = demo.HandleTurn( + text: text, + user: OrDefault(form["user"], "default"), + @namespace: OrDefault(form["namespace"], "default"), + kind: OrDefault(form["kind"], "episodic"), + role: OrDefault(form["role"], "user"), + threshold: threshold, + action: OrDefault(form["action"], "turn")); + SendJson(ctx, payload); + return; + } + + if (path == "/new_thread") + { + string threadId = demo.NewThread( + user: OrDefault(form["user"], "default"), + @namespace: OrDefault(form["namespace"], "default")); + SendJson(ctx, new { thread_id = threadId }); + return; + } + + if (path == "/reset") + { + int seeded = demo.SeedAll( + user: OrDefault(form["user"], "default"), + @namespace: OrDefault(form["namespace"], "default")); + SendJson(ctx, new { seeded }); + return; + } + + if (path == "/drop_memory") + { + string memoryId = (form["memory_id"] ?? "").Trim(); + if (memoryId.Length == 0) + { + SendJson(ctx, new { error = "memory_id is required" }, 400); + return; + } + bool deleted = memory.DeleteMemory(memoryId); + SendJson(ctx, new { deleted, memory_id = memoryId }); + return; + } + + SendJson(ctx, new { error = "not found" }, 404); + return; + } + + SendJson(ctx, new { error = "method not allowed" }, 405); + } + + private static object BuildState( + AgentMemoryDemo demo, + AgentSession session, + LongTermMemory memory, + AgentEventLog events, + LocalEmbedder embedder, + string user, + string @namespace) + { + var info = memory.IndexInfo(); + string threadId = demo.CurrentThreadId; + var state = session.Load(threadId); + var memories = memory.ListMemories(user: user, @namespace: @namespace, limit: 200); + var recentEvents = events.Recent(threadId, count: 20); + return new + { + index = new + { + num_docs = info.NumDocs, + indexing_failures = info.IndexingFailures, + index_name = memory.IndexName, + model = embedder.ModelName, + session_ttl_seconds = session.DefaultTtlSeconds, + dedup_threshold = memory.DedupThreshold, + default_recall_threshold = memory.RecallThreshold, + stack_label = StackLabel, + }, + thread_id = threadId, + session = state is null ? null : SerializeSession(state), + memories = memories.Select(SerializeMemory).ToArray(), + events = recentEvents.Select(SerializeEvent).ToArray(), + // `recalled` is populated by /turn; on plain /state reads + // the UI keeps showing the last turn's result, which is + // the useful behaviour for an "agent" panel. + recalled = Array.Empty(), + }; + } + + // ------------------------------------------------------------------ + // Serialisation helpers (match the Python/Node demo payloads + // exactly so the same index.html JS works) + // ------------------------------------------------------------------ + + internal static object SerializeSession(SessionState s) => new + { + thread_id = s.ThreadId, + user = s.User, + agent = s.Agent, + goal = s.Goal, + scratchpad = s.Scratchpad, + turn_count = s.TurnCount, + created_ts = s.CreatedTs, + last_active_ts = s.LastActiveTs, + recent_turns = s.RecentTurns.Select(t => new + { + role = t.Role, + content = t.Content, + ts = t.Ts, + }).ToArray(), + ttl_seconds = s.TtlSeconds, + }; + + internal static object SerializeMemory(MemoryRecord m) => new + { + id = m.Id, + user = m.User, + @namespace = m.Namespace, + kind = m.Kind, + source_thread = m.SourceThread, + text = m.Text, + created_ts = m.CreatedTs, + hit_count = m.HitCount, + distance = m.Distance, + ttl_seconds = m.TtlSeconds, + }; + + internal static object SerializeEvent(AgentEvent e) => new + { + event_id = e.EventId, + thread_id = e.ThreadId, + action = e.Action, + detail = e.Detail, + ts = e.Ts, + }; + + // ------------------------------------------------------------------ + // HTTP plumbing + // ------------------------------------------------------------------ + + private static string ReadBody(HttpListenerRequest req) + { + using var ms = new MemoryStream(); + var buffer = new byte[8192]; + int total = 0; + int read; + while ((read = req.InputStream.Read(buffer, 0, buffer.Length)) > 0) + { + total += read; + if (total > MaxBodyBytes) + { + throw new InvalidOperationException( + $"request body exceeds {MaxBodyBytes} bytes"); + } + ms.Write(buffer, 0, read); + } + return req.ContentEncoding.GetString(ms.ToArray()); + } + + private static void SendHtml(HttpListenerContext ctx, string html, int status = 200) + { + var bytes = Encoding.UTF8.GetBytes(html); + ctx.Response.StatusCode = status; + ctx.Response.ContentType = "text/html; charset=utf-8"; + ctx.Response.ContentLength64 = bytes.LongLength; + ctx.Response.OutputStream.Write(bytes, 0, bytes.Length); + ctx.Response.OutputStream.Close(); + } + + private static readonly JsonSerializerOptions JsonOpts = new() + { + DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.Never, + }; + + private static void SendJson(HttpListenerContext ctx, object payload, int status = 200) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(payload, JsonOpts); + ctx.Response.StatusCode = status; + ctx.Response.ContentType = "application/json"; + ctx.Response.ContentLength64 = bytes.LongLength; + ctx.Response.OutputStream.Write(bytes, 0, bytes.Length); + ctx.Response.OutputStream.Close(); + } + + private static void TrySendError(HttpListenerContext ctx, Exception ex) + { + try + { + SendJson(ctx, new { error = ex.Message, type = ex.GetType().Name }, 500); + } + catch + { + // Headers may already be partially flushed; nothing left to do. + } + } + + private static double ClampThreshold(string? raw, double fallback) + { + if (!double.TryParse(raw, NumberStyles.Float, CultureInfo.InvariantCulture, out double d)) + return fallback; + // `double.TryParse` accepts NaN/Infinity on some inputs. + // Either would silently turn recall into "every memory" or + // "nothing"; clamp to the meaningful cosine-distance range + // so a malformed POST can't override the threshold semantics. + if (!double.IsFinite(d)) return fallback; + return Math.Max(0.0, Math.Min(2.0, d)); + } + + // Empty form / query values are common when the JS client posts + // an unedited input. `??` only catches `null`, not `""`, so + // without this helper an empty `user` would flow into the TAG + // filter as `""` and the recall query would silently drop the + // scope — Python / Node / Go all normalize via the same shape. + private static string OrDefault(string? value, string fallback) + => string.IsNullOrEmpty(value) ? fallback : value; + + // ------------------------------------------------------------------ + // Help text + // ------------------------------------------------------------------ + + private static void PrintHelp() + { + Console.Error.WriteLine(@"Usage: dotnet run -- [flags] + + --host HTTP bind host (default 127.0.0.1) + --port HTTP bind port (default 8093) + --redis-host Redis host (default localhost) + --redis-port Redis port (default 6379) + --mem-index-name Memory index name (default agentmem:idx) + --mem-key-prefix JSON memory key prefix (default agent:mem:) + --session-key-prefix Session hash key prefix (default agent:session:) + --event-key-prefix Event stream key prefix (default agent:events:) + --session-ttl-seconds Working memory TTL (default 3600) + --dedup-threshold Cosine distance for dedup (default 0.20) + --recall-threshold Cosine distance for recall (default 0.55) + --no-reset Skip clearing and re-seeding on startup +"); + } +} + +// ---------------------------------------------------------------------- +// Demo orchestrator +// ---------------------------------------------------------------------- + +/// +/// Demo state: working memory, long-term memory, event log. +/// +/// +/// / / +/// all touch +/// without coordination — see the walkthrough's "Concurrency +/// caveats" section. The demo is single-user in practice, so the +/// race never triggers; a multi-user agent would carry the thread id +/// on each request instead of holding it as shared server state. +/// +public sealed class AgentMemoryDemo +{ + private readonly AgentSession _session; + private readonly LongTermMemory _memory; + private readonly AgentEventLog _events; + private readonly LocalEmbedder _embedder; + public string DefaultUser { get; } + public string DefaultNamespace { get; } + public string CurrentThreadId { get; private set; } + + public AgentMemoryDemo( + AgentSession session, + LongTermMemory memory, + AgentEventLog events, + LocalEmbedder embedder, + string defaultUser = "default", + string defaultNamespace = "default") + { + _session = session; + _memory = memory; + _events = events; + _embedder = embedder; + DefaultUser = defaultUser; + DefaultNamespace = defaultNamespace; + CurrentThreadId = session.NewThreadId(); + } + + /// Drop everything in scope and pre-populate with seed memories. + public int SeedAll(string user, string @namespace) + { + _memory.Clear(); + _session.Delete(CurrentThreadId); + _events.Clear(CurrentThreadId); + int written = SeedMemory.Seed(_memory, _embedder, user: user, @namespace: @namespace); + CurrentThreadId = _session.NewThreadId(); + return written; + } + + /// Start a fresh thread. Long-term memory is unaffected. + public string NewThread(string user, string @namespace) + { + _events.Clear(CurrentThreadId); + CurrentThreadId = _session.NewThreadId(); + _session.Start(CurrentThreadId, user: user, agentName: "demo-agent", goal: ""); + _events.Record( + CurrentThreadId, + "thread_started", + $"user={user} namespace={@namespace}"); + return CurrentThreadId; + } + + /// + /// One pass through the agent loop: append, recall, remember, log. + /// + /// + /// The order matters. We embed once and reuse the vector + /// for both the recall and (if asked) the remember step — no + /// point encoding the same text twice. Recall runs before + /// the remember write so the agent doesn't see its own just- + /// written turn as a recalled memory. + /// + public object HandleTurn( + string text, + string user, + string @namespace, + string kind, + string role, + double threshold, + string action) + { + string threadId = CurrentThreadId; + + var t0 = System.Diagnostics.Stopwatch.GetTimestamp(); + float[] vec = _embedder.EncodeOne(text); + double embedMs = ElapsedMs(t0); + + // `SetGoal` only touches the goal field so existing turns + // aren't wiped; `AppendTurn` carries the request `user` + // through to the auto-create path so a first turn for a new + // thread doesn't land under the default user. + string sessionAction; + if (action == "goal") + { + _session.SetGoal(threadId, text, user: user, agentName: "demo-agent"); + sessionAction = "goal_set"; + } + else + { + _session.AppendTurn( + threadId, + role: role, + content: text, + user: user, + agentName: "demo-agent"); + sessionAction = $"turn_appended:{role}"; + } + + var t1 = System.Diagnostics.Stopwatch.GetTimestamp(); + var recalled = _memory.Recall( + queryEmbedding: vec, + user: user, + @namespace: @namespace, + k: 5, + distanceThreshold: threshold); + double recallMs = ElapsedMs(t1); + + bool writeSkipped = kind == "skip" || action == "goal"; + WriteResult? writeResult = null; + double writeMs = 0; + if (!writeSkipped) + { + var t2 = System.Diagnostics.Stopwatch.GetTimestamp(); + writeResult = _memory.Remember( + text: text, + embedding: vec, + user: user, + @namespace: @namespace, + kind: kind, + sourceThread: threadId); + writeMs = ElapsedMs(t2); + } + + // Append to event log so the audit trail shows what happened. + if (writeResult is not null) + { + string detail = writeResult.Deduped + ? $"deduped onto {writeResult.Id}" + : $"wrote {writeResult.Id} as {kind}"; + _events.Record(threadId, sessionAction, detail); + } + else + { + _events.Record(threadId, sessionAction, ""); + } + + return new + { + thread_id = threadId, + write_skipped = writeSkipped, + memory_id = writeResult?.Id, + deduped = writeResult?.Deduped ?? false, + existing_distance = writeResult?.ExistingDistance, + kind = writeSkipped ? null : kind, + recalled = recalled.Select(Program.SerializeMemory).ToArray(), + embed_ms = embedMs, + recall_ms = recallMs, + write_ms = writeMs, + }; + } + + private static double ElapsedMs(long start) + => (System.Diagnostics.Stopwatch.GetTimestamp() - start) + * 1000.0 / System.Diagnostics.Stopwatch.Frequency; +} + +// ---------------------------------------------------------------------- +// Arg parsing +// ---------------------------------------------------------------------- + +internal sealed record Args( + string Host, + int Port, + string RedisHost, + int RedisPort, + string MemIndexName, + string MemKeyPrefix, + string SessionKeyPrefix, + string EventKeyPrefix, + long SessionTtlSeconds, + double DedupThreshold, + double RecallThreshold, + bool ResetOnStart) +{ + public static Args Parse(string[] argv) + { + string host = "127.0.0.1"; + int port = 8093; + string redisHost = "localhost"; + int redisPort = 6379; + string memIndex = "agentmem:idx"; + string memPrefix = "agent:mem:"; + string sessionPrefix = "agent:session:"; + string eventPrefix = "agent:events:"; + long sessionTtl = 3600; + double dedup = 0.20; + double recall = 0.55; + bool reset = true; + + for (int i = 0; i < argv.Length; i++) + { + string a = argv[i]; + string? Take() => i + 1 < argv.Length ? argv[++i] : null; + + switch (a) + { + case "--host": host = Take() ?? host; break; + case "--port": port = int.Parse(Take() ?? "8093"); break; + case "--redis-host": redisHost = Take() ?? redisHost; break; + case "--redis-port": redisPort = int.Parse(Take() ?? "6379"); break; + case "--mem-index-name": memIndex = Take() ?? memIndex; break; + case "--mem-key-prefix": memPrefix = Take() ?? memPrefix; break; + case "--session-key-prefix": sessionPrefix = Take() ?? sessionPrefix; break; + case "--event-key-prefix": eventPrefix = Take() ?? eventPrefix; break; + case "--session-ttl-seconds": sessionTtl = long.Parse(Take() ?? "3600"); break; + case "--dedup-threshold": + dedup = double.Parse(Take() ?? "0.20", CultureInfo.InvariantCulture); + break; + case "--recall-threshold": + recall = double.Parse(Take() ?? "0.55", CultureInfo.InvariantCulture); + break; + case "--no-reset": reset = false; break; + case "--help": + case "-h": + throw new ArgumentException("help requested"); + default: + throw new ArgumentException($"unknown flag: {a}"); + } + } + + return new Args( + Host: host, + Port: port, + RedisHost: redisHost, + RedisPort: redisPort, + MemIndexName: memIndex, + MemKeyPrefix: memPrefix, + SessionKeyPrefix: sessionPrefix, + EventKeyPrefix: eventPrefix, + SessionTtlSeconds: sessionTtl, + DedupThreshold: dedup, + RecallThreshold: recall, + ResetOnStart: reset); + } +} diff --git a/content/develop/use-cases/agent-memory/dotnet/Records.cs b/content/develop/use-cases/agent-memory/dotnet/Records.cs new file mode 100644 index 0000000000..503bf68571 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/Records.cs @@ -0,0 +1,57 @@ +namespace AgentMemoryDemo; + +/// One turn inside the rolling session window. +public sealed record SessionTurn(string Role, string Content, double Ts); + +/// The full per-thread working-memory state. +public sealed record SessionState( + string ThreadId, + string User, + string Agent, + string Goal, + string Scratchpad, + long TurnCount, + double CreatedTs, + double LastActiveTs, + IReadOnlyList RecentTurns, + long TtlSeconds); + +/// A single long-term memory document. +/// +/// is populated only when the record +/// comes back from a KNN query; is +/// null for memories with no TTL (e.g. kind=semantic +/// under the default tier map). +/// +public sealed record MemoryRecord( + string Id, + string User, + string Namespace, + string Kind, + string SourceThread, + string Text, + double CreatedTs, + long HitCount, + double? Distance = null, + long? TtlSeconds = null); + +/// Outcome of a LongTermMemory.Remember call. +/// +/// is true when the write skipped +/// because a similar memory already existed; is then +/// the existing memory's id. is the +/// cosine distance to that nearest memory regardless of which branch +/// was taken — useful for tracing. +/// +public sealed record WriteResult(string Id, bool Deduped, double? ExistingDistance); + +/// One entry from the per-thread event Stream. +public sealed record AgentEvent( + string EventId, + string ThreadId, + string Action, + string Detail, + double Ts); + +/// Subset of FT.INFO useful for the demo UI. +public sealed record IndexSnapshot(long NumDocs, long IndexingFailures); diff --git a/content/develop/use-cases/agent-memory/dotnet/SeedMemory.cs b/content/develop/use-cases/agent-memory/dotnet/SeedMemory.cs new file mode 100644 index 0000000000..0d5d031704 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/SeedMemory.cs @@ -0,0 +1,89 @@ +namespace AgentMemoryDemo; + +/// +/// Pre-seed the long-term memory store with sample memories. +/// +/// +/// In a real deployment the memory store fills up organically +/// as the agent reasons over user turns: each turn produces zero or +/// more memories (preferences, facts, episodic summaries) that flow +/// into the store with deduplication. To make the demo immediately +/// useful — so the first recall query lands on relevant results +/// instead of an empty list — we seed a small set of canonical +/// memories for a default user at startup. +/// +/// The seed list mixes semantic memories (long-lived +/// preferences and facts) with episodic memories (snapshots +/// of past sessions), matching what the Python and Node demos seed +/// so the three implementations behave identically. +/// +public static class SeedMemory +{ + public sealed record SeedEntry(string Text, string Kind); + + public static readonly IReadOnlyList SeedMemories = new[] + { + new SeedEntry( + "The user prefers concise answers without filler phrases.", + "semantic"), + new SeedEntry( + "The user is a Python developer working on a logistics platform.", + "semantic"), + new SeedEntry( + "The user lives in Berlin and works in the Europe/Berlin time zone.", + "semantic"), + new SeedEntry( + "The user dislikes dark mode and prefers a high-contrast light " + + "theme in editors and dashboards.", + "semantic"), + new SeedEntry( + "The user is allergic to peanuts; any restaurant suggestion must " + + "avoid dishes that commonly contain them.", + "semantic"), + new SeedEntry( + "Last Tuesday the user asked the agent to draft a postmortem for " + + "the order-routing outage. The agent produced a five-section " + + "draft and the user approved sections 1, 2, and 4 with minor " + + "edits.", + "episodic"), + new SeedEntry( + "In a previous session the user asked for help debugging a flaky " + + "test in the inventory service. The fix turned out to be a race " + + "condition in the warehouse webhook handler.", + "episodic"), + new SeedEntry( + "Two weeks ago the user mentioned they were planning to migrate " + + "the analytics warehouse from Snowflake to BigQuery in Q3.", + "episodic"), + }; + + /// + /// Embed and write the seed memories. Returns the count actually + /// written (entries that dedup against existing memories don't + /// count). + /// + public static int Seed( + LongTermMemory memory, + LocalEmbedder embedder, + string user = "default", + string @namespace = "default", + string sourceThread = "seed") + { + var texts = SeedMemories.Select(m => m.Text).ToList(); + var vectors = embedder.EncodeMany(texts); + int written = 0; + for (int i = 0; i < SeedMemories.Count; i++) + { + var entry = SeedMemories[i]; + var result = memory.Remember( + text: entry.Text, + embedding: vectors[i], + user: user, + @namespace: @namespace, + kind: entry.Kind, + sourceThread: sourceThread); + if (!result.Deduped) written++; + } + return written; + } +} diff --git a/content/develop/use-cases/agent-memory/dotnet/_index.md b/content/develop/use-cases/agent-memory/dotnet/_index.md new file mode 100644 index 0000000000..ce1fcb7637 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/_index.md @@ -0,0 +1,334 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in C# with NRedisStack, ONNX Runtime, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: NRedisStack example (C#) +title: Redis agent memory with NRedisStack +weight: 3 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in C# (.NET 8) with [NRedisStack]({{< relref "/develop/clients/dotnet" >}}) and the ONNX Runtime, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with the .NET [`HttpListener`](https://learn.microsoft.com/en-us/dotnet/api/system.net.httplistener) so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder runs the ONNX-exported [`Xenova/all-MiniLM-L6-v2`](https://huggingface.co/Xenova/all-MiniLM-L6-v2) model — the same encoder the [Python]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) and [Node.js]({{< relref "/develop/use-cases/agent-memory/nodejs" >}}) examples use. .NET ONNX Runtime is the same C++ kernel that powers Python's `onnxruntime`, so the vectors produced here are numerically identical to the Python ones to within rounding noise. The distance bands the Python walkthrough quotes carry over to this demo without recalibration, and a memory written by one demo can be recalled by the other against the same Redis instance. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* A single round trip per tier: one [`HGETALL`]({{< relref "/commands/hgetall" >}}) for the session, one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) for recall, one [`XADD`]({{< relref "/commands/xadd" >}}) for the event log. +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.EncodeOne(text)` to turn the incoming turn into a 384-dimensional `float[]`. +2. `session.AppendTurn(threadId, role, content, user, agentName)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI/EXEC`]({{< relref "/commands/multi" >}}). The session TTL refreshes on every write so an active thread stays alive. +3. `memory.Recall(vec, user, @namespace, k: 5)` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.Remember(text, vec, user, @namespace, kind)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `events.Record(threadId, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/dotnet/AgentSession.cs)): + +```csharp +using StackExchange.Redis; +using AgentMemoryDemo; + +var mux = ConnectionMultiplexer.Connect("localhost:6379"); +var db = mux.GetDatabase(); + +var session = new AgentSession( + db, + keyPrefix: "agent:session:", + defaultTtlSeconds: 3600, // one hour + maxTurns: 20); // rolling window per thread + +string threadId = session.NewThreadId(); +session.Start(threadId, user: "alice", agentName: "demo-agent", + goal: "Plan next week's meetings."); +session.AppendTurn(threadId, role: "user", + content: "Schedule a budget review with finance.", + user: "alice", agentName: "demo-agent"); +SessionState? state = session.Load(threadId); +Console.WriteLine($"{state?.TurnCount} {state?.RecentTurns.Count} {state?.TtlSeconds}"); +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `Start`, `AppendTurn`, `SetGoal`, `SetScratchpad` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) (via NRedisStack's `CreateTransaction`) so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/dotnet/LongTermMemory.cs)): + +```csharp +using AgentMemoryDemo; + +var memory = new LongTermMemory( + db, + indexName: "agentmem:idx", + keyPrefix: "agent:mem:", + dedupThreshold: 0.20, // cosine distance — tight at write time + recallThreshold: 0.55); // looser at read time +var embedder = await LocalEmbedder.CreateAsync(); +memory.CreateIndex(); // idempotent + +// Write a memory. The same KNN that powers recall also runs here at +// a tighter threshold so paraphrases of the same fact collapse. +float[] vec = embedder.EncodeOne("The user prefers light mode in editors."); +WriteResult result = memory.Remember( + text: "The user prefers light mode in editors.", + embedding: vec, + user: "alice", + @namespace: "default", + kind: "semantic", + sourceThread: "9f3d2a4b8c61"); +Console.WriteLine($"deduped={result.Deduped} id={result.Id} dist={result.ExistingDistance}"); + +// Recall against a later question. +float[] q = embedder.EncodeOne("Which theme does this user like?"); +var hits = memory.Recall( + queryEmbedding: q, + user: "alice", + @namespace: "default", + k: 5); +foreach (var h in hits) +{ + Console.WriteLine($"{h.Distance:F3} [{h.Kind}] {h.Text}"); +} +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is stored as a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (packed by `LocalEmbedder.ToBytes` with explicit little-endian writes), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with alias names so the query syntax stays compact (`FieldName("$.user", "user")` writes `$.user AS user` into `FT.CREATE`): + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`Remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write with `ttlSeconds: ...` on `Remember`, or pass a different `ttlByKind: ...` map to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/dotnet/AgentEventLog.cs)): + +```csharp +var events = new AgentEventLog(db, maxLen: 1000); +events.Record(threadId, "turn_appended:user", + "Schedule a budget review with finance."); +events.Record(threadId, "memory_written", + "wrote 7c3f8a1b9e02 as semantic"); + +foreach (var e in events.Recent(threadId, count: 20)) +{ + Console.WriteLine($"{e.Action} {e.Detail}"); +} +``` + +`Record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000` (StackExchange.Redis's `useApproximateMaxLength: true`). The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarisers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession.AppendTurn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recent_turns` list in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recent_turns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory.Remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `CurrentThreadId` that `/new_thread` and `/reset` mutate. `HandleTurn` reads it without coordination, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `SeedMemory.cs` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/dotnet/SeedMemory.cs)): + +```csharp +using AgentMemoryDemo; + +var memory = new LongTermMemory(db); +var embedder = await LocalEmbedder.CreateAsync(); +memory.CreateIndex(); +SeedMemory.Seed(memory, embedder, user: "default", @namespace: "default"); +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`Program.cs` runs a .NET [`HttpListener`](https://learn.microsoft.com/en-us/dotnet/api/system.net.httplistener) on port 8093, dispatching to handlers on the thread pool. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a single property on the demo object that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/dotnet + ``` + +2. Restore and build the project. You'll need the [.NET 8 SDK](https://dotnet.microsoft.com/download) + or later: + + ```bash + dotnet build + ``` + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo. The first run downloads the + [`Xenova/all-MiniLM-L6-v2`](https://huggingface.co/Xenova/all-MiniLM-L6-v2) ONNX + weights (around 90 MB) and `vocab.txt` into a local `model_cache/` directory next + to the binary: + + ```bash + dotnet run + ``` + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing + recalls above the threshold (every seed lands above 0.8), and the new + memory is written as `episodic` (or `semantic`, depending on the + dropdown) under a fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + The .NET ONNX Runtime is the same C++ kernel that powers Python's + `onnxruntime`, so distances here match the Python demo to four decimal + places. `Xenova/all-MiniLM-L6-v2` puts a faithful paraphrase in the + 0.15 – 0.50 cosine-distance range, a loose paraphrase or related topic in + the 0.50 – 0.80 range, and unrelated queries above 0.8 — which is what + motivates the 0.55 default recall threshold and the 0.20 default dedup + threshold. A stricter embedding model (or a domain-tuned one) would let + you tighten both; a noisier one would push them up. The right thresholds + are always a function of the model, the corpus, and how conservative the + agent needs to be about accepting a memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags (pass them after `--`, for example `dotnet run -- --no-reset`): + +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/dotnet/index.html b/content/develop/use-cases/agent-memory/dotnet/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/dotnet/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/go/.gitignore b/content/develop/use-cases/agent-memory/go/.gitignore new file mode 100644 index 0000000000..bd6d073685 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/.gitignore @@ -0,0 +1,7 @@ +# Hugot downloads the ONNX model into ./models on first run; the +# weights are ~87 MB and we don't want them in the repo. +models/ + +# `go build` in this directory produces a `go` binary that we +# also don't want to commit. +/go diff --git a/content/develop/use-cases/agent-memory/go/_index.md b/content/develop/use-cases/agent-memory/go/_index.md new file mode 100644 index 0000000000..0816237d5a --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/_index.md @@ -0,0 +1,359 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Go with go-redis, Hugot, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: go-redis example (Go) +title: Redis agent memory with go-redis +weight: 5 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Go with [`go-redis`]({{< relref "/develop/clients/go" >}}) and the [Hugot](https://pkg.go.dev/github.com/knights-analytics/hugot) library, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with Go's standard [`net/http`](https://pkg.go.dev/net/http) so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [Hugot](https://pkg.go.dev/github.com/knights-analytics/hugot) running the ONNX-exported `sentence-transformers/all-MiniLM-L6-v2` model — the same encoder the [vector search guide]({{< relref "/develop/clients/go/vecsearch" >}}) uses for Go and the same one the [Python]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) example loads. Hugot drives the same ONNX Runtime kernel under the hood as Python's `onnxruntime`, so the vectors produced here are numerically identical to the Python ones to within rounding noise, and the distance bands the Python walkthrough quotes carry over to this demo without recalibration. A memory written by one demo can be recalled by the other against the same Redis instance. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* One Redis Search call per recall: [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) does the KNN + TAG pre-filter in a single round trip (a per-row [`TTL`]({{< relref "/commands/ttl" >}}) follow-up is the only other read the helper issues, just to populate the `ttl_seconds` field for the admin panel). Working memory is one [`HGETALL`]({{< relref "/commands/hgetall" >}}); the event log is one [`XADD`]({{< relref "/commands/xadd" >}}). +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.EncodeOne(ctx, text)` to turn the incoming turn into a 384-element `[]float32`. +2. `session.AppendTurn(ctx, threadID, AppendTurnParams{...})` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) (go-redis's `TxPipelined`). The session TTL refreshes on every write so an active thread stays alive. +3. `memory.Recall(ctx, RecallParams{...})` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.Remember(ctx, RememberParams{...})` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `events.Record(ctx, threadID, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/go/session_store.go)): + +```go +import ( + "context" + "fmt" + "github.com/redis/go-redis/v9" +) + +client := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) +session := NewAgentSession(client, "agent:session:", 3600, 20) + +ctx := context.Background() +threadID := session.NewThreadID() +state, err := session.Start(ctx, threadID, StartParams{ + User: "alice", + Agent: "demo-agent", + Goal: "Plan next week's meetings.", +}) +if err != nil { + panic(err) +} +state, err = session.AppendTurn(ctx, threadID, AppendTurnParams{ + Role: "user", + Content: "Schedule a budget review with finance.", + User: "alice", + Agent: "demo-agent", +}) +if err != nil { + panic(err) +} +fmt.Println(state.TurnCount, len(state.RecentTurns), state.TTLSeconds) +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `Start`, `AppendTurn`, `SetGoal` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a `client.TxPipelined` block so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/go/long_term_memory.go)): + +```go +memory := NewLongTermMemory( + client, + "agentmem:idx", + "agent:mem:", + 384, + 0.20, // dedup threshold — tight at write time + 0.55, // recall threshold — looser at read time + nil, // use the default per-kind TTL map +) +embedder, err := NewLocalEmbedder(ctx, "", "") +if err != nil { + panic(err) +} +defer embedder.Close() +if err := memory.CreateIndex(ctx); err != nil { // idempotent + panic(err) +} + +// Write a memory. The same KNN that powers recall also runs here at +// a tighter threshold so paraphrases of the same fact collapse. +vec, err := embedder.EncodeOne(ctx, "The user prefers light mode in editors.") +if err != nil { + panic(err) +} +result, err := memory.Remember(ctx, RememberParams{ + Text: "The user prefers light mode in editors.", + Embedding: vec, + User: "alice", + Namespace: "default", + Kind: "semantic", + SourceThread: "9f3d2a4b8c61", +}) +if err != nil { + panic(err) +} +fmt.Printf("deduped=%v id=%s dist=%v\n", result.Deduped, result.ID, result.ExistingDistance) + +// Recall against a later question. +q, _ := embedder.EncodeOne(ctx, "Which theme does this user like?") +hits, _ := memory.Recall(ctx, RecallParams{ + QueryEmbedding: q, + User: "alice", + Namespace: "default", + K: 5, +}) +for _, h := range hits { + fmt.Printf("%.3f [%s] %s\n", *h.Distance, h.Kind, h.Text) +} +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is stored as a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (`embeddings.go`'s `FloatsToBytes` packs them in little-endian order), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with `As` aliases on each path so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`Remember` resolves the entry's TTL from the memory's `Kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write by setting `TTLSeconds` on `RememberParams`, or pass a different `ttlByKind` map to `NewLongTermMemory` — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/go/event_log.go)): + +```go +events := NewAgentEventLog(client, "agent:events:", 1000) +_, _ = events.Record(ctx, threadID, "turn_appended:user", + "Schedule a budget review with finance.") +_, _ = events.Record(ctx, threadID, "memory_written", + "wrote 7c3f8a1b9e02 as semantic") + +list, _ := events.Recent(ctx, threadID, 20) +for _, e := range list { + fmt.Println(e.Action, e.Detail) +} +``` + +`Record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000` (go-redis's `XAddArgs.Approx: true`). The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarisers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession.AppendTurn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `RecentTurns` slice in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `RecentTurns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory.Remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `currentThreadID` (a `sync.Mutex`-protected string field on `AgentMemoryDemo`) that `/new_thread` and `/reset` mutate. `HandleTurn` reads it under the mutex but then drops the lock immediately, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `seed_memory.go` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/go/seed_memory.go)): + +```go +memory := NewLongTermMemory(client, "agentmem:idx", "agent:mem:", + 384, 0.20, 0.55, nil) +embedder, _ := NewLocalEmbedder(ctx, "", "") +defer embedder.Close() +_ = memory.CreateIndex(ctx) +written, _ := Seed(ctx, memory, embedder, "default", "default", "seed") +fmt.Printf("seeded %d memories\n", written) +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`main.go` runs an [`http.Server`](https://pkg.go.dev/net/http#Server) on port 8090, dispatching to handlers on the standard library's request-per-goroutine pool. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a `sync.Mutex`-protected string field on `AgentMemoryDemo` that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/go + ``` + +2. Resolve the dependencies. You'll need [Go 1.26](https://go.dev/dl/) or later (the version `go.mod` declares — Hugot tracks recent toolchain releases): + + ```bash + go mod tidy + ``` + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo. The first run downloads the ONNX-exported + `sentence-transformers/all-MiniLM-L6-v2` weights into the local `./models` + directory: + + ```bash + go run . + ``` + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing + recalls above the threshold (every seed lands above 0.8), and the new + memory is written as `episodic` (or `semantic`, depending on the + dropdown) under a fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + Hugot's ONNX Runtime is the same C++ kernel that powers Python's + `onnxruntime`, so distances here match the Python demo to four decimal + places. `sentence-transformers/all-MiniLM-L6-v2` puts a faithful paraphrase + in the 0.15 – 0.50 cosine-distance range, a loose paraphrase or related + topic in the 0.50 – 0.80 range, and unrelated queries above 0.8 — which + is what motivates the 0.55 default recall threshold and the 0.20 default + dedup threshold. A stricter embedding model (or a domain-tuned one) would + let you tighten both; a noisier one would push them up. The right + thresholds are always a function of the model, the corpus, and how + conservative the agent needs to be about accepting a memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags: + +* `--host` / `--port` — change the HTTP bind address (default `127.0.0.1:8090`). +* `--redis-host` / `--redis-port` — point at a non-local Redis (default `localhost:6379`). +* `--mem-index-name` / `--mem-key-prefix` / `--session-key-prefix` / `--event-key-prefix` — relocate the index name and the three key prefixes (to run several demos against one Redis without colliding, for example). +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/go/embeddings.go b/content/develop/use-cases/agent-memory/go/embeddings.go new file mode 100644 index 0000000000..877975b031 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/embeddings.go @@ -0,0 +1,203 @@ +// Local text-embedding helper backed by Hugot. +// +// This is a thin wrapper around the sentence-transformers model +// `sentence-transformers/all-MiniLM-L6-v2`: a 384-dimensional encoder +// that runs in-process on CPU through Hugot's pure-Go inference +// backend (`hugot.NewGoSession`), needs no API key, and produces +// vectors numerically equivalent to the equivalent PyTorch model +// from sentence-transformers. +// +// Vectors are explicitly L2-normalised after extraction so cosine +// distance against another normalised vector reduces to `1 - dot +// product` — matching the behaviour of `sentence-transformers`' +// `normalize_embeddings=True` flag in the Python example and +// `@xenova/transformers`' `normalize: true` option in the Node.js +// example. The model is downloaded into the local `./models` cache +// on the first call; every later call runs offline. + +package main + +import ( + "context" + "encoding/binary" + "fmt" + "math" + "os" + "path/filepath" + + "github.com/knights-analytics/hugot" + "github.com/knights-analytics/hugot/pipelines" +) + +const defaultEmbedModel = "sentence-transformers/all-MiniLM-L6-v2" + +// LocalEmbedder wraps a Hugot feature-extraction pipeline. +// +// Use `NewLocalEmbedder` instead of constructing the struct directly +// because the pipeline load is asynchronous in spirit (it downloads +// the model on first call) and we want one place that owns the wait +// and the dimension probe. +type LocalEmbedder struct { + ModelName string + Dim int + session *hugot.Session + pipeline *pipelines.FeatureExtractionPipeline +} + +// NewLocalEmbedder loads the ONNX model (downloading on first run) +// and returns a ready-to-use embedder. The dimension is probed once +// from a synthetic input so we can fail loudly if a different model +// is wired up against the 384-dim Redis Search field. +func NewLocalEmbedder(ctx context.Context, modelName, modelsDir string) (*LocalEmbedder, error) { + if modelName == "" { + modelName = defaultEmbedModel + } + if modelsDir == "" { + modelsDir = "./models" + } + if err := os.MkdirAll(modelsDir, 0o755); err != nil { + return nil, fmt.Errorf("creating models dir %q: %w", modelsDir, err) + } + + session, err := hugot.NewGoSession(ctx) + if err != nil { + return nil, fmt.Errorf("starting Hugot session: %w", err) + } + + downloadOpts := hugot.NewDownloadOptions() + downloadOpts.OnnxFilePath = "onnx/model.onnx" + modelPath, err := hugot.DownloadModel(ctx, modelName, modelsDir, downloadOpts) + if err != nil { + _ = session.Destroy() + return nil, fmt.Errorf("downloading model %q: %w", modelName, err) + } + + cfg := hugot.FeatureExtractionConfig{ + ModelPath: modelPath, + Name: filepath.Base(modelPath), + } + pipe, err := hugot.NewPipeline(session, cfg) + if err != nil { + _ = session.Destroy() + return nil, fmt.Errorf("creating feature-extraction pipeline: %w", err) + } + + // Probe the output shape once so we can fail loudly if a different + // model is wired up against the 384-dim Redis Search field. + probe, err := pipe.RunPipeline(ctx, []string{"dimension probe"}) + if err != nil { + _ = session.Destroy() + return nil, fmt.Errorf("probing embedding pipeline: %w", err) + } + if len(probe.Embeddings) == 0 || len(probe.Embeddings[0]) == 0 { + _ = session.Destroy() + return nil, fmt.Errorf("embedding probe returned empty result") + } + + return &LocalEmbedder{ + ModelName: modelName, + Dim: len(probe.Embeddings[0]), + session: session, + pipeline: pipe, + }, nil +} + +// Close tears down the underlying Hugot session. Safe to call more +// than once; subsequent calls are no-ops. +func (e *LocalEmbedder) Close() error { + if e == nil || e.session == nil { + return nil + } + err := e.session.Destroy() + e.session = nil + return err +} + +// EncodeOne returns a 384-element float32 vector for the input string. +// The vector is L2-normalised so cosine distance against another +// normalised vector reduces to 1 - dot product. +// +// Hugot reuses the same backing slice across calls on a given +// pipeline, so the returned vector is copied before normalising — +// otherwise overlapping `/turn` handlers on the shared embedder +// would race on the same buffer and either corrupt the stored +// embedding or hand the recall path a vector that's already been +// overwritten by the next request. +func (e *LocalEmbedder) EncodeOne(ctx context.Context, text string) ([]float32, error) { + out, err := e.pipeline.RunPipeline(ctx, []string{text}) + if err != nil { + return nil, fmt.Errorf("encoding text: %w", err) + } + if len(out.Embeddings) == 0 { + return nil, fmt.Errorf("pipeline returned no embeddings") + } + vec := append([]float32(nil), out.Embeddings[0]...) + normalizeInPlace(vec) + return vec, nil +} + +// EncodeMany batches several strings in a single pipeline call so the +// model only pays the setup cost once. Returns one float32 slice per +// input, in the same order as the input. +func (e *LocalEmbedder) EncodeMany(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + out, err := e.pipeline.RunPipeline(ctx, texts) + if err != nil { + return nil, fmt.Errorf("encoding texts: %w", err) + } + // Hugot guarantees one vector per input on success, but defensive + // callers (seed loaders, batch ingest) assume that contract; + // surfacing it as an explicit check avoids an index-out-of-range + // panic later if the backend ever returns a short batch. + if len(out.Embeddings) != len(texts) { + return nil, fmt.Errorf( + "pipeline returned %d vectors for %d inputs", + len(out.Embeddings), len(texts), + ) + } + // Copy each row off the pipeline's reusable backing slice — see + // the comment on `EncodeOne` for why. The seed loader is the + // usual caller here and doesn't itself race, but the contract + // has to hold for any future caller that does. + vecs := make([][]float32, len(out.Embeddings)) + for i := range out.Embeddings { + vecs[i] = append([]float32(nil), out.Embeddings[i]...) + normalizeInPlace(vecs[i]) + } + return vecs, nil +} + +// normalizeInPlace L2-normalises a vector so it has unit length. +// A zero vector is left untouched (its cosine distance to anything +// is undefined, but at least Redis won't reject the bytes). +func normalizeInPlace(v []float32) { + var sumSq float64 + for _, x := range v { + sumSq += float64(x) * float64(x) + } + if sumSq == 0 { + return + } + inv := float32(1.0 / math.Sqrt(sumSq)) + for i := range v { + v[i] *= inv + } +} + +// FloatsToBytes packs a []float32 into the raw little-endian byte +// sequence Redis Search expects for a FLOAT32 vector field. The +// `binary.LittleEndian` here matters: Redis Search reads the bytes +// in little-endian order regardless of the host architecture, so we +// can't use `binary.NativeEndian` if the docs example ever needs to +// run on a big-endian box. Every supported Go target is little-endian +// today, so the practical difference is zero — but explicit is +// cheaper than mysterious off-by-everything vector mismatches. +func FloatsToBytes(fs []float32) []byte { + buf := make([]byte, len(fs)*4) + for i, f := range fs { + binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(f)) + } + return buf +} diff --git a/content/develop/use-cases/agent-memory/go/event_log.go b/content/develop/use-cases/agent-memory/go/event_log.go new file mode 100644 index 0000000000..22d76f5c21 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/event_log.go @@ -0,0 +1,125 @@ +// Append-only event log for an agent thread, backed by a Redis +// Stream. +// +// Each thread gets a stream at `agent:events:{ThreadID}`. Every +// action the agent takes (a user turn arriving, a memory being +// recalled, a memory being written, a tool being called) is one +// `XADD` to that stream. Replay with `XREVRANGE` for the most recent +// N events; bound retention with `XTRIM MAXLEN ~` so the log stays +// cheap regardless of how long the thread has been running. +// +// The stream is independent of the session hash and the long-term +// memory store: it answers the "what just happened" question without +// competing with either of those for indexing or memory budget. +// Consumer groups (not used in this demo) would let downstream +// workers — summarisers, consolidators, audit pipelines — replay the +// log without losing position. + +package main + +import ( + "context" + "fmt" + "strconv" + + "github.com/redis/go-redis/v9" +) + +// DefaultMaxLen is the approximate cap on stream length. `MAXLEN ~` +// lets Redis trim in whole-node units instead of exactly-N units, +// which is much cheaper at the cost of overshooting the bound by up +// to a node's worth. +const DefaultMaxLen = 1000 + +// AgentEvent is a single entry from the per-thread event stream. +type AgentEvent struct { + EventID string `json:"event_id"` + ThreadID string `json:"thread_id"` + Action string `json:"action"` + Detail string `json:"detail"` + TS float64 `json:"ts"` +} + +// AgentEventLog appends, replays, and bounds the per-thread event +// stream. +type AgentEventLog struct { + Client *redis.Client + KeyPrefix string + MaxLen int64 +} + +// NewAgentEventLog returns an event log helper with the supplied +// client. Pass zero values for any field to use the defaults +// (agent:events: / 1000). +func NewAgentEventLog(client *redis.Client, keyPrefix string, maxLen int64) *AgentEventLog { + if keyPrefix == "" { + keyPrefix = "agent:events:" + } + if maxLen <= 0 { + maxLen = DefaultMaxLen + } + return &AgentEventLog{ + Client: client, + KeyPrefix: keyPrefix, + MaxLen: maxLen, + } +} + +// StreamKey returns the Redis key for a thread id. +func (l *AgentEventLog) StreamKey(threadID string) string { + return l.KeyPrefix + threadID +} + +// Record appends one event and returns its stream id. +// +// `MAXLEN ~ N` (`Approx: true` on `XAddArgs`) keeps the stream +// bounded with near-zero overhead; an exact bound forces a scan and +// is rarely worth the cost. +func (l *AgentEventLog) Record(ctx context.Context, threadID, action, detail string) (string, error) { + id, err := l.Client.XAdd(ctx, &redis.XAddArgs{ + Stream: l.StreamKey(threadID), + MaxLen: l.MaxLen, + Approx: true, + Values: map[string]any{ + "action": action, + "detail": detail, + "ts": strconv.FormatFloat(unixSecs(), 'f', -1, 64), + }, + }).Result() + if err != nil { + return "", fmt.Errorf("XADD: %w", err) + } + return id, nil +} + +// Recent returns the most recent events, newest first. +func (l *AgentEventLog) Recent(ctx context.Context, threadID string, count int64) ([]AgentEvent, error) { + rows, err := l.Client.XRevRangeN(ctx, l.StreamKey(threadID), "+", "-", count).Result() + if err != nil { + return nil, fmt.Errorf("XREVRANGE: %w", err) + } + out := make([]AgentEvent, 0, len(rows)) + for _, row := range rows { + action, _ := row.Values["action"].(string) + detail, _ := row.Values["detail"].(string) + tsStr, _ := row.Values["ts"].(string) + ts, _ := strconv.ParseFloat(tsStr, 64) + out = append(out, AgentEvent{ + EventID: row.ID, + ThreadID: threadID, + Action: action, + Detail: detail, + TS: ts, + }) + } + return out, nil +} + +// Clear drops the entire stream for a thread. +func (l *AgentEventLog) Clear(ctx context.Context, threadID string) (bool, error) { + n, err := l.Client.Del(ctx, l.StreamKey(threadID)).Result() + if err != nil { + return false, fmt.Errorf("DEL: %w", err) + } + return n > 0, nil +} diff --git a/content/develop/use-cases/agent-memory/go/go.mod b/content/develop/use-cases/agent-memory/go/go.mod new file mode 100644 index 0000000000..cabe73f32b --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/go.mod @@ -0,0 +1,38 @@ +module github.com/redis/docs/content/develop/use-cases/agent-memory/go + +go 1.26.0 + +require ( + github.com/knights-analytics/hugot v0.7.3 + github.com/redis/go-redis/v9 v9.19.0 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/daulet/tokenizers v1.27.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-errors/errors v1.5.1 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/gofrs/flock v0.13.0 // indirect + github.com/gomlx/exceptions v0.0.3 // indirect + github.com/gomlx/go-huggingface v0.3.5 // indirect + github.com/gomlx/go-xla v0.2.2 // indirect + github.com/gomlx/gomlx v0.27.3 // indirect + github.com/gomlx/onnx-gomlx v0.4.2 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/knights-analytics/ortgenai v0.3.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/viant/afs v1.30.0 // indirect + github.com/x448/float16 v0.8.4 // indirect + github.com/yalue/onnxruntime_go v1.30.1 // indirect + go.uber.org/atomic v1.11.0 // indirect + golang.org/x/crypto v0.51.0 // indirect + golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect + golang.org/x/image v0.40.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/term v0.43.0 // indirect + golang.org/x/text v0.37.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect + k8s.io/klog/v2 v2.140.0 // indirect +) diff --git a/content/develop/use-cases/agent-memory/go/go.sum b/content/develop/use-cases/agent-memory/go/go.sum new file mode 100644 index 0000000000..cad2867528 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/go.sum @@ -0,0 +1,130 @@ +codeberg.org/go-fonts/liberation v0.5.0 h1:SsKoMO1v1OZmzkG2DY+7ZkCL9U+rrWI09niOLfQ5Bo0= +codeberg.org/go-fonts/liberation v0.5.0/go.mod h1:zS/2e1354/mJ4pGzIIaEtm/59VFCFnYC7YV6YdGl5GU= +codeberg.org/go-latex/latex v0.1.0 h1:hoGO86rIbWVyjtlDLzCqZPjNykpWQ9YuTZqAzPcfL3c= +codeberg.org/go-latex/latex v0.1.0/go.mod h1:LA0q/AyWIYrqVd+A9Upkgsb+IqPcmSTKc9Dny04MHMw= +codeberg.org/go-pdf/fpdf v0.10.0 h1:u+w669foDDx5Ds43mpiiayp40Ov6sZalgcPMDBcZRd4= +codeberg.org/go-pdf/fpdf v0.10.0/go.mod h1:Y0DGRAdZ0OmnZPvjbMp/1bYxmIPxm0ws4tfoPOc4LjU= +git.sr.ht/~sbinet/gg v0.6.0 h1:RIzgkizAk+9r7uPzf/VfbJHBMKUr0F5hRFxTUGMnt38= +git.sr.ht/~sbinet/gg v0.6.0/go.mod h1:uucygbfC9wVPQIfrmwM2et0imr8L7KQWywX0xpFMm94= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY= +github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex9t5KX76i20Q= +github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= +github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ= +github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI= +github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8= +github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= +github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= +github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/daulet/tokenizers v1.27.0 h1:MmFYAEDFz69s/nNQfHg59DWqHz3v94m99kEZ/JbL+s4= +github.com/daulet/tokenizers v1.27.0/go.mod h1:YjFY1o1HGMyWkQgbXJDghhvke/yFDp2vGdIO2hYs4MQ= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= +github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= +github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/gomlx/exceptions v0.0.3 h1:HKnTgEjj4jlmhr8zVFkTP9qmV1ey7ypYYosQ8GzXWuM= +github.com/gomlx/exceptions v0.0.3/go.mod h1:uHL0TQwJ0xaV2/snJOJV6hSE4yRmhhfymuYgNredGxU= +github.com/gomlx/go-huggingface v0.3.5 h1:eZz1huOvfr0TW30e11TkGAUZY4Jj5Oh/g0Thz4cvu0I= +github.com/gomlx/go-huggingface v0.3.5/go.mod h1:r/Z6JQTPm2nd6zHYKp6ig8ofQZK16+Rj9iqZpWq8OTQ= +github.com/gomlx/go-xla v0.2.2 h1:2YMzXAcmK8BvqFjRnUHHtE2QwKDEts2tRglcFcKhZj8= +github.com/gomlx/go-xla v0.2.2/go.mod h1:T2CsL/E90te3k4qpuzlXv2uQU2FmLMLfUsRlAGqKSuI= +github.com/gomlx/gomlx v0.27.3 h1:4cCcVi2m3lvMzDyZtepIl3+6cBGMTXhrYvQtOdtU5Z4= +github.com/gomlx/gomlx v0.27.3/go.mod h1:gqqTny0q1kcxml72T313SZy5U9pfX9c54NmzcYtzg5k= +github.com/gomlx/onnx-gomlx v0.4.2 h1:nBDbjzZOVMkCudk0AKMREHMdm54xNcp34dAte9aNwqQ= +github.com/gomlx/onnx-gomlx v0.4.2/go.mod h1:jh/oy07gw7aloPO3R8A2tHIVF7sVVXE2erp5IQCqlPY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/janpfeifer/go-benchmarks v0.1.1 h1:gLLy07/JrOKSnMWeUxSnjTdhkglgmrNR2IBDnR4kRqw= +github.com/janpfeifer/go-benchmarks v0.1.1/go.mod h1:5AagXCOUzevvmYFQalcgoa4oWPyH1IkZNckolGWfiSM= +github.com/janpfeifer/must v0.2.0 h1:yWy1CE5gtk1i2ICBvqAcMMXrCMqil9CJPkc7x81fRdQ= +github.com/janpfeifer/must v0.2.0/go.mod h1:S6c5Yg/YSMR43cJw4zhIq7HFMci90a7kPY9XA4c8UIs= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/knights-analytics/hugot v0.7.3 h1:39UqU52s4nAmNIE4JG5ViASCvd8dhue7XGtt5RhK3T4= +github.com/knights-analytics/hugot v0.7.3/go.mod h1:86tRz/GzyoNFHuUUzgiYnALQNZU8Vzd5F0pApYizwrs= +github.com/knights-analytics/ortgenai v0.3.1 h1:0Awe43Zu+giDxzlpoNvx9ekbez/zxc8XMzKU++sOUB8= +github.com/knights-analytics/ortgenai v0.3.1/go.mod h1:lSbQsRP5wY5NS+4W5CUGhdxjTzERQkR7WprAFxrBSt4= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w= +github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/schollz/progressbar/v3 v3.19.0 h1:Ea18xuIRQXLAUidVDox3AbwfUhD0/1IvohyTutOIFoc= +github.com/schollz/progressbar/v3 v3.19.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= +github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs= +github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/viant/afs v1.30.0 h1:dbgVVSCPwGHUgpgkWJ5gdjKBqssT7OV7Z2M81CjwZEY= +github.com/viant/afs v1.30.0/go.mod h1:rScbFd9LJPGTM8HOI8Kjwee0AZ+MZMupAvFpPg+Qdj4= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/yalue/onnxruntime_go v1.30.1 h1:NaEng5lWbsHZ/8X1dtaw1mIj7eV1ozyjbFo//g0ktl4= +github.com/yalue/onnxruntime_go v1.30.1/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw= +golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= +golang.org/x/image v0.40.0 h1:Tw4GyDXMo+daZN1znreBRC3VayR1aLFUyUEOLUdW1a8= +golang.org/x/image v0.40.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +gonum.org/v1/plot v0.15.2 h1:Tlfh/jBk2tqjLZ4/P8ZIwGrLEWQSPDLRm/SNWKNXiGI= +gonum.org/v1/plot v0.15.2/go.mod h1:DX+x+DWso3LTha+AdkJEv5Txvi+Tql3KAGkehP0/Ubg= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= +k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= diff --git a/content/develop/use-cases/agent-memory/go/index.html b/content/develop/use-cases/agent-memory/go/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/go/long_term_memory.go b/content/develop/use-cases/agent-memory/go/long_term_memory.go new file mode 100644 index 0000000000..a52f07e814 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/long_term_memory.go @@ -0,0 +1,616 @@ +// Long-term memory store for an agent, backed by Redis JSON and +// Search. +// +// Each memory lives as one JSON document at `agent:mem:`. The +// document holds the memory text, its embedding vector, and a small +// metadata block — user, namespace, kind, source thread, timestamps +// — that lets the recall query scope results without falling back to +// application-side filtering. +// +// A single Redis Search index covers the embedding plus every +// metadata field, so one `FT.SEARCH` call performs approximate- +// nearest-neighbour over the in-scope subset and returns the top-k +// memories ranked by cosine distance. The same KNN check runs at +// *write* time to deduplicate near-identical memories before they +// enter the store, which keeps the index from filling with +// paraphrases of the same fact as the agent reasons over similar +// topics across sessions. +// +// Memories carry one of two kinds: +// +// - `episodic` — "what happened" snapshots from a specific thread, +// written with a medium TTL so old session detail decays +// naturally. +// - `semantic` — distilled facts and preferences the agent should +// carry forward indefinitely. Written with no TTL by default. +// +// The split is enforced as a TAG on the index, so the recall query +// can ask for one kind or both with a filter — no separate +// keyspaces. + +package main + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +// VectorDim is the embedding dimension produced by the +// `sentence-transformers/all-MiniLM-L6-v2` model the demo uses. +const VectorDim = 384 + +// DefaultDedupThreshold is how close (cosine distance) a candidate +// must be to an existing memory to count as a duplicate at write +// time. Smaller = stricter. 0.20 is calibrated to MiniLM, where a +// paraphrase of an existing memory lands in the 0.10 – 0.20 range +// and a distinct memory lands above 0.50. +const DefaultDedupThreshold = 0.20 + +// DefaultRecallThreshold is how close (cosine distance) a candidate +// must be to count as a relevant recall result. Larger than the +// dedup threshold so the agent gets a wider net at read time than at +// write time. +const DefaultRecallThreshold = 0.55 + +// defaultTTLByKind holds the per-kind TTLs in seconds. A nil pointer +// means "no TTL" — the memory persists until explicitly deleted or +// evicted under memory pressure. +var defaultTTLByKind = map[string]*int{ + "episodic": intPtr(7 * 24 * 3600), + "semantic": nil, +} + +func intPtr(v int) *int { return &v } + +// MemoryRecord is the public shape of a memory document. +// +// `Distance` is set only when the record comes back from a KNN +// query; `TTLSeconds` is nil for memories with no TTL (e.g. +// `kind=semantic` under the default tier map). +type MemoryRecord struct { + ID string `json:"id"` + User string `json:"user"` + Namespace string `json:"namespace"` + Kind string `json:"kind"` + SourceThread string `json:"source_thread"` + Text string `json:"text"` + CreatedTS float64 `json:"created_ts"` + HitCount int `json:"hit_count"` + Distance *float64 `json:"distance,omitempty"` + TTLSeconds *int `json:"ttl_seconds"` +} + +// WriteResult is what `Remember` returns. +// +// `Deduped` is true when the write skipped because a similar memory +// already existed; `ID` is then the existing memory's id. +// `ExistingDistance` is the cosine distance to that nearest memory +// regardless of which branch was taken — useful for tracing. +type WriteResult struct { + ID string `json:"id"` + Deduped bool `json:"deduped"` + ExistingDistance *float64 `json:"existing_distance"` +} + +// IndexSnapshot is a small subset of FT.INFO useful for the demo UI. +type IndexSnapshot struct { + NumDocs int `json:"num_docs"` + IndexingFailures int `json:"indexing_failures"` +} + +// LongTermMemory owns the JSON documents, the vector index, the +// recall query, and the write-time deduplication. +type LongTermMemory struct { + Client *redis.Client + IndexName string + KeyPrefix string + VectorDim int + DedupThreshold float64 + RecallThreshold float64 + TTLByKind map[string]*int +} + +// NewLongTermMemory returns a memory helper with the supplied client. +// Pass zero values for any field to use the defaults. +func NewLongTermMemory( + client *redis.Client, + indexName, keyPrefix string, + vectorDim int, + dedupThreshold, recallThreshold float64, + ttlByKind map[string]*int, +) *LongTermMemory { + if indexName == "" { + indexName = "agentmem:idx" + } + if keyPrefix == "" { + keyPrefix = "agent:mem:" + } + if vectorDim <= 0 { + vectorDim = VectorDim + } + // The thresholds are honoured as-is. Zero is a legitimate value + // ("exact matches only" for dedup, "nothing recalls" for recall) + // and negative numbers are clamped by the HTTP boundary anyway. + // Silently rewriting `0` to a default would make + // `--dedup-threshold 0` and `--recall-threshold 0` uncallable. + if dedupThreshold < 0 { + dedupThreshold = DefaultDedupThreshold + } + if recallThreshold < 0 { + recallThreshold = DefaultRecallThreshold + } + if ttlByKind == nil { + ttlByKind = defaultTTLByKind + } + return &LongTermMemory{ + Client: client, + IndexName: indexName, + KeyPrefix: keyPrefix, + VectorDim: vectorDim, + DedupThreshold: dedupThreshold, + RecallThreshold: recallThreshold, + TTLByKind: ttlByKind, + } +} + +// MemoryKey returns the Redis key for a memory id. +func (m *LongTermMemory) MemoryKey(memoryID string) string { + return m.KeyPrefix + memoryID +} + +// CreateIndex creates the Redis Search index if it doesn't already +// exist. The index is declared on the JSON document type with alias +// names on each path; the same `FT.SEARCH` filter clause works here +// as on a HASH-backed index, and the field paths (`$.user`, +// `$.embedding`, ...) only show up in `FT.CREATE`. +func (m *LongTermMemory) CreateIndex(ctx context.Context) error { + _, err := m.Client.FTCreate(ctx, + m.IndexName, + &redis.FTCreateOptions{ + OnJSON: true, + Prefix: []any{m.KeyPrefix}, + }, + &redis.FieldSchema{FieldName: "$.text", As: "text", FieldType: redis.SearchFieldTypeText}, + &redis.FieldSchema{FieldName: "$.user", As: "user", FieldType: redis.SearchFieldTypeTag}, + &redis.FieldSchema{FieldName: "$.namespace", As: "namespace", FieldType: redis.SearchFieldTypeTag}, + &redis.FieldSchema{FieldName: "$.kind", As: "kind", FieldType: redis.SearchFieldTypeTag}, + &redis.FieldSchema{FieldName: "$.source_thread", As: "source_thread", FieldType: redis.SearchFieldTypeTag}, + &redis.FieldSchema{FieldName: "$.created_ts", As: "created_ts", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}, + &redis.FieldSchema{FieldName: "$.hit_count", As: "hit_count", FieldType: redis.SearchFieldTypeNumeric, Sortable: true}, + &redis.FieldSchema{ + FieldName: "$.embedding", + As: "embedding", + FieldType: redis.SearchFieldTypeVector, + VectorArgs: &redis.FTVectorArgs{ + HNSWOptions: &redis.FTHNSWOptions{ + Type: "FLOAT32", + Dim: m.VectorDim, + DistanceMetric: "COSINE", + }, + }, + }, + ).Result() + if err != nil && !strings.Contains(err.Error(), "Index already exists") { + return err + } + return nil +} + +// DropIndex drops the Redis Search index. If `deleteDocuments` is +// true the JSON memory documents are deleted alongside the index. +func (m *LongTermMemory) DropIndex(ctx context.Context, deleteDocuments bool) error { + _, err := m.Client.FTDropIndexWithArgs(ctx, + m.IndexName, + &redis.FTDropIndexOptions{DeleteDocs: deleteDocuments}, + ).Result() + if err == nil { + return nil + } + msg := strings.ToLower(err.Error()) + if strings.Contains(msg, "no such index") || strings.Contains(msg, "unknown index name") { + return nil + } + return err +} + +// RememberParams collects the fields of a new memory. +type RememberParams struct { + Text string + Embedding []float32 + User string // default "default" + Namespace string // default "default" + Kind string // "episodic" or "semantic" + SourceThread string + // TTLSeconds: nil => resolve from kind; non-nil pointer to 0 => no TTL. + TTLSeconds *int +} + +// Remember writes a new memory, deduplicating against existing +// entries. +// +// Runs one in-scope KNN(1) against the index first. If the nearest +// existing memory is within `DedupThreshold`, the new memory is +// skipped (its content is already represented) and the existing +// memory's `hit_count` is bumped via `JSON.NUMINCRBY`. Otherwise a +// fresh JSON document is written under a new id with a TTL derived +// from the memory's `kind`. +// +// The KNN-then-write sequence is not atomic; two workers that +// remember the same fact at the same time can both miss each other's +// in-flight write and insert duplicate memories. See the +// walkthrough's "Concurrency caveats" section for the production fix. +func (m *LongTermMemory) Remember(ctx context.Context, p RememberParams) (WriteResult, error) { + if len(p.Embedding) != m.VectorDim { + return WriteResult{}, fmt.Errorf( + "embedding length is %d; index expects %d", + len(p.Embedding), m.VectorDim, + ) + } + user := p.User + if user == "" { + user = "default" + } + ns := p.Namespace + if ns == "" { + ns = "default" + } + kind := p.Kind + if kind == "" { + kind = "episodic" + } + + nearest, err := m.nearest(ctx, p.Embedding, user, ns, kind, 1) + if err != nil { + return WriteResult{}, err + } + var existingDistance *float64 + if len(nearest) > 0 && nearest[0].Distance != nil { + existingDistance = nearest[0].Distance + if *existingDistance <= m.DedupThreshold { + m.bumpHitCount(ctx, nearest[0].ID) + return WriteResult{ + ID: nearest[0].ID, + Deduped: true, + ExistingDistance: existingDistance, + }, nil + } + } + + id, err := newMemoryID() + if err != nil { + return WriteResult{}, fmt.Errorf("generating memory id: %w", err) + } + key := m.MemoryKey(id) + now := unixSecs() + doc := map[string]any{ + "id": id, + "user": user, + "namespace": ns, + "kind": kind, + "source_thread": p.SourceThread, + "text": p.Text, + "embedding": p.Embedding, + "created_ts": now, + "hit_count": 0, + } + + ttl := p.TTLSeconds + if ttl == nil { + ttl = m.TTLByKind[kind] + } + + // MULTI/EXEC so JSON.SET and EXPIRE either both apply or neither + // does. A connection drop (or context cancellation) between the + // two writes would otherwise leave the memory without an expiry + // — the index entry would still be there, but the document would + // outlive its intended `kind`-derived TTL. + if _, err := m.Client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.JSONSet(ctx, key, "$", doc) + // `nil` is the "no TTL" sentinel; an explicit zero falls + // through to `EXPIRE key 0`, which Redis treats as immediate + // expiry — same contract as the Python, Node, Rust, .NET, + // Java, PHP, and Ruby ports. + if ttl != nil { + pipe.Expire(ctx, key, time.Duration(*ttl)*time.Second) + } + return nil + }); err != nil { + return WriteResult{}, fmt.Errorf("remember MULTI/EXEC: %w", err) + } + return WriteResult{ID: id, Deduped: false, ExistingDistance: existingDistance}, nil +} + +// RecallParams collects the optional filters for a recall query. +type RecallParams struct { + QueryEmbedding []float32 + User string + Namespace string + Kind string // empty => any kind + K int // 0 => 5 + DistanceThreshold *float64 +} + +// Recall returns the top-k in-scope memories ranked by similarity. +// Memories beyond `DistanceThreshold` (or the instance default) are +// dropped — the index always returns *something* for KNN, so a +// recall result on an unrelated query would otherwise be a +// confidently-wrong false positive. +func (m *LongTermMemory) Recall(ctx context.Context, p RecallParams) ([]MemoryRecord, error) { + k := p.K + if k <= 0 { + k = 5 + } + threshold := m.RecallThreshold + if p.DistanceThreshold != nil { + threshold = *p.DistanceThreshold + } + // Match `Remember` defaulting so a Recall called with empty + // user/namespace stays scoped to the same `"default"` keys writes + // land under, not the unscoped `(*)` shape `buildMemoryFilter` + // would otherwise produce. + user := p.User + if user == "" { + user = "default" + } + ns := p.Namespace + if ns == "" { + ns = "default" + } + candidates, err := m.nearest(ctx, p.QueryEmbedding, user, ns, p.Kind, k) + if err != nil { + return nil, err + } + out := make([]MemoryRecord, 0, len(candidates)) + for _, c := range candidates { + if c.Distance != nil && *c.Distance <= threshold { + out = append(out, c) + } + } + return out, nil +} + +func (m *LongTermMemory) nearest( + ctx context.Context, + embedding []float32, + user, namespace, kind string, + k int, +) ([]MemoryRecord, error) { + if len(embedding) != m.VectorDim { + return nil, fmt.Errorf( + "embedding length is %d; index expects %d", + len(embedding), m.VectorDim, + ) + } + filterClause := buildMemoryFilter(user, namespace, kind) + queryStr := fmt.Sprintf("%s=>[KNN %d @embedding $vec AS distance]", filterClause, k) + + res, err := m.Client.FTSearchWithArgs(ctx, + m.IndexName, + queryStr, + &redis.FTSearchOptions{ + DialectVersion: 2, + Params: map[string]any{"vec": FloatsToBytes(embedding)}, + SortBy: []redis.FTSearchSortBy{ + {FieldName: "distance", Asc: true}, + }, + Return: []redis.FTSearchReturn{ + {FieldName: "user"}, + {FieldName: "namespace"}, + {FieldName: "kind"}, + {FieldName: "source_thread"}, + {FieldName: "text"}, + {FieldName: "created_ts"}, + {FieldName: "hit_count"}, + {FieldName: "distance"}, + }, + LimitOffset: 0, + Limit: k, + }, + ).Result() + if err != nil { + return nil, fmt.Errorf("FT.SEARCH: %w", err) + } + out := make([]MemoryRecord, 0, len(res.Docs)) + for _, doc := range res.Docs { + // `doc.ID` is the full Redis key (e.g. `agent:mem:abc123`). + // Strip the prefix so the returned record exposes only the + // opaque id the UI and `DeleteMemory` work with. + memoryID := strings.TrimPrefix(doc.ID, m.KeyPrefix) + ttl, _ := m.Client.TTL(ctx, m.MemoryKey(memoryID)).Result() + var ttlSeconds *int + if ttl > 0 { + s := int(ttl / time.Second) + ttlSeconds = &s + } + distanceVal, _ := strconv.ParseFloat(doc.Fields["distance"], 64) + distance := distanceVal + hitCount, _ := strconv.Atoi(doc.Fields["hit_count"]) + createdTS, _ := strconv.ParseFloat(doc.Fields["created_ts"], 64) + out = append(out, MemoryRecord{ + ID: memoryID, + User: doc.Fields["user"], + Namespace: doc.Fields["namespace"], + Kind: doc.Fields["kind"], + SourceThread: doc.Fields["source_thread"], + Text: doc.Fields["text"], + CreatedTS: createdTS, + HitCount: hitCount, + Distance: &distance, + TTLSeconds: ttlSeconds, + }) + } + return out, nil +} + +func (m *LongTermMemory) bumpHitCount(ctx context.Context, memoryID string) { + // Fire-and-forget — the doc may have expired between recall and + // bump (search index lags TTL by its periodic scan). Discarding + // the error keeps the demo from blowing up on that race; we just + // lose the hit count update. + _, _ = m.Client.JSONNumIncrBy(ctx, m.MemoryKey(memoryID), "$.hit_count", 1).Result() +} + +// IndexInfo returns a small subset of FT.INFO. Failures (e.g. an +// index that hasn't been created yet) return zeroed counters rather +// than surface as an error, since the demo UI just renders "0 +// entries" in that case. +func (m *LongTermMemory) IndexInfo(ctx context.Context) IndexSnapshot { + info, err := m.Client.FTInfo(ctx, m.IndexName).Result() + if err != nil { + return IndexSnapshot{} + } + return IndexSnapshot{ + NumDocs: info.NumDocs, + IndexingFailures: info.HashIndexingFailures, + } +} + +// ListMemories returns memories matching the filters, newest first. +func (m *LongTermMemory) ListMemories( + ctx context.Context, + user, namespace, kind string, + limit int, +) ([]MemoryRecord, error) { + if limit <= 0 { + limit = 100 + } + // Match `Remember` defaulting so a list called with an empty + // user/namespace stays scoped to the same `"default"` keys + // `Remember` writes under, not an unscoped `(*)` query. + if user == "" { + user = "default" + } + if namespace == "" { + namespace = "default" + } + filterClause := buildMemoryFilter(user, namespace, kind) + res, err := m.Client.FTSearchWithArgs(ctx, + m.IndexName, + filterClause, + &redis.FTSearchOptions{ + DialectVersion: 2, + Return: []redis.FTSearchReturn{ + {FieldName: "user"}, + {FieldName: "namespace"}, + {FieldName: "kind"}, + {FieldName: "source_thread"}, + {FieldName: "text"}, + {FieldName: "created_ts"}, + {FieldName: "hit_count"}, + }, + SortBy: []redis.FTSearchSortBy{ + {FieldName: "created_ts", Desc: true}, + }, + LimitOffset: 0, + Limit: limit, + }, + ).Result() + if err != nil { + return nil, fmt.Errorf("FT.SEARCH: %w", err) + } + out := make([]MemoryRecord, 0, len(res.Docs)) + for _, doc := range res.Docs { + memoryID := strings.TrimPrefix(doc.ID, m.KeyPrefix) + ttl, _ := m.Client.TTL(ctx, m.MemoryKey(memoryID)).Result() + var ttlSeconds *int + if ttl > 0 { + s := int(ttl / time.Second) + ttlSeconds = &s + } + hitCount, _ := strconv.Atoi(doc.Fields["hit_count"]) + createdTS, _ := strconv.ParseFloat(doc.Fields["created_ts"], 64) + out = append(out, MemoryRecord{ + ID: memoryID, + User: doc.Fields["user"], + Namespace: doc.Fields["namespace"], + Kind: doc.Fields["kind"], + SourceThread: doc.Fields["source_thread"], + Text: doc.Fields["text"], + CreatedTS: createdTS, + HitCount: hitCount, + Distance: nil, + TTLSeconds: ttlSeconds, + }) + } + // Belt-and-braces sort in case Redis returns an unsorted top-N. + sort.SliceStable(out, func(i, j int) bool { + return out[i].CreatedTS > out[j].CreatedTS + }) + return out, nil +} + +// DeleteMemory drops a single memory. Returns true if the key +// existed. +func (m *LongTermMemory) DeleteMemory(ctx context.Context, memoryID string) (bool, error) { + n, err := m.Client.Del(ctx, m.MemoryKey(memoryID)).Result() + if err != nil { + return false, fmt.Errorf("DEL: %w", err) + } + return n > 0, nil +} + +// Clear drops the index and every memory document, then recreates +// the index. Returns the count of documents that were removed. +func (m *LongTermMemory) Clear(ctx context.Context) (int, error) { + before := m.IndexInfo(ctx).NumDocs + if err := m.DropIndex(ctx, true); err != nil { + return 0, err + } + if err := m.CreateIndex(ctx); err != nil { + return 0, err + } + return before, nil +} + +// ----- Filter clause helpers ----------------------------------------- + +func buildMemoryFilter(user, namespace, kind string) string { + var clauses []string + if user != "" { + clauses = append(clauses, "@user:{"+escapeTagValue(user)+"}") + } + if namespace != "" { + clauses = append(clauses, "@namespace:{"+escapeTagValue(namespace)+"}") + } + if kind != "" { + clauses = append(clauses, "@kind:{"+escapeTagValue(kind)+"}") + } + if len(clauses) == 0 { + return "(*)" + } + return "(" + strings.Join(clauses, " ") + ")" +} + +// tagSpecialMem is the set of characters Redis Search treats as +// syntax inside a TAG value; any of them in a user-supplied filter +// must be backslash-escaped or the surrounding `{...}` block won't +// parse correctly. +var tagSpecialMem = map[rune]struct{}{ + '\\': {}, ',': {}, '.': {}, '<': {}, '>': {}, '{': {}, '}': {}, + '[': {}, ']': {}, '"': {}, '\'': {}, ':': {}, ';': {}, '!': {}, + '@': {}, '#': {}, '$': {}, '%': {}, '^': {}, '&': {}, '*': {}, + '(': {}, ')': {}, '-': {}, '+': {}, '=': {}, '~': {}, '|': {}, + ' ': {}, +} + +func escapeTagValue(v string) string { + var b strings.Builder + b.Grow(len(v)) + for _, r := range v { + if _, ok := tagSpecialMem[r]; ok { + b.WriteByte('\\') + } + b.WriteRune(r) + } + return b.String() +} + +func newMemoryID() (string, error) { + return newThreadID() +} diff --git a/content/develop/use-cases/agent-memory/go/main.go b/content/develop/use-cases/agent-memory/go/main.go new file mode 100644 index 0000000000..792920c07f --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/main.go @@ -0,0 +1,666 @@ +// Redis agent-memory demo server (Go). +// +// Run this file and visit http://localhost:8090 to drive a small +// agent-memory demo backed by Redis Hashes, JSON, Search, and +// Streams. The UI lets you type a turn, watch working memory update, +// see semantically similar long-term memories recalled, watch the +// write-time deduplication skip near-duplicates, and inspect the +// per-thread event log. +// +// The server holds a single `LocalEmbedder`, one `AgentSession`, one +// `LongTermMemory`, and one `AgentEventLog` for the lifetime of the +// process. The first run downloads the embedding model into the +// local `./models` directory; everything after is local. + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "math" + "net/http" + "os" + "os/signal" + "path/filepath" + "runtime/debug" + "strconv" + "strings" + "sync" + "syscall" + "time" + + "github.com/redis/go-redis/v9" +) + +// stackLabel is the badge the same HTML uses to identify which +// language demo the user is looking at. +const stackLabel = "go-redis + Hugot + net/http" + +// maxBodyBytes caps POST bodies so a runaway client (or a `curl +// --data-binary @big-file` by mistake) can't accumulate unbounded +// memory before the handler runs. The demo's largest legitimate body +// is a few hundred bytes of form-encoded query fields. +const maxBodyBytes = 1 * 1024 * 1024 + +// AgentMemoryDemo owns the three Redis-backed helpers and the +// embedder for the lifetime of the process. +// +// `SeedAll` / `NewThread` / `HandleTurn` all touch `currentThreadID` +// — `currentThreadID` is wrapped in a `sync.Mutex`, but the lock is +// released after each rotation or read, so a turn racing with +// `/new_thread` or `/reset` can capture the old id and apply to the +// previous thread. The demo is single-user in practice, so the race +// never triggers; a multi-user agent would carry the thread id on +// each request instead of holding it as shared server state. See the +// walkthrough's "Concurrency caveats" section. +type AgentMemoryDemo struct { + Session *AgentSession + Memory *LongTermMemory + Events *AgentEventLog + Embedder *LocalEmbedder + DefaultUser string + DefaultNamespace string + + mu sync.Mutex + currentThreadID string +} + +// NewAgentMemoryDemo wires the helpers together and seeds an initial +// thread id. +func NewAgentMemoryDemo( + session *AgentSession, + memory *LongTermMemory, + events *AgentEventLog, + embedder *LocalEmbedder, +) *AgentMemoryDemo { + return &AgentMemoryDemo{ + Session: session, + Memory: memory, + Events: events, + Embedder: embedder, + DefaultUser: "default", + DefaultNamespace: "default", + currentThreadID: session.NewThreadID(), + } +} + +// CurrentThreadID returns the demo's active thread id. +func (d *AgentMemoryDemo) CurrentThreadID() string { + d.mu.Lock() + defer d.mu.Unlock() + return d.currentThreadID +} + +// SeedAll drops every long-term memory, every working-memory hash +// for the active thread, and the active event stream, then re-seeds +// the canonical memories and starts a fresh thread. +func (d *AgentMemoryDemo) SeedAll(ctx context.Context, user, namespace string) (int, error) { + if _, err := d.Memory.Clear(ctx); err != nil { + return 0, err + } + threadID := d.CurrentThreadID() + if _, err := d.Session.Delete(ctx, threadID); err != nil { + return 0, err + } + if _, err := d.Events.Clear(ctx, threadID); err != nil { + return 0, err + } + written, err := Seed(ctx, d.Memory, d.Embedder, user, namespace, "seed") + if err != nil { + return written, err + } + d.mu.Lock() + d.currentThreadID = d.Session.NewThreadID() + d.mu.Unlock() + return written, nil +} + +// NewThread starts a fresh thread. Long-term memory is unaffected. +func (d *AgentMemoryDemo) NewThread(ctx context.Context, user, namespace string) (string, error) { + oldID := d.CurrentThreadID() + if _, err := d.Events.Clear(ctx, oldID); err != nil { + return "", err + } + newID := d.Session.NewThreadID() + if _, err := d.Session.Start(ctx, newID, StartParams{ + User: user, + Agent: "demo-agent", + Goal: "", + }); err != nil { + return "", err + } + if _, err := d.Events.Record(ctx, newID, "thread_started", + fmt.Sprintf("user=%s namespace=%s", user, namespace)); err != nil { + return "", err + } + d.mu.Lock() + d.currentThreadID = newID + d.mu.Unlock() + return newID, nil +} + +// TurnParams collects what /turn accepts. +type TurnParams struct { + Text string + User string + Namespace string + Kind string // "episodic", "semantic", or "skip" + Role string + Threshold float64 + Action string // "turn" or "goal" +} + +// HandleTurn runs one pass through the agent loop: append, recall, +// remember, log. +// +// The order matters. We embed once and reuse the vector for both the +// recall and (if asked) the remember step — no point encoding the +// same text twice. Recall runs *before* the remember write so the +// agent doesn't see its own just-written turn as a recalled memory. +func (d *AgentMemoryDemo) HandleTurn(ctx context.Context, p TurnParams) (map[string]any, error) { + threadID := d.CurrentThreadID() + + t0 := time.Now() + vec, err := d.Embedder.EncodeOne(ctx, p.Text) + if err != nil { + return nil, fmt.Errorf("embed: %w", err) + } + embedMs := msSince(t0) + + // `SetGoal` only touches the goal field so existing turns aren't + // wiped; `AppendTurn` carries the request `user` through to the + // auto-create path so a first turn for a new thread doesn't land + // under the default user. + var sessionAction string + if p.Action == "goal" { + if _, err := d.Session.SetGoal(ctx, threadID, p.Text, StartParams{ + User: p.User, + Agent: "demo-agent", + }); err != nil { + return nil, fmt.Errorf("set goal: %w", err) + } + sessionAction = "goal_set" + } else { + if _, err := d.Session.AppendTurn(ctx, threadID, AppendTurnParams{ + Role: p.Role, + Content: p.Text, + User: p.User, + Agent: "demo-agent", + }); err != nil { + return nil, fmt.Errorf("append turn: %w", err) + } + sessionAction = "turn_appended:" + p.Role + } + + t1 := time.Now() + threshold := p.Threshold + recalled, err := d.Memory.Recall(ctx, RecallParams{ + QueryEmbedding: vec, + User: p.User, + Namespace: p.Namespace, + K: 5, + DistanceThreshold: &threshold, + }) + if err != nil { + return nil, fmt.Errorf("recall: %w", err) + } + recallMs := msSince(t1) + + writeSkipped := p.Kind == "skip" || p.Action == "goal" + var writeResult *WriteResult + var writeMs float64 + if !writeSkipped { + t2 := time.Now() + r, err := d.Memory.Remember(ctx, RememberParams{ + Text: p.Text, + Embedding: vec, + User: p.User, + Namespace: p.Namespace, + Kind: p.Kind, + SourceThread: threadID, + }) + if err != nil { + return nil, fmt.Errorf("remember: %w", err) + } + writeResult = &r + writeMs = msSince(t2) + } + + var detail string + if writeResult != nil { + if writeResult.Deduped { + detail = "deduped onto " + writeResult.ID + } else { + detail = "wrote " + writeResult.ID + " as " + p.Kind + } + } + if _, err := d.Events.Record(ctx, threadID, sessionAction, detail); err != nil { + return nil, fmt.Errorf("event log: %w", err) + } + + payload := map[string]any{ + "thread_id": threadID, + "write_skipped": writeSkipped, + "memory_id": nil, + "deduped": false, + "existing_distance": nil, + "kind": nil, + "recalled": recalled, + "embed_ms": embedMs, + "recall_ms": recallMs, + "write_ms": writeMs, + } + if writeResult != nil { + payload["memory_id"] = writeResult.ID + payload["deduped"] = writeResult.Deduped + payload["existing_distance"] = writeResult.ExistingDistance + } + if !writeSkipped { + payload["kind"] = p.Kind + } + return payload, nil +} + +// ---- /state shape --------------------------------------------------- + +type stateIndex struct { + NumDocs int `json:"num_docs"` + IndexingFailures int `json:"indexing_failures"` + IndexName string `json:"index_name"` + Model string `json:"model"` + SessionTTLSeconds int `json:"session_ttl_seconds"` + DedupThreshold float64 `json:"dedup_threshold"` + DefaultRecallThreshold float64 `json:"default_recall_threshold"` + StackLabel string `json:"stack_label"` +} + +type stateResponse struct { + Index stateIndex `json:"index"` + ThreadID string `json:"thread_id"` + Session *SessionState `json:"session"` + Memories []MemoryRecord `json:"memories"` + Events []AgentEvent `json:"events"` + // `recalled` is populated by /turn; on plain /state reads the UI + // keeps showing the last turn's result, which is the useful + // behaviour for an "agent" panel. + Recalled []MemoryRecord `json:"recalled"` +} + +func (d *AgentMemoryDemo) buildState(ctx context.Context, user, namespace string) (stateResponse, error) { + info := d.Memory.IndexInfo(ctx) + threadID := d.CurrentThreadID() + session, err := d.Session.Load(ctx, threadID) + if err != nil { + return stateResponse{}, err + } + memories, err := d.Memory.ListMemories(ctx, user, namespace, "", 200) + if err != nil { + return stateResponse{}, err + } + events, err := d.Events.Recent(ctx, threadID, 20) + if err != nil { + return stateResponse{}, err + } + return stateResponse{ + Index: stateIndex{ + NumDocs: info.NumDocs, + IndexingFailures: info.IndexingFailures, + IndexName: d.Memory.IndexName, + Model: d.Embedder.ModelName, + SessionTTLSeconds: d.Session.DefaultTTLSeconds, + DedupThreshold: d.Memory.DedupThreshold, + DefaultRecallThreshold: d.Memory.RecallThreshold, + StackLabel: stackLabel, + }, + ThreadID: threadID, + Session: session, + Memories: memories, + Events: events, + Recalled: []MemoryRecord{}, + }, nil +} + +// ---- HTTP plumbing -------------------------------------------------- + +func sendJSON(w http.ResponseWriter, payload any, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(payload); err != nil { + log.Printf("[demo] encode: %v", err) + } +} + +func sendHTML(w http.ResponseWriter, html string, status int) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(status) + _, _ = io.WriteString(w, html) +} + +// clampThreshold sanitises the threshold parameter from the form +// body. `strconv.ParseFloat` happily handles "nan" / "inf"; either +// would silently turn recall into "every memory" or "nothing". +// Clamp to the meaningful cosine-distance range so a malformed POST +// can't override the threshold semantics. +func clampThreshold(raw string, fallback float64) float64 { + parsed, err := strconv.ParseFloat(strings.TrimSpace(raw), 64) + if err != nil || math.IsNaN(parsed) || math.IsInf(parsed, 0) { + return fallback + } + if parsed < 0 { + return 0 + } + if parsed > 2 { + return 2 + } + return parsed +} + +// jsonRecover turns any panic in a handler into a JSON 500 instead +// of the default plain-text stack trace. Without this the client's +// `await res.json()` would explode with an opaque parse error. +func jsonRecover(w http.ResponseWriter) { + if rec := recover(); rec != nil { + log.Printf("[demo] panic: %v\n%s", rec, debug.Stack()) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": fmt.Sprintf("%v", rec), + "type": "panic", + }) + } +} + +func makeHandler(demo *AgentMemoryDemo, htmlPage string) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + defer jsonRecover(w) + if r.URL.Path != "/" && r.URL.Path != "/index.html" { + sendJSON(w, map[string]any{"error": "not found"}, http.StatusNotFound) + return + } + if r.Method != http.MethodGet { + sendJSON(w, map[string]any{"error": "method not allowed"}, http.StatusMethodNotAllowed) + return + } + sendHTML(w, htmlPage, http.StatusOK) + }) + + mux.HandleFunc("/state", func(w http.ResponseWriter, r *http.Request) { + defer jsonRecover(w) + if r.Method != http.MethodGet { + sendJSON(w, map[string]any{"error": "method not allowed"}, http.StatusMethodNotAllowed) + return + } + q := r.URL.Query() + user := orDefault(q.Get("user"), demo.DefaultUser) + namespace := orDefault(q.Get("namespace"), demo.DefaultNamespace) + s, err := demo.buildState(r.Context(), user, namespace) + if err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusInternalServerError) + return + } + sendJSON(w, s, http.StatusOK) + }) + + mux.HandleFunc("/turn", func(w http.ResponseWriter, r *http.Request) { + defer jsonRecover(w) + if r.Method != http.MethodPost { + sendJSON(w, map[string]any{"error": "method not allowed"}, http.StatusMethodNotAllowed) + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) + if err := r.ParseForm(); err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusBadRequest) + return + } + text := strings.TrimSpace(r.PostForm.Get("text")) + if text == "" { + sendJSON(w, map[string]any{"error": "text is required"}, http.StatusBadRequest) + return + } + payload, err := demo.HandleTurn(r.Context(), TurnParams{ + Text: text, + User: orDefault(r.PostForm.Get("user"), "default"), + Namespace: orDefault(r.PostForm.Get("namespace"), "default"), + Kind: orDefault(r.PostForm.Get("kind"), "episodic"), + Role: orDefault(r.PostForm.Get("role"), "user"), + Threshold: clampThreshold(r.PostForm.Get("threshold"), demo.Memory.RecallThreshold), + Action: orDefault(r.PostForm.Get("action"), "turn"), + }) + if err != nil { + log.Printf("[demo] turn: %v", err) + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusInternalServerError) + return + } + sendJSON(w, payload, http.StatusOK) + }) + + mux.HandleFunc("/new_thread", func(w http.ResponseWriter, r *http.Request) { + defer jsonRecover(w) + if r.Method != http.MethodPost { + sendJSON(w, map[string]any{"error": "method not allowed"}, http.StatusMethodNotAllowed) + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) + if err := r.ParseForm(); err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusBadRequest) + return + } + user := orDefault(r.PostForm.Get("user"), "default") + namespace := orDefault(r.PostForm.Get("namespace"), "default") + tid, err := demo.NewThread(r.Context(), user, namespace) + if err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusInternalServerError) + return + } + sendJSON(w, map[string]any{"thread_id": tid}, http.StatusOK) + }) + + mux.HandleFunc("/reset", func(w http.ResponseWriter, r *http.Request) { + defer jsonRecover(w) + if r.Method != http.MethodPost { + sendJSON(w, map[string]any{"error": "method not allowed"}, http.StatusMethodNotAllowed) + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) + if err := r.ParseForm(); err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusBadRequest) + return + } + user := orDefault(r.PostForm.Get("user"), "default") + namespace := orDefault(r.PostForm.Get("namespace"), "default") + seeded, err := demo.SeedAll(r.Context(), user, namespace) + if err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusInternalServerError) + return + } + sendJSON(w, map[string]any{"seeded": seeded}, http.StatusOK) + }) + + mux.HandleFunc("/drop_memory", func(w http.ResponseWriter, r *http.Request) { + defer jsonRecover(w) + if r.Method != http.MethodPost { + sendJSON(w, map[string]any{"error": "method not allowed"}, http.StatusMethodNotAllowed) + return + } + r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) + if err := r.ParseForm(); err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusBadRequest) + return + } + memoryID := strings.TrimSpace(r.PostForm.Get("memory_id")) + if memoryID == "" { + sendJSON(w, map[string]any{"error": "memory_id is required"}, http.StatusBadRequest) + return + } + deleted, err := demo.Memory.DeleteMemory(r.Context(), memoryID) + if err != nil { + sendJSON(w, map[string]any{"error": err.Error()}, http.StatusInternalServerError) + return + } + sendJSON(w, map[string]any{"deleted": deleted, "memory_id": memoryID}, http.StatusOK) + }) + + return mux +} + +func msSince(t time.Time) float64 { + return float64(time.Since(t)) / float64(time.Millisecond) +} + +// ---- Main ----------------------------------------------------------- + +type flags struct { + host string + port int + redisHost string + redisPort int + memIndexName string + memKeyPrefix string + sessionKeyPrefix string + eventKeyPrefix string + sessionTTLSeconds int + dedupThreshold float64 + recallThreshold float64 + noReset bool +} + +func parseFlags() flags { + var f flags + flag.StringVar(&f.host, "host", "127.0.0.1", "interface to bind to") + flag.IntVar(&f.port, "port", 8090, "HTTP port for the demo UI") + flag.StringVar(&f.redisHost, "redis-host", "localhost", "Redis host") + flag.IntVar(&f.redisPort, "redis-port", 6379, "Redis port") + flag.StringVar(&f.memIndexName, "mem-index-name", "agentmem:idx", "memory search index name") + flag.StringVar(&f.memKeyPrefix, "mem-key-prefix", "agent:mem:", "JSON memory key prefix") + flag.StringVar(&f.sessionKeyPrefix, "session-key-prefix", "agent:session:", "session hash key prefix") + flag.StringVar(&f.eventKeyPrefix, "event-key-prefix", "agent:events:", "event stream key prefix") + flag.IntVar(&f.sessionTTLSeconds, "session-ttl-seconds", 3600, "TTL applied to every session hash write") + flag.Float64Var(&f.dedupThreshold, "dedup-threshold", DefaultDedupThreshold, "cosine-distance cutoff for write-time dedup") + flag.Float64Var(&f.recallThreshold, "recall-threshold", DefaultRecallThreshold, "default cosine-distance cutoff for recall") + flag.BoolVar(&f.noReset, "no-reset", false, "skip clearing and re-seeding on startup") + flag.Parse() + return f +} + +func main() { + f := parseFlags() + + ctx := context.Background() + client := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", f.redisHost, f.redisPort), + Protocol: 2, + }) + if err := client.Ping(ctx).Err(); err != nil { + fmt.Fprintf(os.Stderr, "Error: cannot reach Redis at %s:%d\n (%v)\n", + f.redisHost, f.redisPort, err) + os.Exit(1) + } + + session := NewAgentSession(client, f.sessionKeyPrefix, f.sessionTTLSeconds, DefaultMaxTurns) + memory := NewLongTermMemory( + client, + f.memIndexName, + f.memKeyPrefix, + VectorDim, + f.dedupThreshold, + f.recallThreshold, + nil, + ) + if err := memory.CreateIndex(ctx); err != nil { + log.Fatalf("creating index: %v", err) + } + events := NewAgentEventLog(client, f.eventKeyPrefix, DefaultMaxLen) + + fmt.Println("Loading embedding model (first run downloads the ONNX weights)...") + embedder, err := NewLocalEmbedder(ctx, "", "") + if err != nil { + log.Fatalf("loading embedder: %v", err) + } + defer embedder.Close() + + demo := NewAgentMemoryDemo(session, memory, events, embedder) + + if !f.noReset { + fmt.Printf("Dropping any existing memories under '%s*' and re-seeding from the sample memory list (pass --no-reset to keep).\n", + f.memKeyPrefix) + seeded, err := demo.SeedAll(ctx, "default", "default") + if err != nil { + log.Fatalf("seeding: %v", err) + } + fmt.Printf("Seeded %d memories.\n", seeded) + } + + // Load index.html once and substitute the template tokens so the + // docs panel shows the actual values in use. + htmlPath, err := locateIndexHTML() + if err != nil { + log.Fatalf("locating index.html: %v", err) + } + rawHTML, err := os.ReadFile(htmlPath) + if err != nil { + log.Fatalf("reading %s: %v", htmlPath, err) + } + htmlPage := string(rawHTML) + htmlPage = strings.ReplaceAll(htmlPage, "__SESSION_PREFIX__", f.sessionKeyPrefix) + htmlPage = strings.ReplaceAll(htmlPage, "__MEM_PREFIX__", f.memKeyPrefix) + htmlPage = strings.ReplaceAll(htmlPage, "__MEM_INDEX__", f.memIndexName) + htmlPage = strings.ReplaceAll(htmlPage, "__EVENT_PREFIX__", f.eventKeyPrefix) + + srv := &http.Server{ + Addr: fmt.Sprintf("%s:%d", f.host, f.port), + Handler: makeHandler(demo, htmlPage), + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + } + + fmt.Printf("Redis agent memory demo listening on http://%s:%d\n", f.host, f.port) + fmt.Printf("Using Redis at %s:%d with memory index '%s'\n", + f.redisHost, f.redisPort, f.memIndexName) + + // Run the server in a goroutine so we can shut down cleanly on + // SIGINT/SIGTERM. + errCh := make(chan error, 1) + go func() { + errCh <- srv.ListenAndServe() + }() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + select { + case sig := <-sigCh: + fmt.Printf("\nReceived %s, shutting down...\n", sig) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(shutdownCtx) + case err := <-errCh: + if err != nil && err != http.ErrServerClosed { + log.Fatalf("server: %v", err) + } + } +} + +// locateIndexHTML looks for `index.html` first next to the executable +// (so `go install` + run from anywhere keeps working) and then in the +// process's working directory (the `go run .` case). +func locateIndexHTML() (string, error) { + exe, err := os.Executable() + if err == nil { + candidate := filepath.Join(filepath.Dir(exe), "index.html") + if _, statErr := os.Stat(candidate); statErr == nil { + return candidate, nil + } + } + cwd, err := os.Getwd() + if err == nil { + candidate := filepath.Join(cwd, "index.html") + if _, statErr := os.Stat(candidate); statErr == nil { + return candidate, nil + } + } + return "", fmt.Errorf("index.html not found next to binary or in cwd") +} diff --git a/content/develop/use-cases/agent-memory/go/seed_memory.go b/content/develop/use-cases/agent-memory/go/seed_memory.go new file mode 100644 index 0000000000..d4f200e204 --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/seed_memory.go @@ -0,0 +1,110 @@ +// Pre-seed the long-term memory store with sample memories. +// +// In a real deployment the memory store fills up organically as the +// agent reasons over user turns: each turn produces zero or more +// memories (preferences, facts, episodic summaries) that flow into +// the store with deduplication. To make the demo immediately useful +// — so the first recall query lands on relevant results instead of +// an empty list — we seed a small set of canonical memories for a +// default user at startup. +// +// The seed list mixes `semantic` memories (long-lived preferences +// and facts) with `episodic` memories (snapshots of past sessions), +// matching what the Python, Node, .NET, and Rust demos seed so all +// five implementations behave identically. + +package main + +import ( + "context" + "fmt" +) + +// SeedEntry is one row of the seed list. +type SeedEntry struct { + Text string + Kind string +} + +// SeedMemories is the canonical mixed list. Order matters only for +// the demo's "first eight memories" display; the dedup KNN means a +// re-seed against an existing store will report zero new writes. +var SeedMemories = []SeedEntry{ + { + Text: "The user prefers concise answers without filler phrases.", + Kind: "semantic", + }, + { + Text: "The user is a Python developer working on a logistics platform.", + Kind: "semantic", + }, + { + Text: "The user lives in Berlin and works in the Europe/Berlin time zone.", + Kind: "semantic", + }, + { + Text: "The user dislikes dark mode and prefers a high-contrast light " + + "theme in editors and dashboards.", + Kind: "semantic", + }, + { + Text: "The user is allergic to peanuts; any restaurant suggestion must " + + "avoid dishes that commonly contain them.", + Kind: "semantic", + }, + { + Text: "Last Tuesday the user asked the agent to draft a postmortem for " + + "the order-routing outage. The agent produced a five-section " + + "draft and the user approved sections 1, 2, and 4 with minor " + + "edits.", + Kind: "episodic", + }, + { + Text: "In a previous session the user asked for help debugging a flaky " + + "test in the inventory service. The fix turned out to be a race " + + "condition in the warehouse webhook handler.", + Kind: "episodic", + }, + { + Text: "Two weeks ago the user mentioned they were planning to migrate " + + "the analytics warehouse from Snowflake to BigQuery in Q3.", + Kind: "episodic", + }, +} + +// Seed embeds and writes the seed memories. Returns the count +// actually written (entries that dedup against existing memories +// don't count). +func Seed( + ctx context.Context, + memory *LongTermMemory, + embedder *LocalEmbedder, + user, namespace, sourceThread string, +) (int, error) { + texts := make([]string, len(SeedMemories)) + for i, s := range SeedMemories { + texts[i] = s.Text + } + vectors, err := embedder.EncodeMany(ctx, texts) + if err != nil { + return 0, fmt.Errorf("encoding seed batch: %w", err) + } + written := 0 + for i, entry := range SeedMemories { + res, err := memory.Remember(ctx, RememberParams{ + Text: entry.Text, + Embedding: vectors[i], + User: user, + Namespace: namespace, + Kind: entry.Kind, + SourceThread: sourceThread, + }) + if err != nil { + return written, fmt.Errorf("remembering seed %d: %w", i, err) + } + if !res.Deduped { + written++ + } + } + return written, nil +} diff --git a/content/develop/use-cases/agent-memory/go/session_store.go b/content/develop/use-cases/agent-memory/go/session_store.go new file mode 100644 index 0000000000..886773d97d --- /dev/null +++ b/content/develop/use-cases/agent-memory/go/session_store.go @@ -0,0 +1,351 @@ +// Working-memory store for an agent session, backed by a Redis Hash. +// +// Each session is one Hash document at `agent:session:{ThreadID}`. +// The hash holds the running scratchpad, the current goal, a rolling +// window of recent turns (serialised as a JSON list to fit in one +// field), and a few audit fields. One `HGETALL` returns the whole +// session in a single round trip on every step of the agent loop. +// +// Every write refreshes the key's TTL with `EXPIRE`, so idle sessions +// fall off without a separate cleanup job and active sessions stay +// alive as long as the agent keeps touching them. A separate +// `LongTermMemory` is what survives beyond a session's TTL. +// +// The turn window is bounded to `MaxTurns` in application code; the +// hash itself doesn't grow, so the working set per thread stays +// constant regardless of how long the agent has been running. + +package main + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/redis/go-redis/v9" +) + +// DefaultMaxTurns is how many recent turns to keep inline on the +// session hash. Older turns flow through the event log +// (AgentEventLog) and the long-term memory store (LongTermMemory). +const DefaultMaxTurns = 20 + +// SessionTurn is a single role/content/timestamp triple inside the +// rolling session window. +type SessionTurn struct { + Role string `json:"role"` + Content string `json:"content"` + TS float64 `json:"ts"` +} + +// SessionState is the full per-thread working-memory state. +type SessionState struct { + ThreadID string `json:"thread_id"` + User string `json:"user"` + Agent string `json:"agent"` + Goal string `json:"goal"` + Scratchpad string `json:"scratchpad"` + TurnCount int `json:"turn_count"` + CreatedTS float64 `json:"created_ts"` + LastActiveTS float64 `json:"last_active_ts"` + RecentTurns []SessionTurn `json:"recent_turns"` + TTLSeconds int `json:"ttl_seconds"` +} + +// AgentSession owns the working-memory Hash and the rolling turn +// window for a single agent thread. +type AgentSession struct { + Client *redis.Client + KeyPrefix string + DefaultTTLSeconds int + MaxTurns int +} + +// NewAgentSession returns a session helper with the supplied client. +// Pass zero values for any field to use the defaults +// (agent:session: / 3600 / 20). +func NewAgentSession( + client *redis.Client, + keyPrefix string, + defaultTTLSeconds int, + maxTurns int, +) *AgentSession { + if keyPrefix == "" { + keyPrefix = "agent:session:" + } + if defaultTTLSeconds <= 0 { + defaultTTLSeconds = 3600 + } + if maxTurns <= 0 { + maxTurns = DefaultMaxTurns + } + return &AgentSession{ + Client: client, + KeyPrefix: keyPrefix, + DefaultTTLSeconds: defaultTTLSeconds, + MaxTurns: maxTurns, + } +} + +// SessionKey returns the Redis key for a thread id. +func (s *AgentSession) SessionKey(threadID string) string { + return s.KeyPrefix + threadID +} + +// NewThreadID returns a random 12-hex-character id, matching the +// shape the Python, Node, .NET, and Rust helpers produce. +func (s *AgentSession) NewThreadID() string { + id, err := newThreadID() + if err != nil { + // crypto/rand failing is a kernel-level problem we can't + // gracefully recover from inside a demo; surface a panic so + // the operator sees what's wrong rather than handing back an + // empty string that would silently collide with other + // threads. + panic(fmt.Errorf("generating thread id: %w", err)) + } + return id +} + +// StartParams collects the optional fields for `Start`. Using a +// struct keeps the call site readable when only a couple of fields +// are set. +type StartParams struct { + User string + Agent string + Goal string + TTLSeconds int // 0 => use DefaultTTLSeconds +} + +// Start creates a fresh working memory for a thread. Overwrites any +// existing session at the same key. The agent normally calls this +// once per thread at the first turn and relies on `Load` / +// `AppendTurn` for subsequent steps. +func (s *AgentSession) Start( + ctx context.Context, + threadID string, + p StartParams, +) (*SessionState, error) { + user := p.User + if user == "" { + user = "default" + } + agent := p.Agent + if agent == "" { + agent = "default" + } + ttl := p.TTLSeconds + if ttl <= 0 { + ttl = s.DefaultTTLSeconds + } + now := unixSecs() + state := &SessionState{ + ThreadID: threadID, + User: user, + Agent: agent, + Goal: p.Goal, + Scratchpad: "", + TurnCount: 0, + CreatedTS: now, + LastActiveTS: now, + RecentTurns: []SessionTurn{}, + TTLSeconds: ttl, + } + if err := s.write(ctx, state, ttl); err != nil { + return nil, err + } + return state, nil +} + +// Load returns the session state, or `nil` if it has expired. +func (s *AgentSession) Load(ctx context.Context, threadID string) (*SessionState, error) { + key := s.SessionKey(threadID) + raw, err := s.Client.HGetAll(ctx, key).Result() + if err != nil { + return nil, fmt.Errorf("HGETALL: %w", err) + } + if len(raw) == 0 { + return nil, nil + } + ttl, _ := s.Client.TTL(ctx, key).Result() + ttlSeconds := int(ttl / time.Second) + if ttlSeconds < 0 { + ttlSeconds = 0 + } + var turns []SessionTurn + if blob, ok := raw["recent_turns"]; ok && blob != "" { + _ = json.Unmarshal([]byte(blob), &turns) + } + turnCount, _ := strconv.Atoi(raw["turn_count"]) + createdTS, _ := strconv.ParseFloat(raw["created_ts"], 64) + lastActiveTS, _ := strconv.ParseFloat(raw["last_active_ts"], 64) + return &SessionState{ + ThreadID: threadID, + User: orDefault(raw["user"], "default"), + Agent: orDefault(raw["agent"], "default"), + Goal: raw["goal"], + Scratchpad: raw["scratchpad"], + TurnCount: turnCount, + CreatedTS: createdTS, + LastActiveTS: lastActiveTS, + RecentTurns: turns, + TTLSeconds: ttlSeconds, + }, nil +} + +// AppendTurnParams collects the optional fields for `AppendTurn`. +type AppendTurnParams struct { + Role string + // Content is the turn text. + Content string + // User and Agent are only consulted when the session does not yet + // exist — they seed the auto-created session so the + // working-memory hash matches the user the caller is operating + // against. On an existing session they're ignored; the original + // `Start` values stand. + User string + Agent string + TTLSeconds int // 0 => use DefaultTTLSeconds +} + +// AppendTurn appends a turn, bounds the rolling window, and refreshes +// the TTL. +// +// Read-modify-write here is last-writer-wins on the turn list if two +// concurrent turns reach the same thread; the demo never triggers +// that race in practice (one browser, one turn at a time) but a +// multi-worker agent that shares a thread id would wrap this in +// `WATCH` / `MULTI` / `EXEC` or a Lua script that does the append +// atomically server-side. +func (s *AgentSession) AppendTurn( + ctx context.Context, + threadID string, + p AppendTurnParams, +) (*SessionState, error) { + state, err := s.Load(ctx, threadID) + if err != nil { + return nil, err + } + if state == nil { + state, err = s.Start(ctx, threadID, StartParams{ + User: p.User, + Agent: p.Agent, + TTLSeconds: p.TTLSeconds, + }) + if err != nil { + return nil, err + } + } + state.RecentTurns = append(state.RecentTurns, SessionTurn{ + Role: p.Role, + Content: p.Content, + TS: unixSecs(), + }) + if len(state.RecentTurns) > s.MaxTurns { + state.RecentTurns = state.RecentTurns[len(state.RecentTurns)-s.MaxTurns:] + } + state.TurnCount++ + state.LastActiveTS = unixSecs() + ttl := p.TTLSeconds + if ttl <= 0 { + ttl = s.DefaultTTLSeconds + } + state.TTLSeconds = ttl + if err := s.write(ctx, state, ttl); err != nil { + return nil, err + } + return state, nil +} + +// SetGoal updates the goal field without touching turns or the +// scratchpad. Creates the session if it doesn't exist yet — setting +// a goal on a fresh thread is a sensible first step in the agent +// loop, so this method covers both the "rename the goal mid-session" +// and the "start a thread with this goal" cases. +func (s *AgentSession) SetGoal( + ctx context.Context, + threadID, text string, + p StartParams, +) (*SessionState, error) { + state, err := s.Load(ctx, threadID) + if err != nil { + return nil, err + } + if state == nil { + p.Goal = text + return s.Start(ctx, threadID, p) + } + state.Goal = text + state.LastActiveTS = unixSecs() + ttl := p.TTLSeconds + if ttl <= 0 { + ttl = s.DefaultTTLSeconds + } + state.TTLSeconds = ttl + if err := s.write(ctx, state, ttl); err != nil { + return nil, err + } + return state, nil +} + +// Delete drops the session immediately. Returns true if it existed. +func (s *AgentSession) Delete(ctx context.Context, threadID string) (bool, error) { + n, err := s.Client.Del(ctx, s.SessionKey(threadID)).Result() + if err != nil { + return false, fmt.Errorf("DEL: %w", err) + } + return n > 0, nil +} + +func (s *AgentSession) write(ctx context.Context, state *SessionState, ttl int) error { + key := s.SessionKey(state.ThreadID) + turnsBlob, err := json.Marshal(state.RecentTurns) + if err != nil { + return fmt.Errorf("marshalling recent_turns: %w", err) + } + mapping := map[string]any{ + "thread_id": state.ThreadID, + "user": state.User, + "agent": state.Agent, + "goal": state.Goal, + "scratchpad": state.Scratchpad, + "turn_count": strconv.Itoa(state.TurnCount), + "created_ts": strconv.FormatFloat(state.CreatedTS, 'f', -1, 64), + "last_active_ts": strconv.FormatFloat(state.LastActiveTS, 'f', -1, 64), + "recent_turns": string(turnsBlob), + } + // MULTI/EXEC so HSET and EXPIRE either both apply or neither + // does. A connection drop between the two writes would otherwise + // leave the session without a TTL. + if _, err := s.Client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.HSet(ctx, key, mapping) + pipe.Expire(ctx, key, time.Duration(ttl)*time.Second) + return nil + }); err != nil { + return fmt.Errorf("session write MULTI/EXEC: %w", err) + } + return nil +} + +func newThreadID() (string, error) { + var b [6]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + return hex.EncodeToString(b[:]), nil +} + +func unixSecs() float64 { + return float64(time.Now().UnixNano()) / 1e9 +} + +func orDefault(s, fallback string) string { + if s == "" { + return fallback + } + return s +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/.gitignore b/content/develop/use-cases/agent-memory/java-jedis/.gitignore new file mode 100644 index 0000000000..434dfcf2d3 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/.gitignore @@ -0,0 +1,6 @@ +target/ +.idea/ +*.iml +.classpath +.project +.settings/ diff --git a/content/develop/use-cases/agent-memory/java-jedis/_index.md b/content/develop/use-cases/agent-memory/java-jedis/_index.md new file mode 100644 index 0000000000..50187de1c1 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/_index.md @@ -0,0 +1,344 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Java with Jedis, DJL (PyTorch), and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: Jedis example (Java) +title: Redis agent memory with Jedis +weight: 6 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Java with [Jedis]({{< relref "/develop/clients/jedis" >}}) and [DJL](https://djl.ai/) (the Deep Java Library), using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with the JDK's [`com.sun.net.httpserver`](https://docs.oracle.com/en/java/javase/17/docs/api/jdk.httpserver/com/sun/net/httpserver/package-summary.html) so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [DJL](https://djl.ai/) (`ai.djl.huggingface.tokenizers` + `ai.djl.pytorch.pytorch-model-zoo`) running the canonical `sentence-transformers/all-MiniLM-L6-v2` PyTorch checkpoint — the same library and model the existing [Jedis vector-search example]({{< relref "/develop/clients/jedis/vecsearch" >}}) uses, and the same encoder the [Python]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) example loads. DJL drives libtorch through the same C++ runtime as Python's PyTorch bindings, so the vectors produced here are numerically identical to the Python ones to within rounding noise, and the distance bands the Python walkthrough quotes carry over to this demo without recalibration. A memory written by one demo can be recalled by the other against the same Redis instance. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* One Redis Search call per recall: [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) does the KNN + TAG pre-filter in a single round trip (a per-row [`TTL`]({{< relref "/commands/ttl" >}}) follow-up is the only other read the helper issues, just to populate the `ttl_seconds` field for the admin panel). Working memory is one [`HGETALL`]({{< relref "/commands/hgetall" >}}); the event log is one [`XADD`]({{< relref "/commands/xadd" >}}). +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.encodeOne(text)` to turn the incoming turn into a 384-element `float[]`. +2. `session.appendTurn(threadId, role, content, user, agent, null)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside Jedis's `multi()` transaction. The session TTL refreshes on every write so an active thread stays alive. +3. `memory.recall(vec, user, namespace, null, 5, threshold)` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.remember(text, vec, user, namespace, kind, threadId, null)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}) — best-effort: if the memory's TTL has elapsed between the recall and the bump, the increment quietly fails and the hit count for that recall is lost. Otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) inside the same `multi()` transaction. +5. `events.record(threadId, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) (`XAddParams.approximateTrimming()`), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentSession.java)): + +```java +import com.redis.agentmem.AgentSession; +import com.redis.agentmem.SessionState; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.JedisPooled; + +JedisPooled jedis = new JedisPooled(new HostAndPort("localhost", 6379)); +AgentSession session = new AgentSession(jedis, "agent:session:", 3600, 20); + +String threadId = session.newThreadId(); +SessionState state = session.start(threadId, "alice", "demo-agent", + "Plan next week's meetings.", null); +state = session.appendTurn( + threadId, "user", "Schedule a budget review with finance.", + "alice", "demo-agent", null); +System.out.println(state.turnCount() + " " + state.recentTurns().size() + + " " + state.ttlSeconds()); +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `appendTurn`, `setGoal` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside `jedis.multi()` so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LongTermMemory.java)): + +```java +import com.redis.agentmem.LocalEmbedder; +import com.redis.agentmem.LongTermMemory; +import com.redis.agentmem.MemoryRecord; +import com.redis.agentmem.WriteResult; + +LongTermMemory memory = new LongTermMemory( + jedis, + "agentmem:idx", + "agent:mem:", + 384, + 0.20, // dedup threshold — tight at write time + 0.55, // recall threshold — looser at read time + null); // default per-kind TTL map +LocalEmbedder embedder = LocalEmbedder.create(); +memory.createIndex(); // idempotent + +// Write a memory. The same KNN that powers recall also runs here at +// a tighter threshold so paraphrases of the same fact collapse. +float[] vec = embedder.encodeOne("The user prefers light mode in editors."); +WriteResult result = memory.remember( + "The user prefers light mode in editors.", + vec, + "alice", + "default", + "semantic", + "9f3d2a4b8c61", + null); +System.out.printf("deduped=%s id=%s dist=%s%n", + result.deduped(), result.id(), result.existingDistance()); + +// Recall against a later question. +float[] q = embedder.encodeOne("Which theme does this user like?"); +for (MemoryRecord h : memory.recall(q, "alice", "default", null, 5, null)) { + System.out.printf("%.3f [%s] %s%n", h.distance(), h.kind(), h.text()); +} +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is stored as a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (`LocalEmbedder.toBytes()` packs them in little-endian order), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with alias names on each path so the query syntax stays compact. Jedis spells the alias with `.as("...")`: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write by passing a non-null `ttlSeconds` to `remember`, or hand a different `Map` to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEventLog.java)): + +```java +import com.redis.agentmem.AgentEvent; +import com.redis.agentmem.AgentEventLog; + +AgentEventLog events = new AgentEventLog(jedis, "agent:events:", 1000); +events.record(threadId, "turn_appended:user", + "Schedule a budget review with finance."); +events.record(threadId, "memory_written", + "wrote 7c3f8a1b9e02 as semantic"); + +for (AgentEvent e : events.recent(threadId, 20)) { + System.out.println(e.action() + " " + e.detail()); +} +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000` via `XAddParams.xAddParams().maxLen(1000).approximateTrimming()`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarisers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession.appendTurn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recentTurns` list in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recentTurns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory.remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `currentThreadId` synchronized through an explicit mutex; `seedAll`, `newThread`, and `handleTurn` each release the lock between operations, so a turn racing with a thread rotation can capture the old id and apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +A separate concern specific to Java + DJL: the `LocalEmbedder.encodeOne` / `encodeMany` methods are `synchronized` because the underlying `Predictor` is not thread-safe. The demo's `Executors.newCachedThreadPool` could otherwise call into one predictor from several handler threads at once and corrupt the inference state. A higher-throughput deployment would replace that lock with a small pool of `Predictor` instances or a dedicated single-threaded inference executor. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `SeedMemory` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SeedMemory.java)): + +```java +import com.redis.agentmem.SeedMemory; + +LongTermMemory memory = new LongTermMemory(jedis, "agentmem:idx", + "agent:mem:", 384, 0.20, 0.55, null); +LocalEmbedder embedder = LocalEmbedder.create(); +memory.createIndex(); +int written = SeedMemory.seed(memory, embedder, "default", "default", "seed"); +System.out.println("seeded " + written + " memories"); +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`DemoServer` runs the JDK's [`HttpServer`](https://docs.oracle.com/en/java/javase/17/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html) on port 8092, with a cached thread pool dispatching requests to handlers. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a mutex-protected `String` field that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/java-jedis + ``` + +2. Build the fat jar. You'll need a [JDK 17](https://adoptium.net/) or later and + [Maven](https://maven.apache.org/): + + ```bash + mvn -q package + ``` + + The first build pulls Jedis, DJL, and the PyTorch native libraries — that takes + a couple of minutes the first time and is cached afterwards. + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo. The first run downloads the `sentence-transformers/all-MiniLM-L6-v2` + PyTorch weights into the local DJL cache (~90 MB): + + ```bash + java -jar target/agent-memory-jedis.jar + ``` + + Or via Maven: `mvn -q exec:java`. + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing + recalls above the threshold (every seed lands above 0.8), and the new + memory is written as `episodic` (or `semantic`, depending on the + dropdown) under a fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + DJL drives libtorch through the same C++ kernel as Python's PyTorch + bindings, so distances here match the Python demo to four decimal + places. `sentence-transformers/all-MiniLM-L6-v2` puts a faithful + paraphrase in the 0.15 – 0.50 cosine-distance range, a loose paraphrase + or related topic in the 0.50 – 0.80 range, and unrelated queries above + 0.8 — which is what motivates the 0.55 default recall threshold and the + 0.20 default dedup threshold. A stricter embedding model (or a + domain-tuned one) would let you tighten both; a noisier one would push + them up. The right thresholds are always a function of the model, the + corpus, and how conservative the agent needs to be about accepting a + memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags (pass them after the jar): + +* `--host` / `--port` — change the HTTP bind address (default `127.0.0.1:8092`). +* `--redis-host` / `--redis-port` — point at a non-local Redis (default `localhost:6379`). +* `--mem-index-name` / `--mem-key-prefix` / `--session-key-prefix` / `--event-key-prefix` — relocate the index name and the three key prefixes (to run several demos against one Redis without colliding, for example). +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/java-jedis/index.html b/content/develop/use-cases/agent-memory/java-jedis/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/java-jedis/pom.xml b/content/develop/use-cases/agent-memory/java-jedis/pom.xml new file mode 100644 index 0000000000..a25b43e75e --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/pom.xml @@ -0,0 +1,136 @@ + + + 4.0.0 + + com.redis + agent-memory-jedis + 1.0.0 + jar + + Redis Agent Memory Demo (Jedis) + + Interactive agent-memory demo backed by Redis Hashes, JSON, + Search, and Streams, using Jedis for Redis access and DJL + (PyTorch) for local sentence embeddings. + + + + 17 + 17 + UTF-8 + 5.2.0 + 0.33.0 + 20240303 + + + + + + redis.clients + jedis + ${jedis.version} + + + + + ai.djl + api + ${djl.version} + + + ai.djl.huggingface + tokenizers + ${djl.version} + + + ai.djl.pytorch + pytorch-model-zoo + ${djl.version} + + + + + org.json + json + ${json.version} + + + + + agent-memory-jedis + + + + ${project.basedir} + + index.html + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + 17 + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.3 + + + package + shade + + false + + + com.redis.agentmem.DemoServer + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.5.0 + + com.redis.agentmem.DemoServer + + + + + diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEvent.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEvent.java new file mode 100644 index 0000000000..e4b085def8 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEvent.java @@ -0,0 +1,18 @@ +package com.redis.agentmem; + +/** + * One entry from the per-thread event Stream. + * + *

{@code eventId} is the {@code XADD}-assigned stream id (e.g. + * {@code 1715990400123-0}); {@code ts} is the wall-clock time the + * action happened, stored as a Redis Stream field rather than + * inferred from the stream id because the demo timestamps the action + * on the agent side. + */ +public record AgentEvent( + String eventId, + String threadId, + String action, + String detail, + double ts) { +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEventLog.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEventLog.java new file mode 100644 index 0000000000..832c8a0758 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentEventLog.java @@ -0,0 +1,115 @@ +package com.redis.agentmem; + +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.StreamEntryID; +import redis.clients.jedis.params.XAddParams; +import redis.clients.jedis.resps.StreamEntry; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Append-only event log for an agent thread, backed by a Redis + * Stream. + * + *

Each thread gets a stream at {@code agent:events:{threadId}}. + * Every action the agent takes (a user turn arriving, a memory being + * recalled, a memory being written, a tool being called) is one + * {@code XADD} to that stream. Replay with {@code XREVRANGE} for the + * most recent N events; bound retention with {@code XTRIM MAXLEN ~} + * (Jedis's {@code XAddParams.approximateTrimming()}) so the log + * stays cheap regardless of how long the thread has been running. + * + *

The stream is independent of the session hash and the long-term + * memory store: it answers the "what just happened" question without + * competing with either of those for indexing or memory budget. + * Consumer groups (not used in this demo) would let downstream + * workers — summarisers, consolidators, audit pipelines — replay the + * log without losing position. + */ +public final class AgentEventLog { + + /** Approximate cap on stream length. */ + public static final long DEFAULT_MAX_LEN = 1000L; + + private final JedisPooled jedis; + private final String keyPrefix; + private final long maxLen; + + public AgentEventLog(JedisPooled jedis, String keyPrefix, long maxLen) { + this.jedis = jedis; + this.keyPrefix = keyPrefix; + this.maxLen = maxLen > 0 ? maxLen : DEFAULT_MAX_LEN; + } + + public String keyPrefix() { + return keyPrefix; + } + + public long maxLen() { + return maxLen; + } + + public String streamKey(String threadId) { + return keyPrefix + threadId; + } + + /** + * Append one event and return its stream id. + * + *

{@code MAXLEN ~ N} ({@code approximateTrimming()}) keeps the + * stream bounded with near-zero overhead; the exact form forces a + * scan and is rarely worth the cost. + */ + public String record(String threadId, String action, String detail) { + Map hash = new HashMap<>(); + hash.put("action", action == null ? "" : action); + hash.put("detail", detail == null ? "" : detail); + hash.put("ts", String.format(Locale.ROOT, "%.6f", unixSecs())); + + XAddParams params = XAddParams.xAddParams() + .maxLen(maxLen) + .approximateTrimming(); + StreamEntryID id = jedis.xadd(streamKey(threadId), params, hash); + return id == null ? "" : id.toString(); + } + + /** Return the most recent events, newest first. */ + public List recent(String threadId, int count) { + // xrevrange(key, end, start, count): newest-first iteration + // from "+" (highest id) towards "-" (lowest id). + List entries = jedis.xrevrange( + streamKey(threadId), + StreamEntryID.MAXIMUM_ID, + StreamEntryID.MINIMUM_ID, + count); + List out = new ArrayList<>(entries.size()); + for (StreamEntry entry : entries) { + Map fields = entry.getFields(); + out.add(new AgentEvent( + entry.getID().toString(), + threadId, + fields.getOrDefault("action", ""), + fields.getOrDefault("detail", ""), + parseDouble(fields.get("ts"), 0.0))); + } + return out; + } + + /** Drop the entire stream for a thread. */ + public boolean clear(String threadId) { + return jedis.del(streamKey(threadId)) > 0; + } + + private static double unixSecs() { + return System.currentTimeMillis() / 1000.0; + } + + private static double parseDouble(String s, double fallback) { + if (s == null || s.isEmpty()) return fallback; + try { return Double.parseDouble(s); } catch (NumberFormatException ex) { return fallback; } + } +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentSession.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentSession.java new file mode 100644 index 0000000000..f05bfa0c69 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/AgentSession.java @@ -0,0 +1,281 @@ +package com.redis.agentmem; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import redis.clients.jedis.AbstractTransaction; +import redis.clients.jedis.JedisPooled; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; + +/** + * Working-memory store for an agent session, backed by a Redis Hash. + * + *

Each session is one Hash document at + * {@code agent:session:{threadId}}. The hash holds the running + * scratchpad, the current goal, a rolling window of recent turns + * (serialized as a JSON list to fit in one field), and a few audit + * fields. One {@code HGETALL} returns the whole session in a single + * round trip on every step of the agent loop. + * + *

Every write refreshes the key's TTL with {@code EXPIRE}, so + * idle sessions fall off without a separate cleanup job and active + * sessions stay alive as long as the agent keeps touching them. A + * separate {@link LongTermMemory} is what survives beyond a session's + * TTL. + * + *

The turn window is bounded to {@link #maxTurns()} in application + * code; the hash itself doesn't grow, so the working set per thread + * stays constant regardless of how long the agent has been running. + */ +public final class AgentSession { + + /** Default rolling-window size. */ + public static final int DEFAULT_MAX_TURNS = 20; + + private final JedisPooled jedis; + private final String keyPrefix; + private final long defaultTtlSeconds; + private final int maxTurns; + + public AgentSession( + JedisPooled jedis, + String keyPrefix, + long defaultTtlSeconds, + int maxTurns) { + this.jedis = jedis; + this.keyPrefix = keyPrefix; + this.defaultTtlSeconds = defaultTtlSeconds; + this.maxTurns = maxTurns > 0 ? maxTurns : DEFAULT_MAX_TURNS; + } + + public String keyPrefix() { + return keyPrefix; + } + + public long defaultTtlSeconds() { + return defaultTtlSeconds; + } + + public int maxTurns() { + return maxTurns; + } + + public String sessionKey(String threadId) { + return keyPrefix + threadId; + } + + /** Random 12-hex-character thread id, matching the other ports. */ + public String newThreadId() { + return UUID.randomUUID().toString().replace("-", "").substring(0, 12); + } + + /** + * Create a fresh working memory for a thread. Overwrites any + * existing session at the same key. The agent normally calls + * this once per thread at the first turn and relies on + * {@link #load} / {@link #appendTurn} for subsequent steps. + */ + public SessionState start( + String threadId, String user, String agent, String goal, Long ttlSeconds) { + if (user == null || user.isEmpty()) user = "default"; + if (agent == null || agent.isEmpty()) agent = "default"; + if (goal == null) goal = ""; + long ttl = ttlSeconds != null ? ttlSeconds : defaultTtlSeconds; + double now = unixSecs(); + SessionState state = new SessionState( + threadId, user, agent, goal, "", + 0L, now, now, List.of(), ttl); + write(state, ttl); + return state; + } + + /** Return the session state, or {@code null} if it has expired. */ + public SessionState load(String threadId) { + String key = sessionKey(threadId); + Map raw = jedis.hgetAll(key); + if (raw == null || raw.isEmpty()) { + return null; + } + long ttl = jedis.ttl(key); + if (ttl < 0) ttl = 0; + List turns = parseTurns(raw.get("recent_turns")); + return new SessionState( + threadId, + orDefault(raw.get("user"), "default"), + orDefault(raw.get("agent"), "default"), + orEmpty(raw.get("goal")), + orEmpty(raw.get("scratchpad")), + parseLong(raw.get("turn_count"), 0L), + parseDouble(raw.get("created_ts"), 0.0), + parseDouble(raw.get("last_active_ts"), 0.0), + turns, + ttl); + } + + /** + * Append a turn, bound the rolling window, refresh the TTL. + * + *

{@code user} and {@code agent} are only consulted when the + * session does not yet exist — they seed the auto-created session + * so the working-memory hash matches the user the caller is + * operating against. On an existing session they're ignored; the + * original {@code start} values stand. + * + *

Read-modify-write here is last-writer-wins on the turn list + * if two concurrent turns reach the same thread; the demo never + * triggers that race in practice (one browser, one turn at a + * time) but a multi-worker agent that shares a thread id would + * wrap this in {@code WATCH} / {@code MULTI} / {@code EXEC} or a + * Lua script that does the append atomically server-side. + */ + public SessionState appendTurn( + String threadId, + String role, + String content, + String user, + String agent, + Long ttlSeconds) { + SessionState state = load(threadId); + if (state == null) { + state = start(threadId, user, agent, "", ttlSeconds); + } + List turns = new ArrayList<>(state.recentTurns()); + turns.add(new SessionTurn(role, content == null ? "" : content, unixSecs())); + if (turns.size() > maxTurns) { + turns = turns.subList(turns.size() - maxTurns, turns.size()); + } + long ttl = ttlSeconds != null ? ttlSeconds : defaultTtlSeconds; + SessionState next = new SessionState( + state.threadId(), + state.user(), + state.agent(), + state.goal(), + state.scratchpad(), + state.turnCount() + 1, + state.createdTs(), + unixSecs(), + turns, + ttl); + write(next, ttl); + return next; + } + + /** + * Update the goal field without touching turns or the scratchpad. + * Creates the session if it doesn't exist yet — setting a goal + * on a fresh thread is a sensible first step in the agent loop, + * so this method covers both the "rename the goal mid-session" + * and the "start a thread with this goal" cases. + */ + public SessionState setGoal( + String threadId, + String text, + String user, + String agent, + Long ttlSeconds) { + SessionState state = load(threadId); + if (state == null) { + return start(threadId, user, agent, text == null ? "" : text, ttlSeconds); + } + long ttl = ttlSeconds != null ? ttlSeconds : defaultTtlSeconds; + SessionState next = new SessionState( + state.threadId(), + state.user(), + state.agent(), + text == null ? "" : text, + state.scratchpad(), + state.turnCount(), + state.createdTs(), + unixSecs(), + state.recentTurns(), + ttl); + write(next, ttl); + return next; + } + + /** Drop the session immediately. Returns {@code true} if it existed. */ + public boolean delete(String threadId) { + return jedis.del(sessionKey(threadId)) > 0; + } + + private void write(SessionState state, long ttl) { + String key = sessionKey(state.threadId()); + + JSONArray turnsArr = new JSONArray(); + for (SessionTurn t : state.recentTurns()) { + JSONObject obj = new JSONObject(); + obj.put("role", t.role()); + obj.put("content", t.content()); + obj.put("ts", t.ts()); + turnsArr.put(obj); + } + + Map mapping = new HashMap<>(); + mapping.put("thread_id", state.threadId()); + mapping.put("user", state.user()); + mapping.put("agent", state.agent()); + mapping.put("goal", state.goal()); + mapping.put("scratchpad", state.scratchpad()); + mapping.put("turn_count", Long.toString(state.turnCount())); + mapping.put("created_ts", + String.format(Locale.ROOT, "%.6f", state.createdTs())); + mapping.put("last_active_ts", + String.format(Locale.ROOT, "%.6f", state.lastActiveTs())); + mapping.put("recent_turns", turnsArr.toString()); + + // MULTI/EXEC so HSET and EXPIRE either both apply or neither + // does. A connection drop between the two writes would + // otherwise leave the session without a TTL. + try (AbstractTransaction tx = jedis.multi()) { + tx.hset(key, mapping); + tx.expire(key, ttl); + tx.exec(); + } + } + + private static List parseTurns(String blob) { + if (blob == null || blob.isEmpty()) return List.of(); + try { + JSONArray arr = new JSONArray(blob); + List out = new ArrayList<>(arr.length()); + for (int i = 0; i < arr.length(); i++) { + JSONObject o = arr.getJSONObject(i); + out.add(new SessionTurn( + o.optString("role", ""), + o.optString("content", ""), + o.optDouble("ts", 0.0))); + } + return out; + } catch (JSONException ex) { + return List.of(); + } + } + + private static String orDefault(String s, String fallback) { + return (s == null || s.isEmpty()) ? fallback : s; + } + + private static String orEmpty(String s) { + return s == null ? "" : s; + } + + private static long parseLong(String s, long fallback) { + if (s == null || s.isEmpty()) return fallback; + try { return Long.parseLong(s); } catch (NumberFormatException ex) { return fallback; } + } + + private static double parseDouble(String s, double fallback) { + if (s == null || s.isEmpty()) return fallback; + try { return Double.parseDouble(s); } catch (NumberFormatException ex) { return fallback; } + } + + private static double unixSecs() { + return System.currentTimeMillis() / 1000.0; + } +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/DemoServer.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/DemoServer.java new file mode 100644 index 0000000000..086b7cfa0b --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/DemoServer.java @@ -0,0 +1,683 @@ +package com.redis.agentmem; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import org.json.JSONArray; +import org.json.JSONObject; +import redis.clients.jedis.ConnectionPoolConfig; +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.JedisPooled; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; + +/** + * Redis agent-memory demo server (Java + Jedis). + * + *

Run this main and visit {@code http://localhost:8092} to drive + * a small agent-memory demo backed by Redis Hashes, JSON, Search, + * and Streams. The UI lets you type a turn, watch working memory + * update, see semantically similar long-term memories recalled, and + * inspect the per-thread event log. + * + *

The server holds a single {@link LocalEmbedder}, one + * {@link AgentSession}, one {@link LongTermMemory}, and one + * {@link AgentEventLog} for the lifetime of the process. The first + * run downloads the embedding model into the local DJL cache; + * everything after is local. + */ +public final class DemoServer { + + static final class Args { + String host = "127.0.0.1"; + int port = 8092; + String redisHost = "localhost"; + int redisPort = 6379; + String memIndexName = "agentmem:idx"; + String memKeyPrefix = "agent:mem:"; + String sessionKeyPrefix = "agent:session:"; + String eventKeyPrefix = "agent:events:"; + long sessionTtlSeconds = 3600; + double dedupThreshold = LongTermMemory.DEFAULT_DEDUP_THRESHOLD; + double recallThreshold = LongTermMemory.DEFAULT_RECALL_THRESHOLD; + boolean resetOnStart = true; + } + + public static void main(String[] argv) throws Exception { + Args args = parseArgs(argv); + + ConnectionPoolConfig poolConfig = new ConnectionPoolConfig(); + poolConfig.setMaxTotal(16); + poolConfig.setMaxIdle(4); + poolConfig.setMinIdle(1); + JedisPooled jedis = new JedisPooled( + poolConfig, + new HostAndPort(args.redisHost, args.redisPort), + DefaultJedisClientConfig.builder() + .socketTimeoutMillis(2000) + .connectionTimeoutMillis(2000) + .build()); + try { + jedis.ping(); + } catch (Exception ex) { + System.err.println("Error: cannot reach Redis at " + + args.redisHost + ":" + args.redisPort); + System.err.println(" (" + ex.getMessage() + ")"); + jedis.close(); + System.exit(1); + } + + AgentSession session = new AgentSession( + jedis, + args.sessionKeyPrefix, + args.sessionTtlSeconds, + AgentSession.DEFAULT_MAX_TURNS); + LongTermMemory memory = new LongTermMemory( + jedis, + args.memIndexName, + args.memKeyPrefix, + LocalEmbedder.defaultVectorDim(), + args.dedupThreshold, + args.recallThreshold, + null); + memory.createIndex(); + AgentEventLog events = new AgentEventLog( + jedis, args.eventKeyPrefix, AgentEventLog.DEFAULT_MAX_LEN); + + System.out.println("Loading embedding model " + + "(first run downloads the PyTorch weights)..."); + LocalEmbedder embedder = LocalEmbedder.create(); + + AgentMemoryDemo demo = new AgentMemoryDemo(session, memory, events, embedder); + + if (args.resetOnStart) { + System.out.println( + "Dropping any existing memories under '" + args.memKeyPrefix + + "*' and re-seeding from the sample memory list " + + "(pass --no-reset to keep)."); + int seeded = demo.seedAll("default", "default"); + System.out.println("Seeded " + seeded + " memories."); + } + + // Load index.html once and substitute the template tokens so + // the docs panel shows the actual values in use. + String rawHtml = loadIndexHtml(); + String htmlPage = rawHtml + .replace("__SESSION_PREFIX__", args.sessionKeyPrefix) + .replace("__MEM_PREFIX__", args.memKeyPrefix) + .replace("__MEM_INDEX__", args.memIndexName) + .replace("__EVENT_PREFIX__", args.eventKeyPrefix); + + HttpServer server = HttpServer.create( + new InetSocketAddress(args.host, args.port), 0); + server.setExecutor(Executors.newCachedThreadPool()); + server.createContext("/", new RootHandler(demo, htmlPage)); + + System.out.println("Redis agent memory demo listening on " + + "http://" + args.host + ":" + args.port); + System.out.println("Using Redis at " + args.redisHost + ":" + args.redisPort + + " with memory index '" + args.memIndexName + "'"); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + System.out.println("\nShutting down..."); + server.stop(0); + try { embedder.close(); } catch (Exception ignored) {} + jedis.close(); + })); + + server.start(); + } + + // ------------------------------------------------------------------ + // Demo orchestrator + // ------------------------------------------------------------------ + + /** + * Demo state: working memory, long-term memory, event log. + * + *

{@code seedAll} / {@code newThread} / {@code handleTurn} + * all touch {@code currentThreadId} — synchronized through the + * mutex below, but the lock is released between operations so a + * turn racing with a thread rotation can capture the old id and + * apply to the previous thread. The demo is single-user in + * practice, so the race never triggers; a multi-user agent would + * carry the thread id on each request instead of holding it as + * shared server state. See the walkthrough's "Concurrency + * caveats" section. + */ + static final class AgentMemoryDemo { + private final AgentSession session; + private final LongTermMemory memory; + private final AgentEventLog events; + private final LocalEmbedder embedder; + private final String defaultUser = "default"; + private final String defaultNamespace = "default"; + private final Object threadIdLock = new Object(); + private String currentThreadId; + + AgentMemoryDemo(AgentSession session, LongTermMemory memory, + AgentEventLog events, LocalEmbedder embedder) { + this.session = session; + this.memory = memory; + this.events = events; + this.embedder = embedder; + this.currentThreadId = session.newThreadId(); + } + + String currentThreadId() { + synchronized (threadIdLock) { + return currentThreadId; + } + } + + /** Drop everything in scope and pre-populate with seed memories. */ + int seedAll(String user, String namespace) throws Exception { + memory.clear(); + String threadId = currentThreadId(); + session.delete(threadId); + events.clear(threadId); + int written = SeedMemory.seed(memory, embedder, user, namespace, "seed"); + synchronized (threadIdLock) { + currentThreadId = session.newThreadId(); + } + return written; + } + + /** Start a fresh thread. Long-term memory is unaffected. */ + String newThread(String user, String namespace) { + String oldId = currentThreadId(); + events.clear(oldId); + String newId = session.newThreadId(); + session.start(newId, user, "demo-agent", "", null); + events.record(newId, "thread_started", + "user=" + user + " namespace=" + namespace); + synchronized (threadIdLock) { + currentThreadId = newId; + } + return newId; + } + + /** + * One pass through the agent loop: append, recall, remember, log. + * + *

The order matters. We embed once and reuse the vector + * for both the recall and (if asked) the remember step — no + * point encoding the same text twice. Recall runs + * before the remember write so the agent doesn't + * see its own just-written turn as a recalled memory. + */ + Map handleTurn( + String text, + String user, + String namespace, + String kind, + String role, + double threshold, + String action) throws Exception { + String threadId = currentThreadId(); + + long t0 = System.nanoTime(); + float[] vec = embedder.encodeOne(text); + double embedMs = (System.nanoTime() - t0) / 1_000_000.0; + + // `setGoal` only touches the goal field so existing turns + // aren't wiped; `appendTurn` carries the request `user` + // through to the auto-create path so a first turn for a + // new thread doesn't land under the default user. + String sessionAction; + if ("goal".equals(action)) { + session.setGoal(threadId, text, user, "demo-agent", null); + sessionAction = "goal_set"; + } else { + session.appendTurn(threadId, role, text, user, "demo-agent", null); + sessionAction = "turn_appended:" + role; + } + + long t1 = System.nanoTime(); + List recalled = memory.recall( + vec, user, namespace, null, 5, threshold); + double recallMs = (System.nanoTime() - t1) / 1_000_000.0; + + boolean writeSkipped = "skip".equals(kind) || "goal".equals(action); + WriteResult writeResult = null; + double writeMs = 0.0; + if (!writeSkipped) { + long t2 = System.nanoTime(); + writeResult = memory.remember( + text, vec, user, namespace, kind, threadId, null); + writeMs = (System.nanoTime() - t2) / 1_000_000.0; + } + + String detail; + if (writeResult == null) { + detail = ""; + } else if (writeResult.deduped()) { + detail = "deduped onto " + writeResult.id(); + } else { + detail = "wrote " + writeResult.id() + " as " + kind; + } + events.record(threadId, sessionAction, detail); + + Map payload = new LinkedHashMap<>(); + payload.put("thread_id", threadId); + payload.put("write_skipped", writeSkipped); + payload.put("memory_id", writeResult == null ? null : writeResult.id()); + payload.put("deduped", writeResult != null && writeResult.deduped()); + payload.put("existing_distance", + writeResult == null ? null : writeResult.existingDistance()); + payload.put("kind", writeSkipped ? null : kind); + payload.put("recalled", memoriesToJsonArray(recalled)); + payload.put("embed_ms", embedMs); + payload.put("recall_ms", recallMs); + payload.put("write_ms", writeMs); + return payload; + } + + // ----- Build the /state response shape ------------------------ + + JSONObject buildState(String user, String namespace) { + Map info = memory.indexInfo(); + JSONObject index = new JSONObject(); + index.put("num_docs", info.getOrDefault("num_docs", 0L)); + index.put("indexing_failures", info.getOrDefault("indexing_failures", 0L)); + index.put("index_name", memory.indexName()); + index.put("model", embedder.modelName()); + index.put("session_ttl_seconds", session.defaultTtlSeconds()); + index.put("dedup_threshold", memory.dedupThreshold()); + index.put("default_recall_threshold", memory.recallThreshold()); + index.put("stack_label", + "Jedis + DJL (PyTorch + HuggingFace) + Java standard library HTTP server"); + + String threadId = currentThreadId(); + SessionState state = session.load(threadId); + JSONArray memories = memoriesToJsonArray( + memory.listMemories(user, namespace, null, 200)); + JSONArray eventArr = new JSONArray(); + for (AgentEvent e : events.recent(threadId, 20)) { + eventArr.put(eventToJson(e)); + } + + JSONObject out = new JSONObject(); + out.put("index", index); + out.put("thread_id", threadId); + out.put("session", state == null ? JSONObject.NULL : sessionToJson(state)); + out.put("memories", memories); + out.put("events", eventArr); + // `recalled` is populated by /turn; on plain /state reads + // the UI keeps showing the last turn's result. + out.put("recalled", new JSONArray()); + return out; + } + } + + // ------------------------------------------------------------------ + // Serialisation helpers (match the Python/Node demo payloads + // exactly so the same index.html JS works) + // ------------------------------------------------------------------ + + static JSONObject sessionToJson(SessionState s) { + JSONObject obj = new JSONObject(); + obj.put("thread_id", s.threadId()); + obj.put("user", s.user()); + obj.put("agent", s.agent()); + obj.put("goal", s.goal()); + obj.put("scratchpad", s.scratchpad()); + obj.put("turn_count", s.turnCount()); + obj.put("created_ts", s.createdTs()); + obj.put("last_active_ts", s.lastActiveTs()); + JSONArray turns = new JSONArray(); + for (SessionTurn t : s.recentTurns()) { + JSONObject turn = new JSONObject(); + turn.put("role", t.role()); + turn.put("content", t.content()); + turn.put("ts", t.ts()); + turns.put(turn); + } + obj.put("recent_turns", turns); + obj.put("ttl_seconds", s.ttlSeconds()); + return obj; + } + + static JSONObject memoryToJson(MemoryRecord m) { + JSONObject obj = new JSONObject(); + obj.put("id", m.id()); + obj.put("user", m.user()); + obj.put("namespace", m.namespace()); + obj.put("kind", m.kind()); + obj.put("source_thread", m.sourceThread()); + obj.put("text", m.text()); + obj.put("created_ts", m.createdTs()); + obj.put("hit_count", m.hitCount()); + obj.put("distance", m.distance() == null ? JSONObject.NULL : m.distance()); + obj.put("ttl_seconds", m.ttlSeconds() == null ? JSONObject.NULL : m.ttlSeconds()); + return obj; + } + + static JSONArray memoriesToJsonArray(List records) { + JSONArray arr = new JSONArray(); + for (MemoryRecord m : records) { + arr.put(memoryToJson(m)); + } + return arr; + } + + static JSONObject eventToJson(AgentEvent e) { + JSONObject obj = new JSONObject(); + obj.put("event_id", e.eventId()); + obj.put("thread_id", e.threadId()); + obj.put("action", e.action()); + obj.put("detail", e.detail()); + obj.put("ts", e.ts()); + return obj; + } + + // ------------------------------------------------------------------ + // HTTP plumbing + // ------------------------------------------------------------------ + + static final class RootHandler implements HttpHandler { + private final AgentMemoryDemo demo; + private final String htmlPage; + + RootHandler(AgentMemoryDemo demo, String htmlPage) { + this.demo = demo; + this.htmlPage = htmlPage; + } + + @Override + public void handle(HttpExchange ex) throws IOException { + try { + String method = ex.getRequestMethod(); + URI uri = ex.getRequestURI(); + String path = uri.getPath(); + + if ("GET".equalsIgnoreCase(method)) { + if (path.equals("/") || path.equals("/index.html")) { + sendHtml(ex, 200, htmlPage); + return; + } + if (path.equals("/state")) { + Map q = parseForm(uri.getRawQuery()); + String user = nonEmpty(q.get("user"), "default"); + String ns = nonEmpty(q.get("namespace"), "default"); + sendJson(ex, 200, demo.buildState(user, ns)); + return; + } + sendJson(ex, 404, errorPayload("not found", null)); + return; + } + if ("POST".equalsIgnoreCase(method)) { + String body = readBody(ex); + Map params = parseForm(body); + + if (path.equals("/turn")) { + handleTurn(ex, params); + return; + } + if (path.equals("/new_thread")) { + String user = nonEmpty(params.get("user"), "default"); + String ns = nonEmpty(params.get("namespace"), "default"); + String tid = demo.newThread(user, ns); + JSONObject body2 = new JSONObject(); + body2.put("thread_id", tid); + sendJson(ex, 200, body2); + return; + } + if (path.equals("/reset")) { + String user = nonEmpty(params.get("user"), "default"); + String ns = nonEmpty(params.get("namespace"), "default"); + try { + int seeded = demo.seedAll(user, ns); + JSONObject ok = new JSONObject(); + ok.put("seeded", seeded); + sendJson(ex, 200, ok); + } catch (Exception inner) { + handleException(ex, inner); + } + return; + } + if (path.equals("/drop_memory")) { + String memoryId = params.getOrDefault("memory_id", "").trim(); + if (memoryId.isEmpty()) { + sendJson(ex, 400, errorPayload("memory_id is required", null)); + return; + } + boolean deleted = demo.memory.deleteMemory(memoryId); + JSONObject out = new JSONObject(); + out.put("deleted", deleted); + out.put("memory_id", memoryId); + sendJson(ex, 200, out); + return; + } + sendJson(ex, 404, errorPayload("not found", null)); + return; + } + sendJson(ex, 405, errorPayload("method not allowed", null)); + } catch (Exception exc) { + handleException(ex, exc); + } + } + + private void handleTurn(HttpExchange ex, Map params) + throws IOException { + String text = params.getOrDefault("text", "").trim(); + if (text.isEmpty()) { + sendJson(ex, 400, errorPayload("text is required", null)); + return; + } + double threshold = clampThreshold( + params.get("threshold"), demo.memory.recallThreshold()); + try { + Map payload = demo.handleTurn( + text, + nonEmpty(params.get("user"), "default"), + nonEmpty(params.get("namespace"), "default"), + nonEmpty(params.get("kind"), "episodic"), + nonEmpty(params.get("role"), "user"), + threshold, + nonEmpty(params.get("action"), "turn")); + sendJson(ex, 200, toJson(payload)); + } catch (Exception inner) { + handleException(ex, inner); + } + } + + private void handleException(HttpExchange ex, Exception exc) { + System.err.println("[demo] handler error: " + + exc.getClass().getSimpleName() + ": " + exc.getMessage()); + exc.printStackTrace(System.err); + try { + JSONObject body = errorPayload( + exc.getMessage() == null ? exc.getClass().getSimpleName() : exc.getMessage(), + exc.getClass().getSimpleName()); + sendJson(ex, 500, body); + } catch (Exception ignored) { + // Headers may already be partially flushed; nothing + // useful left to do beyond letting the connection drop. + } + } + } + + // ------------------------------------------------------------------ + // Helpers + // ------------------------------------------------------------------ + + /** + * Parse a threshold value, clamping NaN/Infinity to + * {@code fallback} and otherwise clamping to {@code [0.0, 2.0]}. + * {@code parseDouble} happily handles "nan" → NaN and "inf" → +Inf; + * either would silently turn recall into "every memory" or + * "nothing", so clamping stops a malformed POST from overriding + * the threshold semantics. + */ + static double clampThreshold(String raw, double fallback) { + if (raw == null || raw.isEmpty()) return fallback; + double parsed; + try { + parsed = Double.parseDouble(raw); + } catch (NumberFormatException ex) { + return fallback; + } + if (Double.isNaN(parsed) || Double.isInfinite(parsed)) return fallback; + return Math.max(0.0, Math.min(2.0, parsed)); + } + + private static String nonEmpty(String value, String fallback) { + return (value == null || value.isEmpty()) ? fallback : value; + } + + /** + * Cap POST bodies so a runaway client can't accumulate unbounded + * memory before the handler runs. {@code com.sun.net.httpserver} + * provides no built-in limit on request bodies; the demo's + * largest legitimate body is a few hundred bytes of form-encoded + * query fields. 1 MiB matches the Node, Go, Rust, and .NET caps. + */ + private static final int MAX_BODY_BYTES = 1 * 1024 * 1024; + + private static String readBody(HttpExchange ex) throws IOException { + try (InputStream in = ex.getRequestBody()) { + byte[] bytes = in.readNBytes(MAX_BODY_BYTES + 1); + if (bytes.length > MAX_BODY_BYTES) { + throw new IOException( + "request body exceeds " + MAX_BODY_BYTES + " bytes"); + } + return new String(bytes, StandardCharsets.UTF_8); + } + } + + static Map parseForm(String body) { + Map out = new HashMap<>(); + if (body == null || body.isEmpty()) return out; + for (String pair : body.split("&")) { + if (pair.isEmpty()) continue; + int eq = pair.indexOf('='); + String key, value; + if (eq < 0) { + key = URLDecoder.decode(pair, StandardCharsets.UTF_8); + value = ""; + } else { + key = URLDecoder.decode(pair.substring(0, eq), StandardCharsets.UTF_8); + value = URLDecoder.decode(pair.substring(eq + 1), StandardCharsets.UTF_8); + } + out.put(key, value); + } + return out; + } + + private static void sendHtml(HttpExchange ex, int status, String html) throws IOException { + byte[] bytes = html.getBytes(StandardCharsets.UTF_8); + ex.getResponseHeaders().set("Content-Type", "text/html; charset=utf-8"); + ex.sendResponseHeaders(status, bytes.length); + ex.getResponseBody().write(bytes); + ex.getResponseBody().close(); + } + + private static void sendJson(HttpExchange ex, int status, JSONObject body) throws IOException { + byte[] bytes = body.toString().getBytes(StandardCharsets.UTF_8); + ex.getResponseHeaders().set("Content-Type", "application/json"); + ex.sendResponseHeaders(status, bytes.length); + ex.getResponseBody().write(bytes); + ex.getResponseBody().close(); + } + + private static JSONObject errorPayload(String message, String type) { + JSONObject out = new JSONObject(); + out.put("error", message); + if (type != null) out.put("type", type); + return out; + } + + private static JSONObject toJson(Map map) { + JSONObject out = new JSONObject(); + for (Map.Entry entry : map.entrySet()) { + Object value = entry.getValue(); + if (value == null) { + out.put(entry.getKey(), JSONObject.NULL); + } else { + out.put(entry.getKey(), value); + } + } + return out; + } + + private static String loadIndexHtml() throws IOException { + // index.html is shipped as a classpath resource (Maven pulls + // it from the project root via the entry in + // pom.xml). + try (InputStream in = DemoServer.class.getResourceAsStream("/index.html")) { + if (in == null) { + throw new IOException( + "index.html not found on classpath; rebuild with `mvn package`"); + } + return new String(in.readAllBytes(), StandardCharsets.UTF_8); + } + } + + // ------------------------------------------------------------------ + // CLI parsing + // ------------------------------------------------------------------ + + static Args parseArgs(String[] argv) { + Args args = new Args(); + for (int i = 0; i < argv.length; i++) { + String a = argv[i]; + switch (a) { + case "--host": args.host = require(argv, ++i, a); break; + case "--port": args.port = Integer.parseInt(require(argv, ++i, a)); break; + case "--redis-host": args.redisHost = require(argv, ++i, a); break; + case "--redis-port": args.redisPort = Integer.parseInt(require(argv, ++i, a)); break; + case "--mem-index-name": args.memIndexName = require(argv, ++i, a); break; + case "--mem-key-prefix": args.memKeyPrefix = require(argv, ++i, a); break; + case "--session-key-prefix": args.sessionKeyPrefix = require(argv, ++i, a); break; + case "--event-key-prefix": args.eventKeyPrefix = require(argv, ++i, a); break; + case "--session-ttl-seconds": args.sessionTtlSeconds = Long.parseLong(require(argv, ++i, a)); break; + case "--dedup-threshold": args.dedupThreshold = Double.parseDouble(require(argv, ++i, a)); break; + case "--recall-threshold": args.recallThreshold = Double.parseDouble(require(argv, ++i, a)); break; + case "--no-reset": args.resetOnStart = false; break; + case "-h": + case "--help": + printHelp(); + System.exit(0); + break; + default: + throw new IllegalArgumentException("Unknown flag: " + a); + } + } + return args; + } + + private static String require(String[] argv, int i, String flag) { + if (i >= argv.length) { + throw new IllegalArgumentException("Missing value for " + flag); + } + return argv[i]; + } + + private static void printHelp() { + System.out.println("Usage: java -jar agent-memory-jedis.jar [options]"); + System.out.println(" --host HOST HTTP bind host (default 127.0.0.1)"); + System.out.println(" --port PORT HTTP bind port (default 8092)"); + System.out.println(" --redis-host HOST Redis host (default localhost)"); + System.out.println(" --redis-port PORT Redis port (default 6379)"); + System.out.println(" --mem-index-name NAME Memory search index (default agentmem:idx)"); + System.out.println(" --mem-key-prefix PREFIX JSON memory key prefix (default agent:mem:)"); + System.out.println(" --session-key-prefix PREFIX Session hash key prefix (default agent:session:)"); + System.out.println(" --event-key-prefix PREFIX Event stream key prefix (default agent:events:)"); + System.out.println(" --session-ttl-seconds N Working-memory TTL (default 3600)"); + System.out.println(" --dedup-threshold F Cosine distance for dedup (default 0.20)"); + System.out.println(" --recall-threshold F Cosine distance for recall (default 0.55)"); + System.out.println(" --no-reset Skip clearing and re-seeding on startup"); + } +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LocalEmbedder.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LocalEmbedder.java new file mode 100644 index 0000000000..e9e1bfb524 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LocalEmbedder.java @@ -0,0 +1,149 @@ +package com.redis.agentmem; + +import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory; +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +/** + * Local text-embedding helper backed by DJL + PyTorch. + * + *

This is a thin wrapper around the + * {@code sentence-transformers/all-MiniLM-L6-v2} model loaded from + * DJL's model zoo: a 384-dimensional encoder that runs in-process on + * CPU through libtorch, needs no API key, and produces vectors that + * are numerically very close to the equivalent Python and Node ports + * (close enough that paraphrase distances differ only at the fourth + * decimal place). + * + *

DJL's {@link TextEmbeddingTranslatorFactory} returns mean-pooled, + * L2-normalized vectors by default, so a Redis Search index declared + * with {@code DISTANCE_METRIC COSINE} returns scores that are + * directly comparable across entries. The model is downloaded into + * the local DJL cache on the first call; every later call runs + * offline. + */ +public final class LocalEmbedder implements AutoCloseable { + + private static final String DEFAULT_MODEL_URL = + "djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2"; + private static final String DEFAULT_MODEL_NAME = + "sentence-transformers/all-MiniLM-L6-v2"; + private static final int DEFAULT_VECTOR_DIM = 384; + + private final String modelName; + private final ZooModel model; + private final Predictor predictor; + private final int dim; + + private LocalEmbedder( + String modelName, + ZooModel model, + Predictor predictor, + int dim) { + this.modelName = modelName; + this.model = model; + this.predictor = predictor; + this.dim = dim; + } + + /** + * Load the default model. Blocks while DJL downloads the + * PyTorch weights on the first run, then keeps a single loaded + * predictor for the lifetime of the embedder. + */ + public static LocalEmbedder create() throws Exception { + Criteria criteria = Criteria.builder() + .setTypes(String.class, float[].class) + .optModelUrls(DEFAULT_MODEL_URL) + .optEngine("PyTorch") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .optProgress(new ProgressBar()) + .build(); + ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + // Probe the output shape once so we fail loudly if a + // different model is wired up against the 384-dim Redis + // Search field. + float[] probe = predictor.predict("dimension probe"); + int dim = probe.length; + return new LocalEmbedder(DEFAULT_MODEL_NAME, model, predictor, dim); + } + + public String modelName() { + return modelName; + } + + public int dim() { + return dim; + } + + /** + * Encode a single string. Returns a {@code float[]} of length + * {@link #dim()}. + * + *

The DJL PyTorch {@code Predictor} is not thread-safe — its + * underlying NDManager and tokenizer state mutate per call. The + * demo server uses a cached thread pool, so two browser tabs + * could land on different handler threads and call this method + * concurrently. We {@code synchronized}-guard both encode entry + * points to serialize access to the shared predictor; encoding + * is the bottleneck either way and a single CPU-bound model + * won't usefully run two requests in parallel. A higher- + * throughput deployment would replace this with a small pool + * of {@code Predictor} instances or a dedicated single-threaded + * inference executor. + */ + public synchronized float[] encodeOne(String text) throws Exception { + return predictor.predict(text); + } + + /** Encode several strings sequentially. See {@link #encodeOne} + * for the rationale behind the synchronisation. */ + public synchronized List encodeMany(List texts) throws Exception { + List out = new ArrayList<>(texts.size()); + for (String text : texts) { + out.add(predictor.predict(text)); + } + return out; + } + + /** + * Pack a {@code float[]} into the bytes Redis Search expects. + * Vectors are little-endian {@code float32}; this matches the + * encoding the Python and Node ports write. + */ + public static byte[] toBytes(float[] vector) { + byte[] bytes = new byte[Float.BYTES * vector.length]; + ByteBuffer + .wrap(bytes) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer() + .put(vector); + return bytes; + } + + @Override + public void close() { + try { + predictor.close(); + } catch (Exception ignored) { + // best-effort cleanup + } + try { + model.close(); + } catch (Exception ignored) { + // best-effort cleanup + } + } + + public static int defaultVectorDim() { + return DEFAULT_VECTOR_DIM; + } +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LongTermMemory.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LongTermMemory.java new file mode 100644 index 0000000000..d9c3707a5c --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/LongTermMemory.java @@ -0,0 +1,530 @@ +package com.redis.agentmem; + +import org.json.JSONObject; +import redis.clients.jedis.AbstractTransaction; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.exceptions.JedisDataException; +import redis.clients.jedis.json.Path2; +import redis.clients.jedis.search.Document; +import redis.clients.jedis.search.FTCreateParams; +import redis.clients.jedis.search.IndexDataType; +import redis.clients.jedis.search.Query; +import redis.clients.jedis.search.SearchResult; +import redis.clients.jedis.search.schemafields.NumericField; +import redis.clients.jedis.search.schemafields.SchemaField; +import redis.clients.jedis.search.schemafields.TagField; +import redis.clients.jedis.search.schemafields.TextField; +import redis.clients.jedis.search.schemafields.VectorField; +import redis.clients.jedis.search.schemafields.VectorField.VectorAlgorithm; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; + +/** + * Long-term memory store for an agent, backed by Redis JSON and + * Search. + * + *

Each memory lives as one JSON document at + * {@code agent:mem:}. The document holds the memory text, its + * embedding vector, and a small metadata block — user, namespace, + * kind, source thread, timestamps — that lets the recall query scope + * results without falling back to application-side filtering. + * + *

A single Redis Search index covers the embedding plus every + * metadata field, so one {@code FT.SEARCH} call performs + * approximate-nearest-neighbour over the in-scope subset and returns + * the top-k memories ranked by cosine distance. The same KNN check + * runs at write time to deduplicate near-identical memories + * before they enter the store, which keeps the index from filling + * with paraphrases of the same fact as the agent reasons over + * similar topics across sessions. + * + *

Memories carry one of two kinds: + * + *

    + *
  • {@code episodic} — "what happened" snapshots from a specific + * thread, written with a medium TTL so old session detail + * decays naturally.
  • + *
  • {@code semantic} — distilled facts and preferences the agent + * should carry forward indefinitely. Written with no TTL by + * default.
  • + *
+ * + *

The split is enforced as a TAG on the index, so the recall + * query can ask for one kind or both with a filter — no separate + * keyspaces. + */ +public final class LongTermMemory { + + public static final int VECTOR_DIM_DEFAULT = 384; + + /** + * Cosine-distance cutoff for write-time deduplication. Smaller = + * stricter. 0.20 is calibrated to the + * {@code sentence-transformers/all-MiniLM-L6-v2} embedding model: + * a paraphrase of an existing memory lands in the 0.10 – 0.20 + * range and a distinct memory lands above 0.50. + */ + public static final double DEFAULT_DEDUP_THRESHOLD = 0.20; + + /** + * Cosine-distance cutoff for recall results. Larger than the + * dedup threshold so the agent gets a wider net at read time + * than at write time. + */ + public static final double DEFAULT_RECALL_THRESHOLD = 0.55; + + /** + * Default per-kind TTLs (seconds). A {@code null} pointer means + * "no TTL" — the memory persists until explicitly deleted or + * evicted under memory pressure. + */ + public static Map defaultTtlByKind() { + Map out = new HashMap<>(); + out.put("episodic", 7L * 24 * 3600); + out.put("semantic", null); + return out; + } + + /** + * Characters Redis Search treats as syntax inside a TAG value; + * any of them in a user-supplied filter must be backslash-escaped + * or the surrounding {@code {...}} block won't parse correctly. + */ + private static final String TAG_SPECIAL = "\\,.<>{}[]\"':;!@#$%^&*()-+=~| "; + + private final JedisPooled jedis; + private final String indexName; + private final String keyPrefix; + private final int vectorDim; + private final double dedupThreshold; + private final double recallThreshold; + private final Map ttlByKind; + + public LongTermMemory( + JedisPooled jedis, + String indexName, + String keyPrefix, + int vectorDim, + double dedupThreshold, + double recallThreshold, + Map ttlByKind) { + this.jedis = jedis; + this.indexName = indexName; + this.keyPrefix = keyPrefix; + this.vectorDim = vectorDim > 0 ? vectorDim : VECTOR_DIM_DEFAULT; + // Thresholds are honored as-is. Zero is a legitimate value + // ("exact matches only" for dedup, "nothing recalls" for + // recall); silently rewriting them would make + // --dedup-threshold 0 uncallable. + this.dedupThreshold = dedupThreshold < 0 ? DEFAULT_DEDUP_THRESHOLD : dedupThreshold; + this.recallThreshold = recallThreshold < 0 ? DEFAULT_RECALL_THRESHOLD : recallThreshold; + this.ttlByKind = ttlByKind != null ? ttlByKind : defaultTtlByKind(); + } + + public String indexName() { + return indexName; + } + + public String keyPrefix() { + return keyPrefix; + } + + public int vectorDim() { + return vectorDim; + } + + public double dedupThreshold() { + return dedupThreshold; + } + + public double recallThreshold() { + return recallThreshold; + } + + public String memoryKey(String memoryId) { + return keyPrefix + memoryId; + } + + // ------------------------------------------------------------------ + // Index management + // ------------------------------------------------------------------ + + /** + * Create the Redis Search index if it doesn't already exist. + * + *

The index is declared on the JSON document type with alias + * names on each path; the same {@code FT.SEARCH} filter clause + * works here as on a HASH-backed index, and the field paths + * ({@code $.user}, {@code $.embedding}, ...) only show up in + * {@code FT.CREATE}. + */ + public void createIndex() { + List schema = List.of( + TextField.of("$.text").as("text"), + TagField.of("$.user").as("user"), + TagField.of("$.namespace").as("namespace"), + TagField.of("$.kind").as("kind"), + TagField.of("$.source_thread").as("source_thread"), + NumericField.of("$.created_ts").as("created_ts").sortable(), + NumericField.of("$.hit_count").as("hit_count").sortable(), + VectorField.builder() + .fieldName("$.embedding").as("embedding") + .algorithm(VectorAlgorithm.HNSW) + .attributes(Map.of( + "TYPE", "FLOAT32", + "DIM", vectorDim, + "DISTANCE_METRIC", "COSINE" + )) + .build() + ); + try { + jedis.ftCreate( + indexName, + FTCreateParams.createParams() + .on(IndexDataType.JSON) + .addPrefix(keyPrefix), + schema + ); + } catch (JedisDataException ex) { + if (!String.valueOf(ex.getMessage()).contains("Index already exists")) { + throw ex; + } + } + } + + /** Drop the search index. Optionally also delete the JSON docs. */ + public void dropIndex(boolean deleteDocuments) { + try { + if (deleteDocuments) { + jedis.ftDropIndexDD(indexName); + } else { + jedis.ftDropIndex(indexName); + } + } catch (JedisDataException ex) { + String msg = String.valueOf(ex.getMessage()).toLowerCase(Locale.ROOT); + if (!msg.contains("no such index") && !msg.contains("unknown index name")) { + throw ex; + } + } + } + + // ------------------------------------------------------------------ + // Write + // ------------------------------------------------------------------ + + /** + * Write a new memory, deduplicating against existing entries. + * + *

Runs one in-scope KNN(1) against the index first. If the + * nearest existing memory is within {@link #dedupThreshold()}, + * the new memory is skipped (its content is already represented) + * and the existing memory's {@code hit_count} is bumped via + * {@code JSON.NUMINCRBY}. Otherwise a fresh JSON document is + * written under a new id with a TTL derived from the memory's + * {@code kind}. + * + *

The KNN-then-write sequence is not atomic; two workers that + * remember the same fact at the same time can both miss each + * other's in-flight write and insert duplicate memories. See the + * walkthrough's "Concurrency caveats" section for the production + * fix (periodic background consolidator that merges + * near-duplicates). + */ + public WriteResult remember( + String text, + float[] embedding, + String user, + String namespace, + String kind, + String sourceThread, + Long ttlSeconds) { + if (embedding.length != vectorDim) { + throw new IllegalArgumentException( + "embedding length is " + embedding.length + + "; index expects " + vectorDim); + } + if (user == null || user.isEmpty()) user = "default"; + if (namespace == null || namespace.isEmpty()) namespace = "default"; + if (kind == null || kind.isEmpty()) kind = "episodic"; + + List nearest = nearest(embedding, user, namespace, kind, 1); + Double existingDistance = !nearest.isEmpty() ? nearest.get(0).distance() : null; + if (!nearest.isEmpty() + && existingDistance != null + && existingDistance <= dedupThreshold) { + bumpHitCount(nearest.get(0).id()); + return new WriteResult(nearest.get(0).id(), true, existingDistance); + } + + String id = UUID.randomUUID().toString().replace("-", "").substring(0, 12); + String key = memoryKey(id); + double now = unixSecs(); + JSONObject doc = new JSONObject(); + doc.put("id", id); + doc.put("user", user); + doc.put("namespace", namespace); + doc.put("kind", kind); + doc.put("source_thread", sourceThread == null ? "" : sourceThread); + doc.put("text", text == null ? "" : text); + // org.json's JSONObject.put(String, Object) serializes a + // float[] as a JSON array of numbers — exactly what the JSON + // vector field expects at index time. + doc.put("embedding", embedding); + doc.put("created_ts", now); + doc.put("hit_count", 0); + + Long ttl = ttlSeconds != null ? ttlSeconds : ttlByKind.get(kind); + + // MULTI/EXEC so JSON.SET and EXPIRE either both apply or + // neither does. A connection drop between the two writes + // would otherwise leave the memory without an expiry — the + // index entry would still be there, but an `episodic` doc + // would outlive its intended seven-day TTL. + try (AbstractTransaction tx = jedis.multi()) { + tx.jsonSet(key, Path2.ROOT_PATH, doc); + if (ttl != null && ttl > 0) { + tx.expire(key, ttl); + } + tx.exec(); + } + return new WriteResult(id, false, existingDistance); + } + + // ------------------------------------------------------------------ + // Recall + // ------------------------------------------------------------------ + + /** + * Return the top-k in-scope memories ranked by similarity. + * Memories beyond {@code distanceThreshold} (or the instance + * default) are dropped — the index always returns + * something for KNN, so a recall result on an unrelated + * query would otherwise be a confidently-wrong false positive. + */ + public List recall( + float[] queryEmbedding, + String user, + String namespace, + String kind, + int k, + Double distanceThreshold) { + if (k <= 0) k = 5; + double threshold = distanceThreshold != null ? distanceThreshold : recallThreshold; + List candidates = nearest(queryEmbedding, user, namespace, kind, k); + List out = new ArrayList<>(candidates.size()); + for (MemoryRecord c : candidates) { + if (c.distance() != null && c.distance() <= threshold) { + out.add(c); + } + } + return out; + } + + // ------------------------------------------------------------------ + // Admin / inspection + // ------------------------------------------------------------------ + + public Map indexInfo() { + Map out = new HashMap<>(); + out.put("num_docs", 0L); + out.put("indexing_failures", 0L); + try { + Map info = jedis.ftInfo(indexName); + out.put("num_docs", parseLong(info.get("num_docs"), 0L)); + out.put("indexing_failures", + parseLong(info.get("hash_indexing_failures"), 0L)); + } catch (JedisDataException ignored) { + // index does not exist + } + return out; + } + + /** Return memories matching the filters, newest first. */ + public List listMemories( + String user, String namespace, String kind, int limit) { + if (limit <= 0) limit = 100; + String filterClause = buildFilterClause(user, namespace, kind); + Query q = new Query(filterClause) + .returnFields( + "user", "namespace", "kind", "source_thread", + "text", "created_ts", "hit_count") + .limit(0, limit) + .setSortBy("created_ts", false) + .dialect(2); + List out = new ArrayList<>(); + SearchResult result; + try { + result = jedis.ftSearch(indexName, q); + } catch (JedisDataException ex) { + return out; + } + for (Document doc : result.getDocuments()) { + String memoryId = stripPrefix(doc.getId()); + long ttl = jedis.ttl(memoryKey(memoryId)); + Long ttlSeconds = ttl > 0 ? ttl : null; + out.add(toRecord(memoryId, doc, null, ttlSeconds)); + } + return out; + } + + public boolean deleteMemory(String memoryId) { + return jedis.del(memoryKey(memoryId)) > 0; + } + + /** + * Drop the index and every memory document, then re-create the + * index. Returns the count of documents that were removed. + */ + public long clear() { + long before = parseLong(indexInfo().get("num_docs"), 0L); + dropIndex(true); + createIndex(); + return before; + } + + // ------------------------------------------------------------------ + // Internals + // ------------------------------------------------------------------ + + private List nearest( + float[] embedding, String user, String namespace, String kind, int k) { + if (embedding.length != vectorDim) { + throw new IllegalArgumentException( + "embedding length is " + embedding.length + + "; index expects " + vectorDim); + } + String filterClause = buildFilterClause(user, namespace, kind); + String knnQuery = filterClause + "=>[KNN " + k + " @embedding $vec AS distance]"; + byte[] vecBytes = LocalEmbedder.toBytes(embedding); + + Query q = new Query(knnQuery) + .returnFields( + "user", "namespace", "kind", "source_thread", + "text", "created_ts", "hit_count", "distance") + .setSortBy("distance", true) + .limit(0, k) + .addParam("vec", vecBytes) + .dialect(2); + + SearchResult result = jedis.ftSearch(indexName, q); + List out = new ArrayList<>(result.getDocuments().size()); + for (Document doc : result.getDocuments()) { + // `doc.getId()` is the full Redis key (e.g. + // `agent:mem:abc123`). Strip the prefix so the returned + // record exposes only the opaque id the UI and + // `deleteMemory` work with. + String memoryId = stripPrefix(doc.getId()); + long ttl = jedis.ttl(memoryKey(memoryId)); + Long ttlSeconds = ttl > 0 ? ttl : null; + Double distance = parseDoubleOrNull(doc.get("distance")); + out.add(toRecord(memoryId, doc, distance, ttlSeconds)); + } + return out; + } + + private void bumpHitCount(String memoryId) { + try { + // Fire-and-forget: the doc may have expired between + // recall and bump, and discarding the error keeps the + // demo from blowing up on that race; we just lose the + // hit-count update. + jedis.jsonNumIncrBy(memoryKey(memoryId), Path2.of("$.hit_count"), 1.0); + } catch (JedisDataException ignored) { + // memory expired or path not found + } + } + + private static MemoryRecord toRecord( + String memoryId, Document doc, Double distance, Long ttlSeconds) { + return new MemoryRecord( + memoryId, + nullSafe(doc.getString("user")), + nullSafe(doc.getString("namespace")), + nullSafe(doc.getString("kind")), + nullSafe(doc.getString("source_thread")), + nullSafe(doc.getString("text")), + parseDouble(doc.get("created_ts"), 0.0), + parseLong(doc.get("hit_count"), 0L), + distance, + ttlSeconds); + } + + private String stripPrefix(String rawKey) { + return rawKey.startsWith(keyPrefix) ? rawKey.substring(keyPrefix.length()) : rawKey; + } + + static String escapeTagValue(String value) { + StringBuilder out = new StringBuilder(value.length()); + for (int i = 0; i < value.length(); i++) { + char ch = value.charAt(i); + if (TAG_SPECIAL.indexOf(ch) >= 0) { + out.append('\\'); + } + out.append(ch); + } + return out.toString(); + } + + static String buildFilterClause(String user, String namespace, String kind) { + List clauses = new ArrayList<>(3); + if (user != null && !user.isEmpty()) { + clauses.add("@user:{" + escapeTagValue(user) + "}"); + } + if (namespace != null && !namespace.isEmpty()) { + clauses.add("@namespace:{" + escapeTagValue(namespace) + "}"); + } + if (kind != null && !kind.isEmpty()) { + clauses.add("@kind:{" + escapeTagValue(kind) + "}"); + } + if (clauses.isEmpty()) return "(*)"; + return "(" + String.join(" ", clauses) + ")"; + } + + private static String nullSafe(String s) { + return s == null ? "" : s; + } + + private static double unixSecs() { + return System.currentTimeMillis() / 1000.0; + } + + private static double parseDouble(Object value, double dflt) { + if (value == null) return dflt; + if (value instanceof Number n) return n.doubleValue(); + try { + return Double.parseDouble(value.toString()); + } catch (NumberFormatException ex) { + return dflt; + } + } + + private static Double parseDoubleOrNull(Object value) { + if (value == null) return null; + if (value instanceof Number n) return n.doubleValue(); + try { + return Double.parseDouble(value.toString()); + } catch (NumberFormatException ex) { + return null; + } + } + + private static long parseLong(Object value, long dflt) { + if (value == null) return dflt; + if (value instanceof Number n) return n.longValue(); + try { + return Long.parseLong(value.toString()); + } catch (NumberFormatException ex) { + try { + return (long) Double.parseDouble(value.toString()); + } catch (NumberFormatException ignored) { + return dflt; + } + } + } + +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/MemoryRecord.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/MemoryRecord.java new file mode 100644 index 0000000000..1723172230 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/MemoryRecord.java @@ -0,0 +1,21 @@ +package com.redis.agentmem; + +/** + * A single long-term memory document. + * + *

{@code distance} is set only when the record comes back from a + * KNN query; {@code ttlSeconds} is {@code null} for memories with no + * TTL (e.g. {@code kind=semantic} under the default tier map). + */ +public record MemoryRecord( + String id, + String user, + String namespace, + String kind, + String sourceThread, + String text, + double createdTs, + long hitCount, + Double distance, + Long ttlSeconds) { +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SeedMemory.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SeedMemory.java new file mode 100644 index 0000000000..68df260546 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SeedMemory.java @@ -0,0 +1,96 @@ +package com.redis.agentmem; + +import java.util.ArrayList; +import java.util.List; + +/** + * Pre-seed the long-term memory store with sample memories. + * + *

In a real deployment the memory store fills up organically as + * the agent reasons over user turns: each turn produces zero or more + * memories (preferences, facts, episodic summaries) that flow into + * the store with deduplication. To make the demo immediately useful + * — so the first recall query lands on relevant results instead of + * an empty list — we seed a small set of canonical memories for a + * default user at startup. + * + *

The seed list mixes {@code semantic} memories (long-lived + * preferences and facts) with {@code episodic} memories (snapshots + * of past sessions), matching what the Python, Node, .NET, Rust, and + * Go demos seed so all six implementations behave identically. + */ +public final class SeedMemory { + + private SeedMemory() {} + + public record SeedEntry(String text, String kind) {} + + public static final List SEED_MEMORIES = List.of( + new SeedEntry( + "The user prefers concise answers without filler phrases.", + "semantic"), + new SeedEntry( + "The user is a Python developer working on a logistics platform.", + "semantic"), + new SeedEntry( + "The user lives in Berlin and works in the Europe/Berlin time zone.", + "semantic"), + new SeedEntry( + "The user dislikes dark mode and prefers a high-contrast light " + + "theme in editors and dashboards.", + "semantic"), + new SeedEntry( + "The user is allergic to peanuts; any restaurant suggestion must " + + "avoid dishes that commonly contain them.", + "semantic"), + new SeedEntry( + "Last Tuesday the user asked the agent to draft a postmortem for " + + "the order-routing outage. The agent produced a five-section " + + "draft and the user approved sections 1, 2, and 4 with minor " + + "edits.", + "episodic"), + new SeedEntry( + "In a previous session the user asked for help debugging a flaky " + + "test in the inventory service. The fix turned out to be a " + + "race condition in the warehouse webhook handler.", + "episodic"), + new SeedEntry( + "Two weeks ago the user mentioned they were planning to migrate " + + "the analytics warehouse from Snowflake to BigQuery in Q3.", + "episodic") + ); + + /** + * Embed and write the seed memories. Returns the count actually + * written (entries that dedup against existing memories don't + * count). + */ + public static int seed( + LongTermMemory memory, + LocalEmbedder embedder, + String user, + String namespace, + String sourceThread) throws Exception { + List texts = new ArrayList<>(SEED_MEMORIES.size()); + for (SeedEntry s : SEED_MEMORIES) { + texts.add(s.text()); + } + List vectors = embedder.encodeMany(texts); + int written = 0; + for (int i = 0; i < SEED_MEMORIES.size(); i++) { + SeedEntry entry = SEED_MEMORIES.get(i); + WriteResult result = memory.remember( + entry.text(), + vectors.get(i), + user, + namespace, + entry.kind(), + sourceThread, + null); + if (!result.deduped()) { + written++; + } + } + return written; + } +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SessionState.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SessionState.java new file mode 100644 index 0000000000..67b0f6af2d --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SessionState.java @@ -0,0 +1,23 @@ +package com.redis.agentmem; + +import java.util.List; + +/** + * The full per-thread working-memory state. + * + *

{@code recentTurns} is bounded by {@code AgentSession.maxTurns()}; + * the hash itself never grows in size or field count as the + * conversation goes on. + */ +public record SessionState( + String threadId, + String user, + String agent, + String goal, + String scratchpad, + long turnCount, + double createdTs, + double lastActiveTs, + List recentTurns, + long ttlSeconds) { +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SessionTurn.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SessionTurn.java new file mode 100644 index 0000000000..f9c88bd4ef --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/SessionTurn.java @@ -0,0 +1,11 @@ +package com.redis.agentmem; + +/** + * One turn inside the rolling session window. + * + *

Stored as part of a JSON array on the + * {@code agent:session:{threadId}} hash; the embedder helper does not + * see this directly. + */ +public record SessionTurn(String role, String content, double ts) { +} diff --git a/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/WriteResult.java b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/WriteResult.java new file mode 100644 index 0000000000..b2cf903b49 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-jedis/src/main/java/com/redis/agentmem/WriteResult.java @@ -0,0 +1,13 @@ +package com.redis.agentmem; + +/** + * Outcome of a {@link LongTermMemory#remember} call. + * + *

{@code deduped} is {@code true} when the write skipped because a + * similar memory already existed; {@code id} is then the existing + * memory's id. {@code existingDistance} is the cosine distance to + * that nearest memory regardless of which branch was taken — useful + * for tracing. + */ +public record WriteResult(String id, boolean deduped, Double existingDistance) { +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/.gitignore b/content/develop/use-cases/agent-memory/java-lettuce/.gitignore new file mode 100644 index 0000000000..434dfcf2d3 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/.gitignore @@ -0,0 +1,6 @@ +target/ +.idea/ +*.iml +.classpath +.project +.settings/ diff --git a/content/develop/use-cases/agent-memory/java-lettuce/_index.md b/content/develop/use-cases/agent-memory/java-lettuce/_index.md new file mode 100644 index 0000000000..08c8397f83 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/_index.md @@ -0,0 +1,404 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Java with Lettuce, DJL (PyTorch), and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: Lettuce example (Java) +title: Redis agent memory with Lettuce +weight: 7 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Java with [Lettuce]({{< relref "/develop/clients/lettuce" >}}) and [DJL](https://djl.ai/) (the Deep Java Library), using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with the JDK's [`com.sun.net.httpserver`](https://docs.oracle.com/en/java/javase/17/docs/api/jdk.httpserver/com/sun/net/httpserver/package-summary.html) so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [DJL](https://djl.ai/) (`ai.djl.huggingface.tokenizers` + `ai.djl.pytorch.pytorch-model-zoo`) running the canonical `sentence-transformers/all-MiniLM-L6-v2` PyTorch checkpoint — the same library and model the existing [Lettuce vector-search example]({{< relref "/develop/clients/lettuce/vecsearch" >}}) uses, and the same encoder the [Python]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) example loads. DJL drives libtorch through the same C++ runtime as Python's PyTorch bindings, so the vectors produced here are numerically identical to the Python ones to within rounding noise, and the distance bands the Python walkthrough quotes carry over to this demo without recalibration. A memory written by one demo can be recalled by the other against the same Redis instance. + +The big shape difference from the [sister Jedis port]({{< relref "/develop/use-cases/agent-memory/java-jedis" >}}) is that Lettuce 6.7 doesn't ship first-class `FT.*` or `JSON.*` bindings yet. The helper classes send those commands through Lettuce's generic `dispatch()` API with a custom `ProtocolKeyword`, and the FT.SEARCH/FT.INFO replies are parsed by walking the resulting `List` — see [`LongTermMemory.java`](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LongTermMemory.java) for the parser. The same connection is pinned to RESP2 at startup so the reply shape stays a flat array (Lettuce 6.7 negotiates RESP3 by default, which wraps the data in a map structure the parser would have to special-case). + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* One Redis Search call per recall: [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) does the KNN + TAG pre-filter in a single round trip (a per-row [`TTL`]({{< relref "/commands/ttl" >}}) follow-up is the only other read the helper issues, just to populate the `ttl_seconds` field for the admin panel). Working memory is one [`HGETALL`]({{< relref "/commands/hgetall" >}}); the event log is one [`XADD`]({{< relref "/commands/xadd" >}}). +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.encodeOne(text)` to turn the incoming turn into a 384-element `float[]`. +2. `session.appendTurn(threadId, role, content, user, agent, null)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside Lettuce's `multi()` / `exec()` transaction. The session TTL refreshes on every write so an active thread stays alive. +3. `memory.recall(vec, user, namespace, null, 5, threshold)` issues `FT.SEARCH` via `dispatch()` with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.remember(text, vec, user, namespace, kind, threadId, null)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with `JSON.NUMINCRBY` (also via `dispatch()`) — best-effort: if the memory's TTL has elapsed between the recall and the bump, the increment quietly fails and the hit count for that recall is lost. Otherwise a fresh JSON document is written with `JSON.SET` and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) inside the same `multi()` transaction. +5. `events.record(threadId, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) (`XAddArgs.Builder.maxlen(N).approximateTrimming()`), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentSession.java)): + +```java +import com.redis.agentmem.AgentSession; +import com.redis.agentmem.SessionState; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisClient; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.protocol.ProtocolVersion; + +RedisClient client = RedisClient.create("redis://localhost:6379"); +// Pin to RESP2 so FT.SEARCH / FT.INFO reply with the flat arrays +// the helper's parser expects (Lettuce 6.7 otherwise negotiates +// RESP3 against Redis 7+ and wraps the same data in a map shape — +// see the "Sending FT.* / JSON.* commands through Lettuce" section +// below for the longer story). +client.setOptions(ClientOptions.builder() + .protocolVersion(ProtocolVersion.RESP2) + .build()); +StatefulRedisConnection connection = client.connect(ByteArrayCodec.INSTANCE); + +// Shared lock across helpers so concurrent MULTI/EXEC spans on the +// single connection don't interleave queued commands. +Object txLock = new Object(); +AgentSession session = new AgentSession( + connection, txLock, "agent:session:", 3600, 20); + +String threadId = session.newThreadId(); +SessionState state = session.start(threadId, "alice", "demo-agent", + "Plan next week's meetings.", null); +state = session.appendTurn( + threadId, "user", "Schedule a budget review with finance.", + "alice", "demo-agent", null); +System.out.println(state.turnCount() + " " + state.recentTurns().size() + + " " + state.ttlSeconds()); +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `appendTurn`, `setGoal` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside `multi()` / `exec()` so a connection drop between the two writes can't leave the session without a TTL. The shared `txLock` serializes this whole MULTI…EXEC span against any other transaction on the connection — Lettuce's transaction state is connection-scoped, so two concurrent handler threads queueing into the same MULTI would interleave their writes. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LongTermMemory.java)): + +```java +import com.redis.agentmem.LocalEmbedder; +import com.redis.agentmem.LongTermMemory; +import com.redis.agentmem.MemoryRecord; +import com.redis.agentmem.WriteResult; + +LongTermMemory memory = new LongTermMemory( + connection, + txLock, + "agentmem:idx", + "agent:mem:", + 384, + 0.20, // dedup threshold — tight at write time + 0.55, // recall threshold — looser at read time + null); // default per-kind TTL map +LocalEmbedder embedder = LocalEmbedder.create(); +memory.createIndex(); // idempotent + +// Write a memory. The same KNN that powers recall also runs here at +// a tighter threshold so paraphrases of the same fact collapse. +float[] vec = embedder.encodeOne("The user prefers light mode in editors."); +WriteResult result = memory.remember( + "The user prefers light mode in editors.", + vec, + "alice", + "default", + "semantic", + "9f3d2a4b8c61", + null); +System.out.printf("deduped=%s id=%s dist=%s%n", + result.deduped(), result.id(), result.existingDistance()); + +// Recall against a later question. +float[] q = embedder.encodeOne("Which theme does this user like?"); +for (MemoryRecord h : memory.recall(q, "alice", "default", null, 5, null)) { + System.out.printf("%.3f [%s] %s%n", h.distance(), h.kind(), h.text()); +} +``` + +### Sending FT.* / JSON.* commands through Lettuce + +Lettuce 6.7 doesn't ship first-class `FT.*` or `JSON.*` bindings, so the helper sends them through `dispatch()` with a custom `ProtocolKeyword`. Each command is its own keyword whose `getBytes()` returns the literal command name; the rest is built up with `CommandArgs`: + +```java +private enum ModuleCommand implements ProtocolKeyword { + FT_CREATE("FT.CREATE"), + FT_SEARCH("FT.SEARCH"), + FT_INFO("FT.INFO"), + FT_DROPINDEX("FT.DROPINDEX"), + JSON_SET("JSON.SET"), + JSON_NUMINCRBY("JSON.NUMINCRBY"); + // ...implementation elided +} + +CommandArgs args = new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(indexNameBytes) + .add(filterClause + "=>[KNN 5 @embedding $vec AS distance]") + .add("PARAMS").add(2).add("vec".getBytes(UTF_8)).add(vecBytes) + .add("RETURN").add(8).add("user").add("namespace").add("kind") + .add("source_thread").add("text").add("created_ts") + .add("hit_count").add("distance") + .add("SORTBY").add("distance").add("ASC") + .add("LIMIT").add(0).add(5) + .add("DIALECT").add(2); + +List raw = sync.dispatch( + ModuleCommand.FT_SEARCH, + new NestedMultiOutput<>(ByteArrayCodec.INSTANCE), + args); +``` + +The `NestedMultiOutput` collects the RESP reply into a `List` — a flat `[count, key1, fields1, key2, fields2, ...]` array under RESP2, where each `fields` is itself a nested list of alternating field/value pairs. The helper walks that with `parseAllHits()` and `fieldsToMap()` (both private to `LongTermMemory`). + +The connection is pinned to RESP2 at startup with `ClientOptions.builder().protocolVersion(ProtocolVersion.RESP2)`. Lettuce 6.7 otherwise negotiates RESP3 against Redis 7+, which wraps `FT.SEARCH` and `FT.INFO` results in map / set shapes that the flat-array parser would have to special-case across every reply path. + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is stored as a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (`LocalEmbedder.toBytes()` packs them in little-endian order), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with alias names on each path so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write by passing a non-null `ttlSeconds` to `remember`, or hand a different `Map` to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEventLog.java)): + +```java +import com.redis.agentmem.AgentEvent; +import com.redis.agentmem.AgentEventLog; + +AgentEventLog events = new AgentEventLog(connection, "agent:events:", 1000); +events.record(threadId, "turn_appended:user", + "Schedule a budget review with finance."); +events.record(threadId, "memory_written", + "wrote 7c3f8a1b9e02 as semantic"); + +for (AgentEvent e : events.recent(threadId, 20)) { + System.out.println(e.action() + " " + e.detail()); +} +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000` via `XAddArgs.Builder.maxlen(1000).approximateTrimming()`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarisers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession.appendTurn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recentTurns` list in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recentTurns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory.remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `currentThreadId` synchronized through an explicit mutex; `seedAll`, `newThread`, and `handleTurn` each release the lock between operations, so a turn racing with a thread rotation can capture the old id and apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +Two concerns specific to Java + Lettuce: + +* `LocalEmbedder.encodeOne` / `encodeMany` are `synchronized` because the underlying `Predictor` is not thread-safe. The demo's `Executors.newCachedThreadPool` could otherwise call into one predictor from several handler threads at once and corrupt the inference state. A higher-throughput deployment would replace that lock with a small pool of `Predictor` instances or a dedicated single-threaded inference executor. + +* All three helpers share a single Lettuce connection. Lettuce connections are thread-safe for individual command dispatch, but transaction state (`MULTI` / queued commands / `EXEC`) is connection-scoped — two concurrent threads queueing into the same `MULTI` would interleave their writes and each thread's `EXEC` would return a mix of replies. The shared `txLock` is passed into all three helpers (`AgentSession`, `LongTermMemory`, and `AgentEventLog`) and is held around every Redis call they issue, not only the `MULTI…EXEC` spans. That serializes every operation on the connection, which is fine for a single-user demo but trades concurrency for safety; a higher-throughput deployment would use a small pool of connections via Lettuce's `ConnectionPoolSupport` instead so handlers can run their commands in parallel. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `SeedMemory` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SeedMemory.java)): + +```java +import com.redis.agentmem.SeedMemory; + +LongTermMemory memory = new LongTermMemory(connection, txLock, + "agentmem:idx", "agent:mem:", 384, 0.20, 0.55, null); +LocalEmbedder embedder = LocalEmbedder.create(); +memory.createIndex(); +int written = SeedMemory.seed(memory, embedder, "default", "default", "seed"); +System.out.println("seeded " + written + " memories"); +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`DemoServer` runs the JDK's [`HttpServer`](https://docs.oracle.com/en/java/javase/17/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html) on port 8093, with a cached thread pool dispatching requests to handlers. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a mutex-protected `String` field that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/java-lettuce + ``` + +2. Build the fat jar. You'll need a [JDK 17](https://adoptium.net/) or later and + [Maven](https://maven.apache.org/): + + ```bash + mvn -q package + ``` + + The first build pulls Lettuce, DJL, and the PyTorch native libraries — that takes + a couple of minutes the first time and is cached afterwards. + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo. The first run downloads the `sentence-transformers/all-MiniLM-L6-v2` + PyTorch weights into the local DJL cache (~90 MB): + + ```bash + java -jar target/agent-memory-lettuce.jar + ``` + + Or via Maven: `mvn -q exec:java`. + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing + recalls above the threshold (every seed lands above 0.8), and the new + memory is written as `episodic` (or `semantic`, depending on the + dropdown) under a fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + DJL drives libtorch through the same C++ kernel as Python's PyTorch + bindings, so distances here match the Python demo to four decimal + places. `sentence-transformers/all-MiniLM-L6-v2` puts a faithful + paraphrase in the 0.15 – 0.50 cosine-distance range, a loose paraphrase + or related topic in the 0.50 – 0.80 range, and unrelated queries above + 0.8 — which is what motivates the 0.55 default recall threshold and the + 0.20 default dedup threshold. A stricter embedding model (or a + domain-tuned one) would let you tighten both; a noisier one would push + them up. The right thresholds are always a function of the model, the + corpus, and how conservative the agent needs to be about accepting a + memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags (pass them after the jar): + +* `--host` / `--port` — change the HTTP bind address (default `127.0.0.1:8093`). +* `--redis-host` / `--redis-port` — point at a non-local Redis (default `localhost:6379`). +* `--mem-index-name` / `--mem-key-prefix` / `--session-key-prefix` / `--event-key-prefix` — relocate the index name and the three key prefixes (to run several demos against one Redis without colliding, for example). +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/java-lettuce/index.html b/content/develop/use-cases/agent-memory/java-lettuce/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/java-lettuce/pom.xml b/content/develop/use-cases/agent-memory/java-lettuce/pom.xml new file mode 100644 index 0000000000..bcafb79856 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/pom.xml @@ -0,0 +1,135 @@ + + + 4.0.0 + + com.redis + agent-memory-lettuce + 1.0.0 + jar + + Redis Agent Memory Demo (Lettuce) + + Interactive agent-memory demo backed by Redis Hashes, JSON, + Search, and Streams, using Lettuce for Redis access and DJL + (PyTorch) for local sentence embeddings. + + + + 17 + 17 + UTF-8 + 6.7.1.RELEASE + 0.33.0 + 20240303 + + + + + + io.lettuce + lettuce-core + ${lettuce.version} + + + + + ai.djl + api + ${djl.version} + + + ai.djl.huggingface + tokenizers + ${djl.version} + + + ai.djl.pytorch + pytorch-model-zoo + ${djl.version} + + + + + org.json + json + ${json.version} + + + + + agent-memory-lettuce + + + + ${project.basedir} + + index.html + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + 17 + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.3 + + + package + shade + + false + + + com.redis.agentmem.DemoServer + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + org.codehaus.mojo + exec-maven-plugin + 3.5.0 + + com.redis.agentmem.DemoServer + + + + + diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEvent.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEvent.java new file mode 100644 index 0000000000..e4b085def8 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEvent.java @@ -0,0 +1,18 @@ +package com.redis.agentmem; + +/** + * One entry from the per-thread event Stream. + * + *

{@code eventId} is the {@code XADD}-assigned stream id (e.g. + * {@code 1715990400123-0}); {@code ts} is the wall-clock time the + * action happened, stored as a Redis Stream field rather than + * inferred from the stream id because the demo timestamps the action + * on the agent side. + */ +public record AgentEvent( + String eventId, + String threadId, + String action, + String detail, + double ts) { +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEventLog.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEventLog.java new file mode 100644 index 0000000000..c404ef524d --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentEventLog.java @@ -0,0 +1,145 @@ +package com.redis.agentmem; + +import io.lettuce.core.Limit; +import io.lettuce.core.Range; +import io.lettuce.core.StreamMessage; +import io.lettuce.core.XAddArgs; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Append-only event log for an agent thread, backed by a Redis + * Stream. + * + *

Each thread gets a stream at {@code agent:events:{threadId}}. + * Every action the agent takes (a user turn arriving, a memory being + * recalled, a memory being written, a tool being called) is one + * {@code XADD} to that stream. Replay with {@code XREVRANGE} for the + * most recent N events; bound retention with {@code XTRIM MAXLEN ~} + * (Lettuce's {@code XAddArgs.approximateTrimming()}) so the log + * stays cheap regardless of how long the thread has been running. + */ +public final class AgentEventLog { + + public static final long DEFAULT_MAX_LEN = 1000L; + + private final RedisCommands sync; + private final Object txLock; + private final String keyPrefix; + private final long maxLen; + + public AgentEventLog( + StatefulRedisConnection connection, + Object txLock, + String keyPrefix, + long maxLen) { + this.sync = connection.sync(); + this.txLock = txLock; + this.keyPrefix = keyPrefix; + this.maxLen = maxLen > 0 ? maxLen : DEFAULT_MAX_LEN; + } + + public String keyPrefix() { + return keyPrefix; + } + + public long maxLen() { + return maxLen; + } + + public String streamKey(String threadId) { + return keyPrefix + threadId; + } + + private byte[] streamKeyBytes(String threadId) { + return streamKey(threadId).getBytes(StandardCharsets.UTF_8); + } + + /** + * Append one event and return its stream id. + * + *

{@code MAXLEN ~ N} ({@code approximateTrimming()}) keeps the + * stream bounded with near-zero overhead; the exact form forces a + * scan and is rarely worth the cost. + */ + public String record(String threadId, String action, String detail) { + Map fields = new LinkedHashMap<>(); + putUtf8(fields, "action", action == null ? "" : action); + putUtf8(fields, "detail", detail == null ? "" : detail); + putUtf8(fields, "ts", String.format(Locale.ROOT, "%.6f", unixSecs())); + + XAddArgs args = XAddArgs.Builder.maxlen(maxLen).approximateTrimming(); + // The shared `txLock` serializes this XADD against any + // MULTI/EXEC span the session or memory helpers might have + // open on the same connection. Without the lock a concurrent + // `XADD` would land inside the other thread's transaction + // and the demo would see surprising results. + String id; + synchronized (txLock) { + id = sync.xadd(streamKeyBytes(threadId), args, fields); + } + return id == null ? "" : id; + } + + /** Return the most recent events, newest first. */ + public List recent(String threadId, int count) { + // `xrevrange` walks the stream from the highest id back to the + // lowest, so an unbounded range with {@code Limit.from(count)} + // returns the most recent {@code count} entries newest-first. + // Locked against {@code txLock} for the same reason + // {@code record} is — an unguarded read can otherwise land + // inside another helper's open MULTI on the shared connection. + List> entries; + synchronized (txLock) { + entries = sync.xrevrange( + streamKeyBytes(threadId), + Range.unbounded(), + Limit.from(count)); + } + List out = new ArrayList<>(entries.size()); + for (StreamMessage msg : entries) { + Map body = new LinkedHashMap<>(); + for (Map.Entry e : msg.getBody().entrySet()) { + body.put( + new String(e.getKey(), StandardCharsets.UTF_8), + new String(e.getValue(), StandardCharsets.UTF_8)); + } + out.add(new AgentEvent( + msg.getId(), + threadId, + body.getOrDefault("action", ""), + body.getOrDefault("detail", ""), + parseDouble(body.get("ts"), 0.0))); + } + return out; + } + + /** Drop the entire stream for a thread. */ + public boolean clear(String threadId) { + synchronized (txLock) { + return sync.del(streamKeyBytes(threadId)) > 0L; + } + } + + private static void putUtf8(Map fields, String name, String value) { + fields.put( + name.getBytes(StandardCharsets.UTF_8), + (value == null ? "" : value).getBytes(StandardCharsets.UTF_8)); + } + + private static double unixSecs() { + return System.currentTimeMillis() / 1000.0; + } + + private static double parseDouble(String s, double fallback) { + if (s == null || s.isEmpty()) return fallback; + try { return Double.parseDouble(s); } catch (NumberFormatException ex) { return fallback; } + } +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentSession.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentSession.java new file mode 100644 index 0000000000..d4a22dfdb4 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/AgentSession.java @@ -0,0 +1,310 @@ +package com.redis.agentmem; + +import io.lettuce.core.RedisException; +import io.lettuce.core.TransactionResult; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; + +/** + * Working-memory store for an agent session, backed by a Redis Hash. + * + *

Each session is one Hash document at + * {@code agent:session:{threadId}}. The hash holds the running + * scratchpad, the current goal, a rolling window of recent turns + * (serialized as a JSON list to fit in one field), and a few audit + * fields. One {@code HGETALL} returns the whole session in a single + * round trip on every step of the agent loop. + * + *

Every write refreshes the key's TTL with {@code EXPIRE}, so + * idle sessions fall off without a separate cleanup job and active + * sessions stay alive as long as the agent keeps touching them. + * + *

The Lettuce connection is shared with {@link LongTermMemory} + * and {@link AgentEventLog}; the {@code txLock} passed in at + * construction time serializes {@code MULTI}/{@code EXEC} spans + * across the three helpers so concurrent transactions on the same + * connection don't interleave queued commands. + */ +public final class AgentSession { + + public static final int DEFAULT_MAX_TURNS = 20; + + private final StatefulRedisConnection connection; + private final RedisCommands sync; + private final Object txLock; + private final String keyPrefix; + private final byte[] keyPrefixBytes; + private final long defaultTtlSeconds; + private final int maxTurns; + + public AgentSession( + StatefulRedisConnection connection, + Object txLock, + String keyPrefix, + long defaultTtlSeconds, + int maxTurns) { + this.connection = connection; + this.sync = connection.sync(); + this.txLock = txLock; + this.keyPrefix = keyPrefix; + this.keyPrefixBytes = keyPrefix.getBytes(StandardCharsets.UTF_8); + this.defaultTtlSeconds = defaultTtlSeconds; + this.maxTurns = maxTurns > 0 ? maxTurns : DEFAULT_MAX_TURNS; + } + + public String keyPrefix() { + return keyPrefix; + } + + public long defaultTtlSeconds() { + return defaultTtlSeconds; + } + + public int maxTurns() { + return maxTurns; + } + + public String sessionKey(String threadId) { + return keyPrefix + threadId; + } + + private byte[] sessionKeyBytes(String threadId) { + return sessionKey(threadId).getBytes(StandardCharsets.UTF_8); + } + + public String newThreadId() { + return UUID.randomUUID().toString().replace("-", "").substring(0, 12); + } + + /** + * Create a fresh working memory for a thread. Overwrites any + * existing session at the same key. + */ + public SessionState start( + String threadId, String user, String agent, String goal, Long ttlSeconds) { + if (user == null || user.isEmpty()) user = "default"; + if (agent == null || agent.isEmpty()) agent = "default"; + if (goal == null) goal = ""; + long ttl = ttlSeconds != null ? ttlSeconds : defaultTtlSeconds; + double now = unixSecs(); + SessionState state = new SessionState( + threadId, user, agent, goal, "", + 0L, now, now, List.of(), ttl); + write(state, ttl); + return state; + } + + /** Return the session state, or {@code null} if it has expired. */ + public SessionState load(String threadId) { + byte[] key = sessionKeyBytes(threadId); + // Lock the HGETALL+TTL pair against the shared connection's + // transaction state — an unguarded read can otherwise land + // inside another helper's open MULTI/EXEC and the demo would + // see queued responses come back as null. + Map raw; + long ttl; + synchronized (txLock) { + raw = sync.hgetall(key); + ttl = sync.ttl(key); + } + if (raw == null || raw.isEmpty()) return null; + Map fields = new LinkedHashMap<>(); + for (Map.Entry e : raw.entrySet()) { + fields.put( + new String(e.getKey(), StandardCharsets.UTF_8), + new String(e.getValue(), StandardCharsets.UTF_8)); + } + if (ttl < 0) ttl = 0; + return new SessionState( + threadId, + orDefault(fields.get("user"), "default"), + orDefault(fields.get("agent"), "default"), + orEmpty(fields.get("goal")), + orEmpty(fields.get("scratchpad")), + parseLong(fields.get("turn_count"), 0L), + parseDouble(fields.get("created_ts"), 0.0), + parseDouble(fields.get("last_active_ts"), 0.0), + parseTurns(fields.get("recent_turns")), + ttl); + } + + /** + * Append a turn, bound the rolling window, refresh the TTL. + * + *

{@code user} and {@code agent} are only consulted when the + * session does not yet exist — they seed the auto-created session + * so the working-memory hash matches the user the caller is + * operating against. On an existing session they're ignored; the + * original {@code start} values stand. + * + *

Read-modify-write here is last-writer-wins on the turn list + * if two concurrent turns reach the same thread; the demo never + * triggers that race in practice (one browser, one turn at a + * time) but a multi-worker agent that shares a thread id would + * wrap this in {@code WATCH} / {@code MULTI} / {@code EXEC} or a + * Lua script that does the append atomically server-side. + */ + public SessionState appendTurn( + String threadId, + String role, + String content, + String user, + String agent, + Long ttlSeconds) { + SessionState state = load(threadId); + if (state == null) { + state = start(threadId, user, agent, "", ttlSeconds); + } + List turns = new ArrayList<>(state.recentTurns()); + turns.add(new SessionTurn(role, content == null ? "" : content, unixSecs())); + if (turns.size() > maxTurns) { + turns = turns.subList(turns.size() - maxTurns, turns.size()); + } + long ttl = ttlSeconds != null ? ttlSeconds : defaultTtlSeconds; + SessionState next = new SessionState( + state.threadId(), state.user(), state.agent(), + state.goal(), state.scratchpad(), + state.turnCount() + 1, + state.createdTs(), + unixSecs(), + turns, + ttl); + write(next, ttl); + return next; + } + + /** + * Update the goal field without touching turns or the scratchpad. + * Creates the session if it doesn't exist yet. + */ + public SessionState setGoal( + String threadId, String text, String user, String agent, Long ttlSeconds) { + SessionState state = load(threadId); + if (state == null) { + return start(threadId, user, agent, text == null ? "" : text, ttlSeconds); + } + long ttl = ttlSeconds != null ? ttlSeconds : defaultTtlSeconds; + SessionState next = new SessionState( + state.threadId(), state.user(), state.agent(), + text == null ? "" : text, + state.scratchpad(), + state.turnCount(), + state.createdTs(), + unixSecs(), + state.recentTurns(), + ttl); + write(next, ttl); + return next; + } + + /** Drop the session immediately. Returns {@code true} if it existed. */ + public boolean delete(String threadId) { + synchronized (txLock) { + return sync.del(sessionKeyBytes(threadId)) > 0L; + } + } + + public StatefulRedisConnection connection() { + return connection; + } + + private void write(SessionState state, long ttl) { + byte[] key = sessionKeyBytes(state.threadId()); + + JSONArray turnsArr = new JSONArray(); + for (SessionTurn t : state.recentTurns()) { + JSONObject obj = new JSONObject(); + obj.put("role", t.role()); + obj.put("content", t.content()); + obj.put("ts", t.ts()); + turnsArr.put(obj); + } + + Map mapping = new LinkedHashMap<>(); + putUtf8(mapping, "thread_id", state.threadId()); + putUtf8(mapping, "user", state.user()); + putUtf8(mapping, "agent", state.agent()); + putUtf8(mapping, "goal", state.goal()); + putUtf8(mapping, "scratchpad", state.scratchpad()); + putUtf8(mapping, "turn_count", Long.toString(state.turnCount())); + putUtf8(mapping, "created_ts", + String.format(Locale.ROOT, "%.6f", state.createdTs())); + putUtf8(mapping, "last_active_ts", + String.format(Locale.ROOT, "%.6f", state.lastActiveTs())); + putUtf8(mapping, "recent_turns", turnsArr.toString()); + + // MULTI/EXEC so HSET and EXPIRE either both apply or neither + // does. The shared `txLock` serializes this whole MULTI…EXEC + // span against any other transaction on the same connection + // — Lettuce's transaction state is connection-scoped, so two + // concurrent threads queueing into the same MULTI would + // interleave their writes. + TransactionResult txResult; + synchronized (txLock) { + sync.multi(); + sync.hset(key, mapping); + sync.expire(key, ttl); + txResult = sync.exec(); + } + if (txResult == null || txResult.wasDiscarded()) { + throw new RedisException("MULTI/EXEC for session write was discarded"); + } + } + + private static void putUtf8(Map mapping, String field, String value) { + mapping.put( + field.getBytes(StandardCharsets.UTF_8), + (value == null ? "" : value).getBytes(StandardCharsets.UTF_8)); + } + + private static List parseTurns(String blob) { + if (blob == null || blob.isEmpty()) return List.of(); + try { + JSONArray arr = new JSONArray(blob); + List out = new ArrayList<>(arr.length()); + for (int i = 0; i < arr.length(); i++) { + JSONObject o = arr.getJSONObject(i); + out.add(new SessionTurn( + o.optString("role", ""), + o.optString("content", ""), + o.optDouble("ts", 0.0))); + } + return out; + } catch (JSONException ex) { + return List.of(); + } + } + + private static String orDefault(String s, String fallback) { + return (s == null || s.isEmpty()) ? fallback : s; + } + + private static String orEmpty(String s) { + return s == null ? "" : s; + } + + private static long parseLong(String s, long fallback) { + if (s == null || s.isEmpty()) return fallback; + try { return Long.parseLong(s); } catch (NumberFormatException ex) { return fallback; } + } + + private static double parseDouble(String s, double fallback) { + if (s == null || s.isEmpty()) return fallback; + try { return Double.parseDouble(s); } catch (NumberFormatException ex) { return fallback; } + } + + private static double unixSecs() { + return System.currentTimeMillis() / 1000.0; + } +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/DemoServer.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/DemoServer.java new file mode 100644 index 0000000000..026b167ce6 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/DemoServer.java @@ -0,0 +1,667 @@ +package com.redis.agentmem; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.protocol.ProtocolVersion; +import org.json.JSONArray; +import org.json.JSONObject; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; + +/** + * Redis agent-memory demo server (Java + Lettuce). + * + *

Run this main and visit {@code http://localhost:8093} to drive + * a small agent-memory demo backed by Redis Hashes, JSON, Search, + * and Streams. The UI lets you type a turn, watch working memory + * update, see semantically similar long-term memories recalled, and + * inspect the per-thread event log. + * + *

The server holds a single {@link LocalEmbedder}, one + * {@link AgentSession}, one {@link LongTermMemory}, and one + * {@link AgentEventLog} for the lifetime of the process. All three + * helpers share a single Lettuce connection with a {@code byte[]} + * codec and a shared {@code txLock} so {@code MULTI}/{@code EXEC} + * spans don't interleave across the helpers. + */ +public final class DemoServer { + + static final class Args { + String host = "127.0.0.1"; + int port = 8093; + String redisHost = "localhost"; + int redisPort = 6379; + String memIndexName = "agentmem:idx"; + String memKeyPrefix = "agent:mem:"; + String sessionKeyPrefix = "agent:session:"; + String eventKeyPrefix = "agent:events:"; + long sessionTtlSeconds = 3600; + double dedupThreshold = LongTermMemory.DEFAULT_DEDUP_THRESHOLD; + double recallThreshold = LongTermMemory.DEFAULT_RECALL_THRESHOLD; + boolean resetOnStart = true; + } + + public static void main(String[] argv) throws Exception { + Args args = parseArgs(argv); + + RedisClient client = RedisClient.create( + RedisURI.Builder.redis(args.redisHost, args.redisPort).build()); + // Pin the connection to RESP2 so FT.SEARCH and FT.INFO come + // back as the flat arrays our `NestedMultiOutput` parser + // expects. Lettuce 6.7 negotiates RESP3 by default with + // Redis 7+, which wraps the same data in a map/set structure + // we'd have to special-case across both the recall query + // and the FT.INFO snapshot; RESP2 keeps the wire format + // identical to what `redis-cli` shows on the command line. + client.setOptions(ClientOptions.builder() + .protocolVersion(ProtocolVersion.RESP2) + .build()); + StatefulRedisConnection connection; + try { + connection = client.connect(ByteArrayCodec.INSTANCE); + connection.sync().ping(); + } catch (Exception ex) { + System.err.println("Error: cannot reach Redis at " + + args.redisHost + ":" + args.redisPort); + System.err.println(" (" + ex.getMessage() + ")"); + client.shutdown(); + System.exit(1); + return; + } + + // The shared `txLock` serializes MULTI/EXEC spans across the + // three helpers (working memory, long-term memory, event + // log) so the cached-thread-pool handlers can't interleave + // their queued commands. Lettuce connections are otherwise + // thread-safe for individual command dispatch. + Object txLock = new Object(); + + AgentSession session = new AgentSession( + connection, + txLock, + args.sessionKeyPrefix, + args.sessionTtlSeconds, + AgentSession.DEFAULT_MAX_TURNS); + LongTermMemory memory = new LongTermMemory( + connection, + txLock, + args.memIndexName, + args.memKeyPrefix, + LocalEmbedder.defaultVectorDim(), + args.dedupThreshold, + args.recallThreshold, + null); + memory.createIndex(); + AgentEventLog events = new AgentEventLog( + connection, txLock, args.eventKeyPrefix, AgentEventLog.DEFAULT_MAX_LEN); + + System.out.println("Loading embedding model " + + "(first run downloads the PyTorch weights)..."); + LocalEmbedder embedder = LocalEmbedder.create(); + + AgentMemoryDemo demo = new AgentMemoryDemo(session, memory, events, embedder); + + if (args.resetOnStart) { + System.out.println( + "Dropping any existing memories under '" + args.memKeyPrefix + + "*' and re-seeding from the sample memory list " + + "(pass --no-reset to keep)."); + int seeded = demo.seedAll("default", "default"); + System.out.println("Seeded " + seeded + " memories."); + } + + String rawHtml = loadIndexHtml(); + String htmlPage = rawHtml + .replace("__SESSION_PREFIX__", args.sessionKeyPrefix) + .replace("__MEM_PREFIX__", args.memKeyPrefix) + .replace("__MEM_INDEX__", args.memIndexName) + .replace("__EVENT_PREFIX__", args.eventKeyPrefix); + + HttpServer server = HttpServer.create( + new InetSocketAddress(args.host, args.port), 0); + server.setExecutor(Executors.newCachedThreadPool()); + server.createContext("/", new RootHandler(demo, htmlPage)); + + System.out.println("Redis agent memory demo listening on " + + "http://" + args.host + ":" + args.port); + System.out.println("Using Redis at " + args.redisHost + ":" + args.redisPort + + " with memory index '" + args.memIndexName + "'"); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + System.out.println("\nShutting down..."); + server.stop(0); + try { embedder.close(); } catch (Exception ignored) {} + try { connection.close(); } catch (Exception ignored) {} + try { client.shutdown(); } catch (Exception ignored) {} + })); + + server.start(); + } + + // ------------------------------------------------------------------ + // Demo orchestrator + // ------------------------------------------------------------------ + + /** + * Demo state: working memory, long-term memory, event log. + * + *

{@code seedAll} / {@code newThread} / {@code handleTurn} + * all touch {@code currentThreadId} — synchronized through the + * mutex below, but the lock is released between operations so a + * turn racing with a thread rotation can capture the old id and + * apply to the previous thread. The demo is single-user in + * practice, so the race never triggers; a multi-user agent would + * carry the thread id on each request instead of holding it as + * shared server state. See the walkthrough's "Concurrency + * caveats" section. + */ + static final class AgentMemoryDemo { + final AgentSession session; + final LongTermMemory memory; + final AgentEventLog events; + final LocalEmbedder embedder; + final String defaultUser = "default"; + final String defaultNamespace = "default"; + private final Object threadIdLock = new Object(); + private String currentThreadId; + + AgentMemoryDemo(AgentSession session, LongTermMemory memory, + AgentEventLog events, LocalEmbedder embedder) { + this.session = session; + this.memory = memory; + this.events = events; + this.embedder = embedder; + this.currentThreadId = session.newThreadId(); + } + + String currentThreadId() { + synchronized (threadIdLock) { + return currentThreadId; + } + } + + int seedAll(String user, String namespace) throws Exception { + memory.clear(); + String threadId = currentThreadId(); + session.delete(threadId); + events.clear(threadId); + int written = SeedMemory.seed(memory, embedder, user, namespace, "seed"); + synchronized (threadIdLock) { + currentThreadId = session.newThreadId(); + } + return written; + } + + String newThread(String user, String namespace) { + String oldId = currentThreadId(); + events.clear(oldId); + String newId = session.newThreadId(); + session.start(newId, user, "demo-agent", "", null); + events.record(newId, "thread_started", + "user=" + user + " namespace=" + namespace); + synchronized (threadIdLock) { + currentThreadId = newId; + } + return newId; + } + + /** + * One pass through the agent loop: append, recall, remember, log. + */ + Map handleTurn( + String text, + String user, + String namespace, + String kind, + String role, + double threshold, + String action) throws Exception { + String threadId = currentThreadId(); + + long t0 = System.nanoTime(); + float[] vec = embedder.encodeOne(text); + double embedMs = (System.nanoTime() - t0) / 1_000_000.0; + + // `setGoal` only touches the goal field so existing turns + // aren't wiped; `appendTurn` carries the request `user` + // through to the auto-create path so a first turn for a + // new thread doesn't land under the default user. + String sessionAction; + if ("goal".equals(action)) { + session.setGoal(threadId, text, user, "demo-agent", null); + sessionAction = "goal_set"; + } else { + session.appendTurn(threadId, role, text, user, "demo-agent", null); + sessionAction = "turn_appended:" + role; + } + + long t1 = System.nanoTime(); + List recalled = memory.recall( + vec, user, namespace, null, 5, threshold); + double recallMs = (System.nanoTime() - t1) / 1_000_000.0; + + boolean writeSkipped = "skip".equals(kind) || "goal".equals(action); + WriteResult writeResult = null; + double writeMs = 0.0; + if (!writeSkipped) { + long t2 = System.nanoTime(); + writeResult = memory.remember( + text, vec, user, namespace, kind, threadId, null); + writeMs = (System.nanoTime() - t2) / 1_000_000.0; + } + + String detail; + if (writeResult == null) { + detail = ""; + } else if (writeResult.deduped()) { + detail = "deduped onto " + writeResult.id(); + } else { + detail = "wrote " + writeResult.id() + " as " + kind; + } + events.record(threadId, sessionAction, detail); + + Map payload = new LinkedHashMap<>(); + payload.put("thread_id", threadId); + payload.put("write_skipped", writeSkipped); + payload.put("memory_id", writeResult == null ? null : writeResult.id()); + payload.put("deduped", writeResult != null && writeResult.deduped()); + payload.put("existing_distance", + writeResult == null ? null : writeResult.existingDistance()); + payload.put("kind", writeSkipped ? null : kind); + payload.put("recalled", memoriesToJsonArray(recalled)); + payload.put("embed_ms", embedMs); + payload.put("recall_ms", recallMs); + payload.put("write_ms", writeMs); + return payload; + } + + JSONObject buildState(String user, String namespace) { + Map info = memory.indexInfo(); + JSONObject index = new JSONObject(); + index.put("num_docs", info.getOrDefault("num_docs", 0L)); + index.put("indexing_failures", info.getOrDefault("indexing_failures", 0L)); + index.put("index_name", memory.indexName()); + index.put("model", embedder.modelName()); + index.put("session_ttl_seconds", session.defaultTtlSeconds()); + index.put("dedup_threshold", memory.dedupThreshold()); + index.put("default_recall_threshold", memory.recallThreshold()); + index.put("stack_label", + "Lettuce + DJL (PyTorch + HuggingFace) + Java standard library HTTP server"); + + String threadId = currentThreadId(); + SessionState state = session.load(threadId); + JSONArray memories = memoriesToJsonArray( + memory.listMemories(user, namespace, null, 200)); + JSONArray eventArr = new JSONArray(); + for (AgentEvent e : events.recent(threadId, 20)) { + eventArr.put(eventToJson(e)); + } + + JSONObject out = new JSONObject(); + out.put("index", index); + out.put("thread_id", threadId); + out.put("session", state == null ? JSONObject.NULL : sessionToJson(state)); + out.put("memories", memories); + out.put("events", eventArr); + // `recalled` is populated by /turn; on plain /state reads + // the UI keeps showing the last turn's result. + out.put("recalled", new JSONArray()); + return out; + } + } + + // ------------------------------------------------------------------ + // Serialisation helpers + // ------------------------------------------------------------------ + + static JSONObject sessionToJson(SessionState s) { + JSONObject obj = new JSONObject(); + obj.put("thread_id", s.threadId()); + obj.put("user", s.user()); + obj.put("agent", s.agent()); + obj.put("goal", s.goal()); + obj.put("scratchpad", s.scratchpad()); + obj.put("turn_count", s.turnCount()); + obj.put("created_ts", s.createdTs()); + obj.put("last_active_ts", s.lastActiveTs()); + JSONArray turns = new JSONArray(); + for (SessionTurn t : s.recentTurns()) { + JSONObject turn = new JSONObject(); + turn.put("role", t.role()); + turn.put("content", t.content()); + turn.put("ts", t.ts()); + turns.put(turn); + } + obj.put("recent_turns", turns); + obj.put("ttl_seconds", s.ttlSeconds()); + return obj; + } + + static JSONObject memoryToJson(MemoryRecord m) { + JSONObject obj = new JSONObject(); + obj.put("id", m.id()); + obj.put("user", m.user()); + obj.put("namespace", m.namespace()); + obj.put("kind", m.kind()); + obj.put("source_thread", m.sourceThread()); + obj.put("text", m.text()); + obj.put("created_ts", m.createdTs()); + obj.put("hit_count", m.hitCount()); + obj.put("distance", m.distance() == null ? JSONObject.NULL : m.distance()); + obj.put("ttl_seconds", m.ttlSeconds() == null ? JSONObject.NULL : m.ttlSeconds()); + return obj; + } + + static JSONArray memoriesToJsonArray(List records) { + JSONArray arr = new JSONArray(); + for (MemoryRecord m : records) { + arr.put(memoryToJson(m)); + } + return arr; + } + + static JSONObject eventToJson(AgentEvent e) { + JSONObject obj = new JSONObject(); + obj.put("event_id", e.eventId()); + obj.put("thread_id", e.threadId()); + obj.put("action", e.action()); + obj.put("detail", e.detail()); + obj.put("ts", e.ts()); + return obj; + } + + // ------------------------------------------------------------------ + // HTTP plumbing + // ------------------------------------------------------------------ + + static final class RootHandler implements HttpHandler { + private final AgentMemoryDemo demo; + private final String htmlPage; + + RootHandler(AgentMemoryDemo demo, String htmlPage) { + this.demo = demo; + this.htmlPage = htmlPage; + } + + @Override + public void handle(HttpExchange ex) throws IOException { + try { + String method = ex.getRequestMethod(); + URI uri = ex.getRequestURI(); + String path = uri.getPath(); + + if ("GET".equalsIgnoreCase(method)) { + if (path.equals("/") || path.equals("/index.html")) { + sendHtml(ex, 200, htmlPage); + return; + } + if (path.equals("/state")) { + Map q = parseForm(uri.getRawQuery()); + String user = nonEmpty(q.get("user"), "default"); + String ns = nonEmpty(q.get("namespace"), "default"); + sendJson(ex, 200, demo.buildState(user, ns)); + return; + } + sendJson(ex, 404, errorPayload("not found", null)); + return; + } + if ("POST".equalsIgnoreCase(method)) { + String body = readBody(ex); + Map params = parseForm(body); + + if (path.equals("/turn")) { + handleTurn(ex, params); + return; + } + if (path.equals("/new_thread")) { + String user = nonEmpty(params.get("user"), "default"); + String ns = nonEmpty(params.get("namespace"), "default"); + String tid = demo.newThread(user, ns); + JSONObject body2 = new JSONObject(); + body2.put("thread_id", tid); + sendJson(ex, 200, body2); + return; + } + if (path.equals("/reset")) { + String user = nonEmpty(params.get("user"), "default"); + String ns = nonEmpty(params.get("namespace"), "default"); + try { + int seeded = demo.seedAll(user, ns); + JSONObject ok = new JSONObject(); + ok.put("seeded", seeded); + sendJson(ex, 200, ok); + } catch (Exception inner) { + handleException(ex, inner); + } + return; + } + if (path.equals("/drop_memory")) { + String memoryId = params.getOrDefault("memory_id", "").trim(); + if (memoryId.isEmpty()) { + sendJson(ex, 400, errorPayload("memory_id is required", null)); + return; + } + boolean deleted = demo.memory.deleteMemory(memoryId); + JSONObject out = new JSONObject(); + out.put("deleted", deleted); + out.put("memory_id", memoryId); + sendJson(ex, 200, out); + return; + } + sendJson(ex, 404, errorPayload("not found", null)); + return; + } + sendJson(ex, 405, errorPayload("method not allowed", null)); + } catch (Exception exc) { + handleException(ex, exc); + } + } + + private void handleTurn(HttpExchange ex, Map params) + throws IOException { + String text = params.getOrDefault("text", "").trim(); + if (text.isEmpty()) { + sendJson(ex, 400, errorPayload("text is required", null)); + return; + } + double threshold = clampThreshold( + params.get("threshold"), demo.memory.recallThreshold()); + try { + Map payload = demo.handleTurn( + text, + nonEmpty(params.get("user"), "default"), + nonEmpty(params.get("namespace"), "default"), + nonEmpty(params.get("kind"), "episodic"), + nonEmpty(params.get("role"), "user"), + threshold, + nonEmpty(params.get("action"), "turn")); + sendJson(ex, 200, toJson(payload)); + } catch (Exception inner) { + handleException(ex, inner); + } + } + + private void handleException(HttpExchange ex, Exception exc) { + System.err.println("[demo] handler error: " + + exc.getClass().getSimpleName() + ": " + exc.getMessage()); + exc.printStackTrace(System.err); + try { + JSONObject body = errorPayload( + exc.getMessage() == null ? exc.getClass().getSimpleName() : exc.getMessage(), + exc.getClass().getSimpleName()); + sendJson(ex, 500, body); + } catch (Exception ignored) { + } + } + } + + // ------------------------------------------------------------------ + // Helpers + // ------------------------------------------------------------------ + + static double clampThreshold(String raw, double fallback) { + if (raw == null || raw.isEmpty()) return fallback; + double parsed; + try { + parsed = Double.parseDouble(raw); + } catch (NumberFormatException ex) { + return fallback; + } + if (Double.isNaN(parsed) || Double.isInfinite(parsed)) return fallback; + return Math.max(0.0, Math.min(2.0, parsed)); + } + + private static String nonEmpty(String value, String fallback) { + return (value == null || value.isEmpty()) ? fallback : value; + } + + private static final int MAX_BODY_BYTES = 1 * 1024 * 1024; + + private static String readBody(HttpExchange ex) throws IOException { + try (InputStream in = ex.getRequestBody()) { + byte[] bytes = in.readNBytes(MAX_BODY_BYTES + 1); + if (bytes.length > MAX_BODY_BYTES) { + throw new IOException( + "request body exceeds " + MAX_BODY_BYTES + " bytes"); + } + return new String(bytes, StandardCharsets.UTF_8); + } + } + + static Map parseForm(String body) { + Map out = new HashMap<>(); + if (body == null || body.isEmpty()) return out; + for (String pair : body.split("&")) { + if (pair.isEmpty()) continue; + int eq = pair.indexOf('='); + String key, value; + if (eq < 0) { + key = URLDecoder.decode(pair, StandardCharsets.UTF_8); + value = ""; + } else { + key = URLDecoder.decode(pair.substring(0, eq), StandardCharsets.UTF_8); + value = URLDecoder.decode(pair.substring(eq + 1), StandardCharsets.UTF_8); + } + out.put(key, value); + } + return out; + } + + private static void sendHtml(HttpExchange ex, int status, String html) throws IOException { + byte[] bytes = html.getBytes(StandardCharsets.UTF_8); + ex.getResponseHeaders().set("Content-Type", "text/html; charset=utf-8"); + ex.sendResponseHeaders(status, bytes.length); + ex.getResponseBody().write(bytes); + ex.getResponseBody().close(); + } + + private static void sendJson(HttpExchange ex, int status, JSONObject body) throws IOException { + byte[] bytes = body.toString().getBytes(StandardCharsets.UTF_8); + ex.getResponseHeaders().set("Content-Type", "application/json"); + ex.sendResponseHeaders(status, bytes.length); + ex.getResponseBody().write(bytes); + ex.getResponseBody().close(); + } + + private static JSONObject errorPayload(String message, String type) { + JSONObject out = new JSONObject(); + out.put("error", message); + if (type != null) out.put("type", type); + return out; + } + + private static JSONObject toJson(Map map) { + JSONObject out = new JSONObject(); + for (Map.Entry entry : map.entrySet()) { + Object value = entry.getValue(); + if (value == null) { + out.put(entry.getKey(), JSONObject.NULL); + } else { + out.put(entry.getKey(), value); + } + } + return out; + } + + private static String loadIndexHtml() throws IOException { + try (InputStream in = DemoServer.class.getResourceAsStream("/index.html")) { + if (in == null) { + throw new IOException( + "index.html not found on classpath; rebuild with `mvn package`"); + } + return new String(in.readAllBytes(), StandardCharsets.UTF_8); + } + } + + // ------------------------------------------------------------------ + // CLI parsing + // ------------------------------------------------------------------ + + static Args parseArgs(String[] argv) { + Args args = new Args(); + for (int i = 0; i < argv.length; i++) { + String a = argv[i]; + switch (a) { + case "--host": args.host = require(argv, ++i, a); break; + case "--port": args.port = Integer.parseInt(require(argv, ++i, a)); break; + case "--redis-host": args.redisHost = require(argv, ++i, a); break; + case "--redis-port": args.redisPort = Integer.parseInt(require(argv, ++i, a)); break; + case "--mem-index-name": args.memIndexName = require(argv, ++i, a); break; + case "--mem-key-prefix": args.memKeyPrefix = require(argv, ++i, a); break; + case "--session-key-prefix": args.sessionKeyPrefix = require(argv, ++i, a); break; + case "--event-key-prefix": args.eventKeyPrefix = require(argv, ++i, a); break; + case "--session-ttl-seconds": args.sessionTtlSeconds = Long.parseLong(require(argv, ++i, a)); break; + case "--dedup-threshold": args.dedupThreshold = Double.parseDouble(require(argv, ++i, a)); break; + case "--recall-threshold": args.recallThreshold = Double.parseDouble(require(argv, ++i, a)); break; + case "--no-reset": args.resetOnStart = false; break; + case "-h": + case "--help": + printHelp(); + System.exit(0); + break; + default: + throw new IllegalArgumentException("Unknown flag: " + a); + } + } + return args; + } + + private static String require(String[] argv, int i, String flag) { + if (i >= argv.length) { + throw new IllegalArgumentException("Missing value for " + flag); + } + return argv[i]; + } + + private static void printHelp() { + System.out.println("Usage: java -jar agent-memory-lettuce.jar [options]"); + System.out.println(" --host HOST HTTP bind host (default 127.0.0.1)"); + System.out.println(" --port PORT HTTP bind port (default 8093)"); + System.out.println(" --redis-host HOST Redis host (default localhost)"); + System.out.println(" --redis-port PORT Redis port (default 6379)"); + System.out.println(" --mem-index-name NAME Memory search index (default agentmem:idx)"); + System.out.println(" --mem-key-prefix PREFIX JSON memory key prefix (default agent:mem:)"); + System.out.println(" --session-key-prefix PREFIX Session hash key prefix (default agent:session:)"); + System.out.println(" --event-key-prefix PREFIX Event stream key prefix (default agent:events:)"); + System.out.println(" --session-ttl-seconds N Working-memory TTL (default 3600)"); + System.out.println(" --dedup-threshold F Cosine distance for dedup (default 0.20)"); + System.out.println(" --recall-threshold F Cosine distance for recall (default 0.55)"); + System.out.println(" --no-reset Skip clearing and re-seeding on startup"); + } +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LocalEmbedder.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LocalEmbedder.java new file mode 100644 index 0000000000..831e56d3bb --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LocalEmbedder.java @@ -0,0 +1,176 @@ +package com.redis.agentmem; + +import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory; +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +/** + * Local text-embedding helper backed by DJL + PyTorch. + * + *

This is a thin wrapper around the + * {@code sentence-transformers/all-MiniLM-L6-v2} model loaded from + * DJL's model zoo: a 384-dimensional encoder that runs in-process on + * CPU through libtorch, needs no API key, and produces vectors that + * are numerically very close to the equivalent Python and Node ports + * (close enough that paraphrase distances differ only at the fourth + * decimal place). + * + *

DJL's {@link TextEmbeddingTranslatorFactory} returns mean-pooled + * vectors. They are normalized by default for cosine similarity, but + * the demo L2-normalises the result explicitly in {@link #encodeOne} + * before returning. That belt-and-braces step makes the cosine + * distance reported by Redis Search numerically equivalent to what + * the Python and Go ports produce, regardless of whether a future + * DJL release tweaks its default normalization behavior. + */ +public final class LocalEmbedder implements AutoCloseable { + + private static final String DEFAULT_MODEL_URL = + "djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2"; + private static final String DEFAULT_MODEL_NAME = + "sentence-transformers/all-MiniLM-L6-v2"; + private static final int DEFAULT_VECTOR_DIM = 384; + + private final String modelName; + private final ZooModel model; + private final Predictor predictor; + private final int dim; + + private LocalEmbedder( + String modelName, + ZooModel model, + Predictor predictor, + int dim) { + this.modelName = modelName; + this.model = model; + this.predictor = predictor; + this.dim = dim; + } + + /** + * Load the default model. Blocks while DJL downloads the + * PyTorch weights on the first run, then keeps a single loaded + * predictor for the lifetime of the embedder. + */ + public static LocalEmbedder create() throws Exception { + Criteria criteria = Criteria.builder() + .setTypes(String.class, float[].class) + .optModelUrls(DEFAULT_MODEL_URL) + .optEngine("PyTorch") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .optProgress(new ProgressBar()) + .build(); + ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + // Probe the output shape once so we fail loudly if a + // different model is wired up against the 384-dim Redis + // Search field. + float[] probe = predictor.predict("dimension probe"); + int dim = probe.length; + return new LocalEmbedder(DEFAULT_MODEL_NAME, model, predictor, dim); + } + + public String modelName() { + return modelName; + } + + public int dim() { + return dim; + } + + /** + * Encode a single string. Returns a {@code float[]} of length + * {@link #dim()}, L2-normalized in place. + * + *

The DJL PyTorch {@code Predictor} is not thread-safe — its + * underlying NDManager and tokenizer state mutate per call. The + * demo server uses a cached thread pool, so two browser tabs + * could land on different handler threads and call this method + * concurrently. We {@code synchronized}-guard both encode entry + * points to serialize access to the shared predictor; encoding + * is the bottleneck either way and a single CPU-bound model + * won't usefully run two requests in parallel. A higher- + * throughput deployment would replace this with a small pool + * of {@code Predictor} instances or a dedicated single-threaded + * inference executor. + */ + public synchronized float[] encodeOne(String text) throws Exception { + float[] vector = predictor.predict(text); + l2Normalise(vector); + return vector; + } + + /** Encode several strings sequentially. See {@link #encodeOne} + * for the rationale behind the synchronisation. */ + public synchronized List encodeMany(List texts) throws Exception { + List out = new ArrayList<>(texts.size()); + for (String text : texts) { + float[] vector = predictor.predict(text); + l2Normalise(vector); + out.add(vector); + } + return out; + } + + /** + * Scale {@code vector} to unit length in place. DJL's default + * translator already returns near-unit vectors for this model, + * so the multiplier sits right on top of {@code 1.0} — but + * re-normalising explicitly insulates the demo from any future + * change in the translator's defaults, and a vector that has + * drifted by even one part in a million would otherwise leak + * into the cosine distance the demo prints to the UI. + */ + private static void l2Normalise(float[] vector) { + double sumSq = 0.0; + for (float v : vector) { + sumSq += (double) v * (double) v; + } + if (sumSq <= 0.0) return; + double norm = Math.sqrt(sumSq); + float scale = (float) (1.0 / norm); + for (int i = 0; i < vector.length; i++) { + vector[i] = vector[i] * scale; + } + } + + /** + * Pack a {@code float[]} into the bytes Redis Search expects. + * Vectors are little-endian {@code float32}; this matches the + * encoding the Python and Node ports write. + */ + public static byte[] toBytes(float[] vector) { + byte[] bytes = new byte[Float.BYTES * vector.length]; + ByteBuffer + .wrap(bytes) + .order(ByteOrder.LITTLE_ENDIAN) + .asFloatBuffer() + .put(vector); + return bytes; + } + + @Override + public void close() { + try { + predictor.close(); + } catch (Exception ignored) { + // best-effort cleanup + } + try { + model.close(); + } catch (Exception ignored) { + // best-effort cleanup + } + } + + public static int defaultVectorDim() { + return DEFAULT_VECTOR_DIM; + } +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LongTermMemory.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LongTermMemory.java new file mode 100644 index 0000000000..5dc4476669 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/LongTermMemory.java @@ -0,0 +1,648 @@ +package com.redis.agentmem; + +import io.lettuce.core.RedisException; +import io.lettuce.core.TransactionResult; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.output.NestedMultiOutput; +import io.lettuce.core.output.StatusOutput; +import io.lettuce.core.output.ValueOutput; +import io.lettuce.core.protocol.CommandArgs; +import io.lettuce.core.protocol.ProtocolKeyword; +import org.json.JSONObject; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.UUID; + +/** + * Long-term memory store for an agent, backed by Redis JSON and + * Search. + * + *

Each memory lives as one JSON document at + * {@code agent:mem:}. The document holds the memory text, its + * embedding vector, and a small metadata block — user, namespace, + * kind, source thread, timestamps — that lets the recall query + * scope results without falling back to application-side filtering. + * + *

A single Redis Search index covers the embedding plus every + * metadata field, so one {@code FT.SEARCH} call performs + * approximate-nearest-neighbour over the in-scope subset and + * returns the top-k memories ranked by cosine distance. The same + * KNN check runs at write time to deduplicate + * near-identical memories before they enter the store. + * + *

Memories carry one of two kinds: + * + *

    + *
  • {@code episodic} — "what happened" snapshots from a specific + * thread, written with a medium TTL so old session detail + * decays naturally.
  • + *
  • {@code semantic} — distilled facts and preferences the + * agent should carry forward indefinitely. Written with no + * TTL by default.
  • + *
+ * + *

Lettuce 6.7 doesn't ship first-class {@code FT.*} or + * {@code JSON.*} bindings, so the helper sends them through + * {@code dispatch()} with custom {@link ProtocolKeyword}s. Everything + * else — {@code EXPIRE}, {@code TTL}, {@code DEL}, + * {@code MULTI}/{@code EXEC} — goes through the built-in synchronous + * API on a {@code byte[]}-codec connection. + */ +public final class LongTermMemory { + + public static final int VECTOR_DIM_DEFAULT = 384; + public static final double DEFAULT_DEDUP_THRESHOLD = 0.20; + public static final double DEFAULT_RECALL_THRESHOLD = 0.55; + + /** Default per-kind TTLs (seconds); {@code null} value = no TTL. */ + public static Map defaultTtlByKind() { + Map out = new HashMap<>(); + out.put("episodic", 7L * 24 * 3600); + out.put("semantic", null); + return out; + } + + /** + * Characters Redis Search treats as syntax inside a TAG value; + * any of them in a user-supplied filter must be backslash-escaped + * or the surrounding {@code {...}} block won't parse correctly. + */ + private static final String TAG_SPECIAL = "\\,.<>{}[]\"':;!@#$%^&*()-+=~| "; + + /** + * Custom {@link ProtocolKeyword}s for the Redis Search and Redis + * JSON subcommands the helper sends via {@code dispatch()}. + * Lettuce 6.7 has no native bindings; {@code dispatch()} accepts + * any keyword whose {@code getBytes()} returns the raw command + * name, so spelling each command out as its own keyword keeps + * the wire bytes matching what {@code redis-cli} would send. + */ + private enum ModuleCommand implements ProtocolKeyword { + FT_CREATE("FT.CREATE"), + FT_SEARCH("FT.SEARCH"), + FT_INFO("FT.INFO"), + FT_DROPINDEX("FT.DROPINDEX"), + JSON_SET("JSON.SET"), + JSON_NUMINCRBY("JSON.NUMINCRBY"); + + private final byte[] bytes; + private final String wire; + + ModuleCommand(String wire) { + this.wire = wire; + this.bytes = wire.getBytes(StandardCharsets.US_ASCII); + } + + @Override + public byte[] getBytes() { + return bytes; + } + + @Override + public String toString() { + return wire; + } + } + + private final RedisCommands sync; + private final Object txLock; + private final String indexName; + private final String keyPrefix; + private final byte[] indexNameBytes; + private final byte[] keyPrefixBytes; + private final int vectorDim; + private final double dedupThreshold; + private final double recallThreshold; + private final Map ttlByKind; + + public LongTermMemory( + StatefulRedisConnection connection, + Object txLock, + String indexName, + String keyPrefix, + int vectorDim, + double dedupThreshold, + double recallThreshold, + Map ttlByKind) { + this.sync = connection.sync(); + this.txLock = txLock; + this.indexName = indexName; + this.keyPrefix = keyPrefix; + this.indexNameBytes = indexName.getBytes(StandardCharsets.UTF_8); + this.keyPrefixBytes = keyPrefix.getBytes(StandardCharsets.UTF_8); + this.vectorDim = vectorDim > 0 ? vectorDim : VECTOR_DIM_DEFAULT; + // Thresholds are honored as-is. Zero is a legitimate value + // ("exact matches only" for dedup, "nothing recalls" for + // recall); silently rewriting them would make + // --dedup-threshold 0 uncallable. + this.dedupThreshold = dedupThreshold < 0 ? DEFAULT_DEDUP_THRESHOLD : dedupThreshold; + this.recallThreshold = recallThreshold < 0 ? DEFAULT_RECALL_THRESHOLD : recallThreshold; + this.ttlByKind = ttlByKind != null ? ttlByKind : defaultTtlByKind(); + } + + public String indexName() { return indexName; } + public String keyPrefix() { return keyPrefix; } + public int vectorDim() { return vectorDim; } + public double dedupThreshold() { return dedupThreshold; } + public double recallThreshold() { return recallThreshold; } + + public String memoryKey(String memoryId) { + return keyPrefix + memoryId; + } + + private byte[] memoryKeyBytes(String memoryId) { + return memoryKey(memoryId).getBytes(StandardCharsets.UTF_8); + } + + // ------------------------------------------------------------------ + // Index management + // ------------------------------------------------------------------ + + public void createIndex() { + CommandArgs args = new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(indexNameBytes) + .add("ON").add("JSON") + .add("PREFIX").add(1).add(keyPrefixBytes) + .add("SCHEMA") + .add("$.text").add("AS").add("text").add("TEXT") + .add("$.user").add("AS").add("user").add("TAG") + .add("$.namespace").add("AS").add("namespace").add("TAG") + .add("$.kind").add("AS").add("kind").add("TAG") + .add("$.source_thread").add("AS").add("source_thread").add("TAG") + .add("$.created_ts").add("AS").add("created_ts") + .add("NUMERIC").add("SORTABLE") + .add("$.hit_count").add("AS").add("hit_count") + .add("NUMERIC").add("SORTABLE") + .add("$.embedding").add("AS").add("embedding") + .add("VECTOR").add("HNSW").add(6) + .add("TYPE").add("FLOAT32") + .add("DIM").add(vectorDim) + .add("DISTANCE_METRIC").add("COSINE"); + try { + synchronized (txLock) { + sync.dispatch( + ModuleCommand.FT_CREATE, + new StatusOutput<>(ByteArrayCodec.INSTANCE), + args); + } + } catch (RedisException ex) { + if (!String.valueOf(ex.getMessage()).contains("Index already exists")) { + throw ex; + } + } + } + + public void dropIndex(boolean deleteDocuments) { + CommandArgs args = new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(indexNameBytes); + if (deleteDocuments) args.add("DD"); + try { + synchronized (txLock) { + sync.dispatch( + ModuleCommand.FT_DROPINDEX, + new StatusOutput<>(ByteArrayCodec.INSTANCE), + args); + } + } catch (RedisException ex) { + String msg = String.valueOf(ex.getMessage()).toLowerCase(Locale.ROOT); + if (!msg.contains("no such index") && !msg.contains("unknown index name")) { + throw ex; + } + } + } + + // ------------------------------------------------------------------ + // Write + // ------------------------------------------------------------------ + + /** + * Write a new memory, deduplicating against existing entries. + * + *

Runs one in-scope KNN(1) against the index first. If the + * nearest existing memory is within {@link #dedupThreshold()}, + * the new memory is skipped (its content is already represented) + * and the existing memory's {@code hit_count} is bumped via + * {@code JSON.NUMINCRBY}. Otherwise a fresh JSON document is + * written under a new id with a TTL derived from the memory's + * {@code kind}, inside a {@code MULTI/EXEC} so the JSON.SET and + * EXPIRE apply as a unit. + */ + public WriteResult remember( + String text, + float[] embedding, + String user, + String namespace, + String kind, + String sourceThread, + Long ttlSeconds) { + if (embedding.length != vectorDim) { + throw new IllegalArgumentException( + "embedding length is " + embedding.length + + "; index expects " + vectorDim); + } + if (user == null || user.isEmpty()) user = "default"; + if (namespace == null || namespace.isEmpty()) namespace = "default"; + if (kind == null || kind.isEmpty()) kind = "episodic"; + + List nearest = nearest(embedding, user, namespace, kind, 1); + Double existingDistance = !nearest.isEmpty() ? nearest.get(0).distance() : null; + if (!nearest.isEmpty() + && existingDistance != null + && existingDistance <= dedupThreshold) { + bumpHitCount(nearest.get(0).id()); + return new WriteResult(nearest.get(0).id(), true, existingDistance); + } + + String id = UUID.randomUUID().toString().replace("-", "").substring(0, 12); + byte[] keyBytes = memoryKeyBytes(id); + double now = unixSecs(); + JSONObject doc = new JSONObject(); + doc.put("id", id); + doc.put("user", user); + doc.put("namespace", namespace); + doc.put("kind", kind); + doc.put("source_thread", sourceThread == null ? "" : sourceThread); + doc.put("text", text == null ? "" : text); + // org.json's JSONObject.put(String, Object) serializes a + // float[] as a JSON array of numbers — exactly what the JSON + // vector field expects at index time. + doc.put("embedding", embedding); + doc.put("created_ts", now); + doc.put("hit_count", 0); + byte[] docBytes = doc.toString().getBytes(StandardCharsets.UTF_8); + + Long ttl = ttlSeconds != null ? ttlSeconds : ttlByKind.get(kind); + + // MULTI/EXEC so JSON.SET and EXPIRE either both apply or + // neither does — a connection drop between the two writes + // would otherwise leave an episodic memory without its + // intended seven-day TTL. The shared `txLock` serializes + // this whole MULTI…EXEC span against any other transaction + // on the connection. + TransactionResult txResult; + synchronized (txLock) { + sync.multi(); + sync.dispatch( + ModuleCommand.JSON_SET, + new StatusOutput<>(ByteArrayCodec.INSTANCE), + new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(keyBytes) + .add("$") + .add(docBytes)); + if (ttl != null && ttl > 0) { + sync.expire(keyBytes, ttl); + } + txResult = sync.exec(); + } + if (txResult == null || txResult.wasDiscarded()) { + throw new RedisException("MULTI/EXEC for remember was discarded"); + } + return new WriteResult(id, false, existingDistance); + } + + // ------------------------------------------------------------------ + // Recall + // ------------------------------------------------------------------ + + public List recall( + float[] queryEmbedding, + String user, + String namespace, + String kind, + int k, + Double distanceThreshold) { + if (k <= 0) k = 5; + double threshold = distanceThreshold != null ? distanceThreshold : recallThreshold; + List candidates = nearest(queryEmbedding, user, namespace, kind, k); + List out = new ArrayList<>(candidates.size()); + for (MemoryRecord c : candidates) { + if (c.distance() != null && c.distance() <= threshold) { + out.add(c); + } + } + return out; + } + + // ------------------------------------------------------------------ + // Admin / inspection + // ------------------------------------------------------------------ + + public Map indexInfo() { + Map out = new LinkedHashMap<>(); + out.put("num_docs", 0L); + out.put("indexing_failures", 0L); + CommandArgs args = new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(indexNameBytes); + List raw; + try { + synchronized (txLock) { + raw = sync.dispatch( + ModuleCommand.FT_INFO, + new NestedMultiOutput<>(ByteArrayCodec.INSTANCE), + args); + } + } catch (RedisException ignored) { + return out; + } + Map info = pairsToMap(raw); + out.put("num_docs", parseLong(info.get("num_docs"), 0L)); + out.put("indexing_failures", + parseLong(info.get("hash_indexing_failures"), 0L)); + return out; + } + + public List listMemories( + String user, String namespace, String kind, int limit) { + if (limit <= 0) limit = 100; + String filterClause = buildFilterClause(user, namespace, kind); + + CommandArgs args = new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(indexNameBytes) + .add(filterClause) + .add("RETURN").add(7) + .add("user").add("namespace").add("kind").add("source_thread") + .add("text").add("created_ts").add("hit_count") + .add("SORTBY").add("created_ts").add("DESC") + .add("LIMIT").add(0).add(limit) + .add("DIALECT").add(2); + + // Hold `txLock` for the whole FT.SEARCH + per-row TTL fetch + // so this batch of commands lands together on the connection + // and can't get tangled with another helper's MULTI/EXEC. + synchronized (txLock) { + List raw; + try { + raw = sync.dispatch( + ModuleCommand.FT_SEARCH, + new NestedMultiOutput<>(ByteArrayCodec.INSTANCE), + args); + } catch (RedisException ex) { + return new ArrayList<>(); + } + List hits = parseAllHits(raw); + List out = new ArrayList<>(hits.size()); + for (SearchHit hit : hits) { + String memoryId = stripPrefix(hit.id); + long ttl = sync.ttl(memoryKeyBytes(memoryId)); + Long ttlSeconds = ttl > 0 ? ttl : null; + out.add(toRecord(memoryId, hit, null, ttlSeconds)); + } + return out; + } + } + + public boolean deleteMemory(String memoryId) { + synchronized (txLock) { + return sync.del(memoryKeyBytes(memoryId)) > 0L; + } + } + + public long clear() { + long before = parseLong(indexInfo().get("num_docs"), 0L); + dropIndex(true); + createIndex(); + return before; + } + + // ------------------------------------------------------------------ + // Internals + // ------------------------------------------------------------------ + + private List nearest( + float[] embedding, String user, String namespace, String kind, int k) { + if (embedding.length != vectorDim) { + throw new IllegalArgumentException( + "embedding length is " + embedding.length + + "; index expects " + vectorDim); + } + String filterClause = buildFilterClause(user, namespace, kind); + String knnQuery = filterClause + "=>[KNN " + k + " @embedding $vec AS distance]"; + byte[] vecBytes = LocalEmbedder.toBytes(embedding); + + CommandArgs args = new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(indexNameBytes) + .add(knnQuery) + .add("RETURN").add(8) + .add("user").add("namespace").add("kind").add("source_thread") + .add("text").add("created_ts").add("hit_count").add("distance") + .add("SORTBY").add("distance").add("ASC") + .add("LIMIT").add(0).add(k) + .add("PARAMS").add(2).add("vec".getBytes(StandardCharsets.UTF_8)).add(vecBytes) + .add("DIALECT").add(2); + + // Same connection-discipline as `listMemories`: hold the + // shared lock for FT.SEARCH + the per-row TTL fetches so the + // batch can't tangle with a MULTI/EXEC on another helper. + synchronized (txLock) { + List raw = sync.dispatch( + ModuleCommand.FT_SEARCH, + new NestedMultiOutput<>(ByteArrayCodec.INSTANCE), + args); + List hits = parseAllHits(raw); + List out = new ArrayList<>(hits.size()); + for (SearchHit hit : hits) { + // `hit.id` is the full Redis key (e.g. `agent:mem:abc123`). + // Strip the prefix so the returned record exposes only + // the opaque id the UI and `deleteMemory` work with. + String memoryId = stripPrefix(hit.id); + long ttl = sync.ttl(memoryKeyBytes(memoryId)); + Long ttlSeconds = ttl > 0 ? ttl : null; + Double distance = parseDoubleOrNull(hit.fields.get("distance")); + out.add(toRecord(memoryId, hit, distance, ttlSeconds)); + } + return out; + } + } + + private void bumpHitCount(String memoryId) { + try { + // Fire-and-forget — the doc may have expired between + // recall and bump, and discarding the error keeps the + // demo from blowing up on that race; we just lose the + // hit-count update. The shared lock keeps the + // JSON.NUMINCRBY from landing inside another helper's + // open MULTI on the connection. + synchronized (txLock) { + sync.dispatch( + ModuleCommand.JSON_NUMINCRBY, + new ValueOutput<>(ByteArrayCodec.INSTANCE), + new CommandArgs<>(ByteArrayCodec.INSTANCE) + .add(memoryKeyBytes(memoryId)) + .add("$.hit_count") + .add(1)); + } + } catch (RedisException ignored) { + // memory expired or path not found + } + } + + private static MemoryRecord toRecord( + String memoryId, SearchHit hit, Double distance, Long ttlSeconds) { + return new MemoryRecord( + memoryId, + nullSafe(hit.fields.get("user")), + nullSafe(hit.fields.get("namespace")), + nullSafe(hit.fields.get("kind")), + nullSafe(hit.fields.get("source_thread")), + nullSafe(hit.fields.get("text")), + parseDouble(hit.fields.get("created_ts"), 0.0), + parseLong(hit.fields.get("hit_count"), 0L), + distance, + ttlSeconds); + } + + private String stripPrefix(String rawKey) { + return rawKey.startsWith(keyPrefix) ? rawKey.substring(keyPrefix.length()) : rawKey; + } + + static String escapeTagValue(String value) { + StringBuilder out = new StringBuilder(value.length()); + for (int i = 0; i < value.length(); i++) { + char ch = value.charAt(i); + if (TAG_SPECIAL.indexOf(ch) >= 0) { + out.append('\\'); + } + out.append(ch); + } + return out.toString(); + } + + static String buildFilterClause(String user, String namespace, String kind) { + List clauses = new ArrayList<>(3); + if (user != null && !user.isEmpty()) { + clauses.add("@user:{" + escapeTagValue(user) + "}"); + } + if (namespace != null && !namespace.isEmpty()) { + clauses.add("@namespace:{" + escapeTagValue(namespace) + "}"); + } + if (kind != null && !kind.isEmpty()) { + clauses.add("@kind:{" + escapeTagValue(kind) + "}"); + } + if (clauses.isEmpty()) return "(*)"; + return "(" + String.join(" ", clauses) + ")"; + } + + // ------------------------------------------------------------------ + // FT.SEARCH / FT.INFO parsing + // ------------------------------------------------------------------ + + private static final class SearchHit { + final String id; + final Map fields; + + SearchHit(String id, Map fields) { + this.id = id; + this.fields = fields; + } + } + + /** Parse every hit from an FT.SEARCH reply, preserving order. */ + private static List parseAllHits(List raw) { + List out = new ArrayList<>(); + if (raw == null || raw.size() < 3) return out; + // index 0 holds the total count; entries follow as + // (key, fields)* pairs. + for (int i = 1; i + 1 < raw.size(); i += 2) { + String id = decode(raw.get(i)); + Map fields = fieldsToMap(raw.get(i + 1)); + out.add(new SearchHit(id, fields)); + } + return out; + } + + private static Map fieldsToMap(Object array) { + Map out = new HashMap<>(); + if (!(array instanceof List list)) return out; + for (int i = 0; i + 1 < list.size(); i += 2) { + out.put(decode(list.get(i)), decode(list.get(i + 1))); + } + return out; + } + + private static Map pairsToMap(List raw) { + Map out = new HashMap<>(); + if (raw == null) return out; + for (int i = 0; i + 1 < raw.size(); i += 2) { + String key = decode(raw.get(i)); + Object value = raw.get(i + 1); + if (value instanceof byte[] bytes) { + out.put(key, new String(bytes, StandardCharsets.UTF_8)); + } else if (value != null) { + out.put(key, value); + } + } + return out; + } + + private static String decode(Object value) { + if (value == null) return null; + if (value instanceof byte[] bytes) { + return new String(bytes, StandardCharsets.UTF_8); + } + if (value instanceof String s) return s; + return value.toString(); + } + + // ------------------------------------------------------------------ + // Helpers + // ------------------------------------------------------------------ + + private static String nullSafe(String s) { + return s == null ? "" : s; + } + + private static double unixSecs() { + return System.currentTimeMillis() / 1000.0; + } + + private static double parseDouble(Object value, double dflt) { + if (value == null) return dflt; + if (value instanceof Number n) return n.doubleValue(); + String s = (value instanceof byte[] bytes) + ? new String(bytes, StandardCharsets.UTF_8) + : value.toString(); + try { + return Double.parseDouble(s); + } catch (NumberFormatException ex) { + return dflt; + } + } + + private static Double parseDoubleOrNull(Object value) { + if (value == null) return null; + if (value instanceof Number n) return n.doubleValue(); + String s = (value instanceof byte[] bytes) + ? new String(bytes, StandardCharsets.UTF_8) + : value.toString(); + try { + return Double.parseDouble(s); + } catch (NumberFormatException ex) { + return null; + } + } + + private static long parseLong(Object value, long dflt) { + if (value == null) return dflt; + if (value instanceof Number n) return n.longValue(); + String s = (value instanceof byte[] bytes) + ? new String(bytes, StandardCharsets.UTF_8) + : value.toString(); + try { + return Long.parseLong(s); + } catch (NumberFormatException ex) { + try { + return (long) Double.parseDouble(s); + } catch (NumberFormatException ignored) { + return dflt; + } + } + } +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/MemoryRecord.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/MemoryRecord.java new file mode 100644 index 0000000000..1723172230 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/MemoryRecord.java @@ -0,0 +1,21 @@ +package com.redis.agentmem; + +/** + * A single long-term memory document. + * + *

{@code distance} is set only when the record comes back from a + * KNN query; {@code ttlSeconds} is {@code null} for memories with no + * TTL (e.g. {@code kind=semantic} under the default tier map). + */ +public record MemoryRecord( + String id, + String user, + String namespace, + String kind, + String sourceThread, + String text, + double createdTs, + long hitCount, + Double distance, + Long ttlSeconds) { +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SeedMemory.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SeedMemory.java new file mode 100644 index 0000000000..68df260546 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SeedMemory.java @@ -0,0 +1,96 @@ +package com.redis.agentmem; + +import java.util.ArrayList; +import java.util.List; + +/** + * Pre-seed the long-term memory store with sample memories. + * + *

In a real deployment the memory store fills up organically as + * the agent reasons over user turns: each turn produces zero or more + * memories (preferences, facts, episodic summaries) that flow into + * the store with deduplication. To make the demo immediately useful + * — so the first recall query lands on relevant results instead of + * an empty list — we seed a small set of canonical memories for a + * default user at startup. + * + *

The seed list mixes {@code semantic} memories (long-lived + * preferences and facts) with {@code episodic} memories (snapshots + * of past sessions), matching what the Python, Node, .NET, Rust, and + * Go demos seed so all six implementations behave identically. + */ +public final class SeedMemory { + + private SeedMemory() {} + + public record SeedEntry(String text, String kind) {} + + public static final List SEED_MEMORIES = List.of( + new SeedEntry( + "The user prefers concise answers without filler phrases.", + "semantic"), + new SeedEntry( + "The user is a Python developer working on a logistics platform.", + "semantic"), + new SeedEntry( + "The user lives in Berlin and works in the Europe/Berlin time zone.", + "semantic"), + new SeedEntry( + "The user dislikes dark mode and prefers a high-contrast light " + + "theme in editors and dashboards.", + "semantic"), + new SeedEntry( + "The user is allergic to peanuts; any restaurant suggestion must " + + "avoid dishes that commonly contain them.", + "semantic"), + new SeedEntry( + "Last Tuesday the user asked the agent to draft a postmortem for " + + "the order-routing outage. The agent produced a five-section " + + "draft and the user approved sections 1, 2, and 4 with minor " + + "edits.", + "episodic"), + new SeedEntry( + "In a previous session the user asked for help debugging a flaky " + + "test in the inventory service. The fix turned out to be a " + + "race condition in the warehouse webhook handler.", + "episodic"), + new SeedEntry( + "Two weeks ago the user mentioned they were planning to migrate " + + "the analytics warehouse from Snowflake to BigQuery in Q3.", + "episodic") + ); + + /** + * Embed and write the seed memories. Returns the count actually + * written (entries that dedup against existing memories don't + * count). + */ + public static int seed( + LongTermMemory memory, + LocalEmbedder embedder, + String user, + String namespace, + String sourceThread) throws Exception { + List texts = new ArrayList<>(SEED_MEMORIES.size()); + for (SeedEntry s : SEED_MEMORIES) { + texts.add(s.text()); + } + List vectors = embedder.encodeMany(texts); + int written = 0; + for (int i = 0; i < SEED_MEMORIES.size(); i++) { + SeedEntry entry = SEED_MEMORIES.get(i); + WriteResult result = memory.remember( + entry.text(), + vectors.get(i), + user, + namespace, + entry.kind(), + sourceThread, + null); + if (!result.deduped()) { + written++; + } + } + return written; + } +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SessionState.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SessionState.java new file mode 100644 index 0000000000..67b0f6af2d --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SessionState.java @@ -0,0 +1,23 @@ +package com.redis.agentmem; + +import java.util.List; + +/** + * The full per-thread working-memory state. + * + *

{@code recentTurns} is bounded by {@code AgentSession.maxTurns()}; + * the hash itself never grows in size or field count as the + * conversation goes on. + */ +public record SessionState( + String threadId, + String user, + String agent, + String goal, + String scratchpad, + long turnCount, + double createdTs, + double lastActiveTs, + List recentTurns, + long ttlSeconds) { +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SessionTurn.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SessionTurn.java new file mode 100644 index 0000000000..f9c88bd4ef --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/SessionTurn.java @@ -0,0 +1,11 @@ +package com.redis.agentmem; + +/** + * One turn inside the rolling session window. + * + *

Stored as part of a JSON array on the + * {@code agent:session:{threadId}} hash; the embedder helper does not + * see this directly. + */ +public record SessionTurn(String role, String content, double ts) { +} diff --git a/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/WriteResult.java b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/WriteResult.java new file mode 100644 index 0000000000..b2cf903b49 --- /dev/null +++ b/content/develop/use-cases/agent-memory/java-lettuce/src/main/java/com/redis/agentmem/WriteResult.java @@ -0,0 +1,13 @@ +package com.redis.agentmem; + +/** + * Outcome of a {@link LongTermMemory#remember} call. + * + *

{@code deduped} is {@code true} when the write skipped because a + * similar memory already existed; {@code id} is then the existing + * memory's id. {@code existingDistance} is the cosine distance to + * that nearest memory regardless of which branch was taken — useful + * for tracing. + */ +public record WriteResult(String id, boolean deduped, Double existingDistance) { +} diff --git a/content/develop/use-cases/agent-memory/nodejs/_index.md b/content/develop/use-cases/agent-memory/nodejs/_index.md new file mode 100644 index 0000000000..082b6c7071 --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/_index.md @@ -0,0 +1,343 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Node.js with node-redis, @xenova/transformers, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: node-redis example (Node.js) +title: Redis agent memory with node-redis +weight: 2 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Node.js with [`node-redis`]({{< relref "/develop/clients/nodejs" >}}) and the [`@xenova/transformers`](https://www.npmjs.com/package/@xenova/transformers) library, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with Node's standard `http` module so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [`@xenova/transformers`](https://www.npmjs.com/package/@xenova/transformers) running the ONNX-exported [`Xenova/all-MiniLM-L6-v2`](https://huggingface.co/Xenova/all-MiniLM-L6-v2) model, which is the same encoder the [Python example]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) uses. Embeddings produced by the two implementations are numerically very close — paraphrase distances differ only at the fourth decimal place — so a memory written by one demo can be recalled by the other against the same Redis instance, and the distance bands the Python walkthrough quotes carry over to this one without recalibration. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* A single round trip per tier: one [`HGETALL`]({{< relref "/commands/hgetall" >}}) for the session, one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) for recall, one [`XADD`]({{< relref "/commands/xadd" >}}) for the event log. +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.encodeOne(text)` to turn the incoming turn into a 384-dimensional `Float32Array`. +2. `session.appendTurn(threadId, { role, content })` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with an [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI/EXEC`]({{< relref "/commands/multi" >}}). The session TTL refreshes on every write so an active thread stays alive. +3. `memory.recall({ queryEmbedding: vec, user, namespace, k: 5 })` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.remember({ text, embedding: vec, user, namespace, kind })` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `eventLog.record(threadId, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/nodejs/sessionStore.js)): + +```javascript +import { createClient } from 'redis'; +import { AgentSession } from './sessionStore.js'; + +const client = createClient(); +await client.connect(); + +const session = new AgentSession({ + client, + keyPrefix: 'agent:session:', + defaultTtlSeconds: 3600, // one hour + maxTurns: 20, // rolling window per thread +}); + +const threadId = session.newThreadId(); +await session.start(threadId, { + user: 'alice', + agent: 'demo-agent', + goal: "Plan next week's meetings.", +}); +await session.appendTurn(threadId, { + role: 'user', + content: 'Schedule a budget review with finance.', +}); +const state = await session.load(threadId); +console.log(state.turn_count, state.recent_turns.length, state.ttl_seconds); +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `appendTurn`, `setScratchpad` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/nodejs/longTermMemory.js)): + +```javascript +import { LongTermMemory } from './longTermMemory.js'; +import { LocalEmbedder } from './embeddings.js'; + +const memory = new LongTermMemory({ + client, + indexName: 'agentmem:idx', + keyPrefix: 'agent:mem:', + dedupThreshold: 0.20, // cosine distance — tight at write time + recallThreshold: 0.55, // looser at read time +}); +const embedder = await LocalEmbedder.create(); +await memory.createIndex(); // idempotent + +// Write a memory. The same KNN that powers recall also runs here at +// a tighter threshold so paraphrases of the same fact collapse. +const vec = await embedder.encodeOne('The user prefers light mode in editors.'); +const result = await memory.remember({ + text: 'The user prefers light mode in editors.', + embedding: vec, + user: 'alice', + namespace: 'default', + kind: 'semantic', + sourceThread: '9f3d2a4b8c61', +}); +console.log(result.deduped, result.id, result.existingDistance); + +// Recall against a later question. +const q = await embedder.encodeOne('Which theme does this user like?'); +const hits = await memory.recall({ + queryEmbedding: q, + user: 'alice', + namespace: 'default', + k: 5, +}); +for (const h of hits) { + console.log(`${h.distance.toFixed(3)} [${h.kind}] ${h.text}`); +} +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (a `Buffer` view over a `Float32Array` in Node), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with `AS` aliases so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write with `ttlSeconds: ...` on `remember`, or pass a different `ttlByKind: { ... }` to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/nodejs/eventLog.js)): + +```javascript +import { AgentEventLog } from './eventLog.js'; + +const events = new AgentEventLog({ client, maxLen: 1000 }); +await events.record(threadId, 'turn_appended:user', + 'Schedule a budget review with finance.'); +await events.record(threadId, 'memory_written', + 'wrote 7c3f8a1b9e02 as semantic'); + +for (const event of await events.recent(threadId, 20)) { + console.log(event.action, event.detail); +} +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarisers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession.appendTurn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recent_turns` list in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recent_turns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory.remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `currentThreadId` that `/new_thread` and `/reset` mutate. `handleTurn` reads it without coordination, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `seedMemory.js` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/nodejs/seedMemory.js)): + +```javascript +import { seed } from './seedMemory.js'; +import { LongTermMemory } from './longTermMemory.js'; +import { LocalEmbedder } from './embeddings.js'; + +const memory = new LongTermMemory({ client }); +const embedder = await LocalEmbedder.create(); +await memory.createIndex(); +await seed(memory, embedder, { user: 'default', namespace: 'default' }); +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`demoServer.js` runs a Node [`http`](https://nodejs.org/api/http.html) server on port 8089. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a single field on the demo object that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/nodejs + ``` + +2. Install the dependencies: + + ```bash + npm install + ``` + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo server. The first run downloads the + [`Xenova/all-MiniLM-L6-v2`](https://huggingface.co/Xenova/all-MiniLM-L6-v2) ONNX + weights (around 90 MB) into the local Hugging Face cache: + + ```bash + npm start + ``` + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.18 (the + ONNX-exported model runs slightly different arithmetic from the + PyTorch one, so paraphrase distances sit a hair higher than in the + Python demo), with `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing recalls + above the threshold (every seed lands above 0.8), and the new memory is + written as `episodic` (or `semantic`, depending on the dropdown) under a + fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + `Xenova/all-MiniLM-L6-v2` puts a faithful paraphrase in the 0.15 – 0.50 + cosine-distance range, a loose paraphrase or related topic in the 0.50 – 0.80 + range, and unrelated queries above 0.8 — which is what motivates the 0.55 + default recall threshold and the 0.20 default dedup threshold. A stricter + embedding model (or a domain-tuned one) would let you tighten both; a + noisier one would push them up. The right thresholds are always a function + of the model, the corpus, and how conservative the agent needs to be about + accepting a memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags: + +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/nodejs/demoServer.js b/content/develop/use-cases/agent-memory/nodejs/demoServer.js new file mode 100644 index 0000000000..deceae1047 --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/demoServer.js @@ -0,0 +1,464 @@ +#!/usr/bin/env node +// Redis agent-memory demo server (Node.js). +// +// Run this file and visit http://localhost:8089 to drive a small +// agent-memory demo backed by Redis Hashes, JSON, Search, and +// Streams. The UI lets you: +// +// * Type a turn as the user (or paste a goal). The server appends +// the turn to the per-thread working-memory hash, embeds the +// turn, recalls the top-k semantically nearest long-term memories, +// optionally writes the turn back as a new memory with write-time +// deduplication, and appends an event to the per-thread stream. +// * Watch the three memory tiers update in place: working memory in +// one Hash, long-term memories as JSON documents under one index, +// and the event log in one Stream. +// * Switch user, namespace, kind, and recall threshold to see how +// scoping changes which memories the agent sees. +// * Inspect every long-term memory and drop individual memories to +// simulate eviction. +// +// The server holds a single `LocalEmbedder`, one `AgentSession`, +// one `LongTermMemory`, and one `AgentEventLog` for the lifetime of +// the process. The first run downloads the embedding model into the +// local Hugging Face cache; everything after is local. + +import { createServer } from 'node:http'; +import { readFile } from 'node:fs/promises'; +import { fileURLToPath } from 'node:url'; +import { dirname, join } from 'node:path'; +import { parseArgs } from 'node:util'; +import { createClient } from 'redis'; + +import { LocalEmbedder } from './embeddings.js'; +import { AgentSession } from './sessionStore.js'; +import { AgentEventLog } from './eventLog.js'; +import { LongTermMemory } from './longTermMemory.js'; +import { seed } from './seedMemory.js'; + +const HERE = dirname(fileURLToPath(import.meta.url)); + +// Demo state: working memory, long-term memory, event log. +class AgentMemoryDemo { + constructor({ + sessionStore, + memory, + eventLog, + embedder, + defaultUser = 'default', + defaultNamespace = 'default', + }) { + this.sessionStore = sessionStore; + this.memory = memory; + this.eventLog = eventLog; + this.embedder = embedder; + this.defaultUser = defaultUser; + this.defaultNamespace = defaultNamespace; + this.currentThreadId = sessionStore.newThreadId(); + } + + // Drop everything in scope and pre-populate with seed memories. + async seed(user, namespace) { + await this.memory.clear(); + await this.sessionStore.delete(this.currentThreadId); + await this.eventLog.clear(this.currentThreadId); + const written = await seed(this.memory, this.embedder, { + user, namespace, sourceThread: 'seed', + }); + this.currentThreadId = this.sessionStore.newThreadId(); + return written; + } + + // Start a fresh thread. Long-term memory is unaffected. + async newThread(user, namespace) { + await this.eventLog.clear(this.currentThreadId); + this.currentThreadId = this.sessionStore.newThreadId(); + await this.sessionStore.start(this.currentThreadId, { + user, agent: 'demo-agent', goal: '', + }); + await this.eventLog.record( + this.currentThreadId, + 'thread_started', + `user=${user} namespace=${namespace}`, + ); + return this.currentThreadId; + } + + // One pass through the agent loop: append, recall, remember, log. + // + // The order matters. We embed once and reuse the vector for both + // the recall and (if asked) the remember step — no point encoding + // the same text twice. Recall runs *before* the remember write so + // the agent doesn't see its own just-written turn as a recalled + // memory. + async handleTurn({ + text, user, namespace, kind, role, threshold, action, + }) { + const threadId = this.currentThreadId; + + const t0 = performance.now(); + const vec = await this.embedder.encodeOne(text); + const embedMs = performance.now() - t0; + + // `setGoal` only touches the goal field so existing turns aren't + // wiped; `appendTurn` carries the request `user` through to the + // auto-create path so a first turn for a new thread doesn't land + // under the default user. + let sessionAction; + if (action === 'goal') { + await this.sessionStore.setGoal(threadId, text, { + user, agent: 'demo-agent', + }); + sessionAction = 'goal_set'; + } else { + await this.sessionStore.appendTurn(threadId, { + role, content: text, user, agent: 'demo-agent', + }); + sessionAction = `turn_appended:${role}`; + } + + const t1 = performance.now(); + const recalled = await this.memory.recall({ + queryEmbedding: vec, + user, + namespace, + k: 5, + distanceThreshold: threshold, + }); + const recallMs = performance.now() - t1; + + const writeSkipped = (kind === 'skip' || action === 'goal'); + let writeResult = null; + let writeMs = 0; + if (!writeSkipped) { + const t2 = performance.now(); + writeResult = await this.memory.remember({ + text, + embedding: vec, + user, + namespace, + kind, + sourceThread: threadId, + }); + writeMs = performance.now() - t2; + } + + if (writeResult) { + const eventDetail = writeResult.deduped + ? `deduped onto ${writeResult.id}` + : `wrote ${writeResult.id} as ${kind}`; + await this.eventLog.record(threadId, sessionAction, eventDetail); + } else { + await this.eventLog.record(threadId, sessionAction, ''); + } + + return { + thread_id: threadId, + write_skipped: writeSkipped, + memory_id: writeResult?.id ?? null, + deduped: writeResult?.deduped ?? false, + existing_distance: writeResult?.existingDistance ?? null, + kind: writeSkipped ? null : kind, + recalled, + embed_ms: embedMs, + recall_ms: recallMs, + write_ms: writeMs, + }; + } +} + +// ---- HTTP plumbing -------------------------------------------------- + +function sendJson(res, payload, status = 200) { + res.writeHead(status, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify(payload)); +} + +function sendHtml(res, html, status = 200) { + res.writeHead(status, { 'Content-Type': 'text/html; charset=utf-8' }); + res.end(html); +} + +// Cap POST bodies so a runaway client (or a curl --data-binary +// @big-file by mistake) can't accumulate unbounded memory before the +// handler runs. The demo's largest legitimate body is a few hundred +// bytes of form-encoded query fields; 1 MiB is a generous ceiling. +const MAX_BODY_BYTES = 1 * 1024 * 1024; + +async function readBody(req) { + return new Promise((resolve, reject) => { + const chunks = []; + let total = 0; + req.on('data', c => { + total += c.length; + if (total > MAX_BODY_BYTES) { + req.destroy(); + reject(new Error(`request body exceeds ${MAX_BODY_BYTES} bytes`)); + return; + } + chunks.push(c); + }); + req.on('end', () => resolve(Buffer.concat(chunks).toString('utf-8'))); + req.on('error', reject); + }); +} + +function parseForm(body) { + const params = new URLSearchParams(body); + const result = {}; + for (const [k, v] of params) result[k] = v; + return result; +} + +function clampThreshold(raw, fallback) { + const parsed = parseFloat(raw); + // `parseFloat` happily handles "nan" → NaN, "inf" → Infinity. Either + // would silently turn the lookup into a permanent match or a + // permanent miss. Clamp to the meaningful cosine-distance range so + // a malformed POST can't override the threshold semantics. + if (!Number.isFinite(parsed)) return fallback; + return Math.max(0.0, Math.min(2.0, parsed)); +} + +async function buildState({ + demo, memory, sessionStore, eventLog, embedder, user, namespace, stackLabel, +}) { + const info = await memory.indexInfo(); + const threadId = demo.currentThreadId; + const [session, memories, events] = await Promise.all([ + sessionStore.load(threadId), + memory.listMemories({ user, namespace, limit: 200 }), + eventLog.recent(threadId, 20), + ]); + return { + index: { + ...info, + index_name: memory.indexName, + model: embedder.modelName, + session_ttl_seconds: sessionStore.defaultTtlSeconds, + dedup_threshold: memory.dedupThreshold, + default_recall_threshold: memory.recallThreshold, + stack_label: stackLabel, + }, + thread_id: threadId, + session, + memories, + events, + // `recalled` is populated by /turn; on plain /state reads the UI + // keeps showing the last turn's result, which is the useful + // behaviour for an "agent" panel. + recalled: [], + }; +} + +function makeHandler({ + demo, memory, sessionStore, eventLog, embedder, htmlPage, stackLabel, +}) { + return async (req, res) => { + try { + const url = new URL(req.url, 'http://localhost'); + if (req.method === 'GET') { + if (url.pathname === '/' || url.pathname === '/index.html') { + return sendHtml(res, htmlPage); + } + if (url.pathname === '/state') { + const user = url.searchParams.get('user') || demo.defaultUser; + const namespace = url.searchParams.get('namespace') + || demo.defaultNamespace; + return sendJson(res, await buildState({ + demo, memory, sessionStore, eventLog, embedder, + user, namespace, stackLabel, + })); + } + return sendJson(res, { error: 'not found' }, 404); + } + if (req.method === 'POST') { + const body = await readBody(req); + const params = parseForm(body); + + if (url.pathname === '/turn') { + const text = (params.text || '').trim(); + if (!text) return sendJson(res, { error: 'text is required' }, 400); + const threshold = clampThreshold( + params.threshold ?? String(memory.recallThreshold), + memory.recallThreshold, + ); + const payload = await demo.handleTurn({ + text, + user: params.user || 'default', + namespace: params.namespace || 'default', + kind: params.kind || 'episodic', + role: params.role || 'user', + threshold, + action: params.action || 'turn', + }); + return sendJson(res, payload); + } + if (url.pathname === '/new_thread') { + const threadId = await demo.newThread( + params.user || 'default', + params.namespace || 'default', + ); + return sendJson(res, { thread_id: threadId }); + } + if (url.pathname === '/reset') { + const seeded = await demo.seed( + params.user || 'default', + params.namespace || 'default', + ); + return sendJson(res, { seeded }); + } + if (url.pathname === '/drop_memory') { + const memoryId = (params.memory_id || '').trim(); + if (!memoryId) { + return sendJson(res, { error: 'memory_id is required' }, 400); + } + const deleted = await memory.deleteMemory(memoryId); + return sendJson(res, { deleted, memory_id: memoryId }); + } + return sendJson(res, { error: 'not found' }, 404); + } + return sendJson(res, { error: 'method not allowed' }, 405); + } catch (exc) { + // Without this wrapper, an exception escapes to the default + // Node error handler and the client's `await res.json()` would + // explode with an opaque parse error instead of surfacing what + // actually went wrong. + process.stderr.write( + `[demo] handler error: ${exc?.stack || exc}\n`, + ); + try { + sendJson(res, { + error: String(exc?.message || exc), + type: exc?.name || 'Error', + }, 500); + } catch { + // Headers may already be partially flushed; nothing left to do. + } + } + }; +} + +// ---- Main ----------------------------------------------------------- + +function parseFlags() { + const { values } = parseArgs({ + options: { + host: { type: 'string', default: '127.0.0.1' }, + port: { type: 'string', default: '8089' }, + 'redis-host': { type: 'string', default: 'localhost' }, + 'redis-port': { type: 'string', default: '6379' }, + 'mem-index-name': { type: 'string', default: 'agentmem:idx' }, + 'mem-key-prefix': { type: 'string', default: 'agent:mem:' }, + 'session-key-prefix': { type: 'string', default: 'agent:session:' }, + 'event-key-prefix': { type: 'string', default: 'agent:events:' }, + 'session-ttl-seconds': { type: 'string', default: '3600' }, + 'dedup-threshold': { type: 'string', default: '0.20' }, + 'recall-threshold': { type: 'string', default: '0.55' }, + 'no-reset': { type: 'boolean', default: false }, + }, + }); + return values; +} + +async function main() { + const args = parseFlags(); + const port = Number(args.port); + const redisHost = args['redis-host']; + const redisPort = Number(args['redis-port']); + const memIndexName = args['mem-index-name']; + const memKeyPrefix = args['mem-key-prefix']; + const sessionKeyPrefix = args['session-key-prefix']; + const eventKeyPrefix = args['event-key-prefix']; + const sessionTtlSeconds = Number(args['session-ttl-seconds']); + const dedupThreshold = Number(args['dedup-threshold']); + const recallThreshold = Number(args['recall-threshold']); + const resetOnStart = !args['no-reset']; + + const client = createClient({ socket: { host: redisHost, port: redisPort } }); + client.on('error', err => console.error('[redis]', err)); + try { + await client.connect(); + await client.ping(); + } catch (exc) { + console.error(`Error: cannot reach Redis at ${redisHost}:${redisPort}`); + console.error(` (${exc.message || exc})`); + process.exit(1); + } + + const sessionStore = new AgentSession({ + client, + keyPrefix: sessionKeyPrefix, + defaultTtlSeconds: sessionTtlSeconds, + }); + const memory = new LongTermMemory({ + client, + indexName: memIndexName, + keyPrefix: memKeyPrefix, + dedupThreshold, + recallThreshold, + }); + await memory.createIndex(); + const eventLog = new AgentEventLog({ + client, + keyPrefix: eventKeyPrefix, + }); + + console.log( + 'Loading embedding model (first run downloads the ONNX weights)...', + ); + const embedder = await LocalEmbedder.create(); + + const demo = new AgentMemoryDemo({ + sessionStore, memory, eventLog, embedder, + }); + + if (resetOnStart) { + console.log( + `Dropping any existing memories under '${memKeyPrefix}*' and ` + + 're-seeding from the sample memory list (pass --no-reset to keep).', + ); + const seeded = await demo.seed('default', 'default'); + console.log(`Seeded ${seeded} memories.`); + } + + // Load the HTML once and replace the template tokens with the + // configured key prefixes and index name so the docs panel shows + // the actual values in use. + const rawHtml = await readFile(join(HERE, 'index.html'), 'utf-8'); + const htmlPage = rawHtml + .replaceAll('__SESSION_PREFIX__', sessionKeyPrefix) + .replaceAll('__MEM_PREFIX__', memKeyPrefix) + .replaceAll('__MEM_INDEX__', memIndexName) + .replaceAll('__EVENT_PREFIX__', eventKeyPrefix); + + const stackLabel = + 'node-redis + @xenova/transformers + Node.js standard library HTTP server'; + const server = createServer(makeHandler({ + demo, memory, sessionStore, eventLog, embedder, htmlPage, stackLabel, + })); + server.listen(port, args.host, () => { + console.log( + `Redis agent memory demo listening on http://${args.host}:${port}`, + ); + console.log( + `Using Redis at ${redisHost}:${redisPort}` + + ` with memory index '${memIndexName}'`, + ); + }); + + // Clean shutdown so the Redis client closes its socket. + const shutdown = async (signal) => { + console.log(`\nReceived ${signal}, shutting down...`); + server.close(); + try { await client.disconnect(); } catch {} + process.exit(0); + }; + process.on('SIGINT', () => shutdown('SIGINT')); + process.on('SIGTERM', () => shutdown('SIGTERM')); +} + +main().catch(err => { + console.error(err); + process.exit(1); +}); diff --git a/content/develop/use-cases/agent-memory/nodejs/embeddings.js b/content/develop/use-cases/agent-memory/nodejs/embeddings.js new file mode 100644 index 0000000000..997c1a1f5c --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/embeddings.js @@ -0,0 +1,76 @@ +// Local text-embedding helper backed by @xenova/transformers. +// +// This is a thin wrapper around the ONNX-exported sentence-transformers +// model `Xenova/all-MiniLM-L6-v2`: a 384-dimensional encoder that runs +// in-process on CPU through ONNX Runtime Web, needs no API key, and +// produces vectors that are numerically very close to the equivalent +// PyTorch model (close enough that paraphrase distances differ only at +// the fourth decimal place, so the distance bands quoted in the +// Python walkthrough carry over to this demo without re-calibration). +// +// Vectors are L2-normalised so a Redis Search index declared with +// `DISTANCE_METRIC COSINE` returns scores that are directly comparable +// across entries. The model is downloaded into the local Hugging Face +// cache on the first call; every later call runs offline. + +import { env, pipeline } from '@xenova/transformers'; + +// Allow the local cache to satisfy subsequent runs without re-downloading. +env.allowLocalModels = true; + +const DEFAULT_MODEL = 'Xenova/all-MiniLM-L6-v2'; + +export class LocalEmbedder { + // Use `LocalEmbedder.create(...)` instead of `new LocalEmbedder(...)` + // because the pipeline load is async; we want one place that owns + // the wait and the dimension probe. + constructor(modelName, extractor, dim) { + this.modelName = modelName; + this.extractor = extractor; + this.dim = dim; + } + + static async create(modelName = DEFAULT_MODEL) { + const extractor = await pipeline('feature-extraction', modelName); + // Probe the output shape once and record it on the instance so + // callers can compare against the index's expected vectorDim + // before doing any inserts. LongTermMemory also checks length on + // every remember / recall, so a model swap that produces + // wrong-dim vectors fails at the call site with a clear error. + const probe = await extractor('dimension probe', { + pooling: 'mean', normalize: true, + }); + const dim = probe.dims[probe.dims.length - 1]; + return new LocalEmbedder(modelName, extractor, dim); + } + + // Encode a single string. Returns a Float32Array of length `dim`. + async encodeOne(text) { + const out = await this.extractor(text, { + pooling: 'mean', normalize: true, + }); + return new Float32Array(out.data); + } + + // Encode several strings in one pipeline call. Returns an array of + // Float32Array; callers that need raw bytes use `toBytes` per row. + async encodeMany(texts) { + const out = await this.extractor(texts, { + pooling: 'mean', normalize: true, + }); + const rows = out.dims[0]; + const cols = out.dims[1]; + const result = []; + for (let i = 0; i < rows; i++) { + result.push(new Float32Array(out.data.slice(i * cols, (i + 1) * cols))); + } + return result; + } + + // Pack a Float32Array into the bytes Redis Search expects. + // Float32Array.buffer is little-endian on every architecture we care + // about — Node runs on x86_64/arm64, both little-endian. + static toBytes(vector) { + return Buffer.from(vector.buffer, vector.byteOffset, vector.byteLength); + } +} diff --git a/content/develop/use-cases/agent-memory/nodejs/eventLog.js b/content/develop/use-cases/agent-memory/nodejs/eventLog.js new file mode 100644 index 0000000000..824f270d58 --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/eventLog.js @@ -0,0 +1,83 @@ +// Append-only event log for an agent thread, backed by a Redis Stream. +// +// Each thread gets a stream at `agent:events:{threadId}`. Every +// action the agent takes (a user turn arriving, a memory being +// recalled, a memory being written, a tool being called) is one +// `XADD` to that stream. Replay with `XREVRANGE` for the most recent +// N events; bound retention with `XTRIM MAXLEN ~` so the log stays +// cheap regardless of how long the thread has been running. +// +// The stream is independent of the session hash (`sessionStore.js`) +// and the long-term memory store (`longTermMemory.js`): it answers +// the "what just happened" question without competing with either of +// those for indexing or memory budget. Consumer groups (not used in +// this demo) would let downstream workers — summarisers, +// consolidators, audit pipelines — replay the log without losing +// position. + +// Approximate cap on stream length. `MAXLEN ~` lets Redis trim in +// whole-node units instead of exactly-N units, which is much cheaper +// at the cost of overshooting the bound by up to a node's worth. +export const DEFAULT_MAXLEN = 1000; + +export class AgentEventLog { + constructor({ + client, + keyPrefix = 'agent:events:', + maxLen = DEFAULT_MAXLEN, + }) { + this.client = client; + this.keyPrefix = keyPrefix; + this.maxLen = maxLen; + } + + streamKey(threadId) { + return `${this.keyPrefix}${threadId}`; + } + + // Append one event and return its stream id. + // + // `MAXLEN ~ N` keeps the stream bounded with near-zero overhead; + // an exact bound (`MAXLEN N` without the tilde) forces a scan and + // is rarely worth the cost. + async record(threadId, action, detail = '') { + return await this.client.xAdd( + this.streamKey(threadId), + '*', + { + action, + detail, + ts: String(Date.now() / 1000), + }, + { + TRIM: { + strategy: 'MAXLEN', + strategyModifier: '~', + threshold: this.maxLen, + }, + }, + ); + } + + // Return the most recent events, newest first. + async recent(threadId, count = 20) { + const rows = await this.client.xRevRange( + this.streamKey(threadId), '+', '-', { COUNT: count }, + ); + return rows.map(r => ({ + event_id: r.id, + thread_id: threadId, + action: r.message?.action ?? '', + detail: r.message?.detail ?? '', + ts: parseFloat(r.message?.ts ?? '0') || 0, + })); + } + + async length(threadId) { + return await this.client.xLen(this.streamKey(threadId)); + } + + async clear(threadId) { + return (await this.client.del(this.streamKey(threadId))) > 0; + } +} diff --git a/content/develop/use-cases/agent-memory/nodejs/index.html b/content/develop/use-cases/agent-memory/nodejs/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +

+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/nodejs/longTermMemory.js b/content/develop/use-cases/agent-memory/nodejs/longTermMemory.js new file mode 100644 index 0000000000..da94582791 --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/longTermMemory.js @@ -0,0 +1,394 @@ +// Long-term memory store for an agent, backed by Redis JSON and Search. +// +// Each memory lives as one JSON document at `agent:mem:`. The +// document holds the memory text, its embedding vector, and a small +// metadata block — user, namespace, kind, source thread, timestamps — +// that lets the recall query scope results without falling back to +// application-side filtering. +// +// A single Redis Search index covers the embedding plus every +// metadata field, so one `FT.SEARCH` call performs approximate- +// nearest-neighbour over the in-scope subset and returns the top-k +// memories ranked by cosine distance. The same KNN check runs at +// *write* time to deduplicate near-identical memories before they +// enter the store, which keeps the index from filling with +// paraphrases of the same fact as the agent reasons over similar +// topics across sessions. +// +// Memories carry one of two kinds: +// +// * `episodic` — "what happened" snapshots from a specific thread, +// written with a medium TTL so old session detail decays naturally. +// * `semantic` — distilled facts and preferences the agent should +// carry forward indefinitely. Written with no TTL by default. +// +// The split is enforced as a TAG on the index, so the recall query +// can ask for one kind or both with a filter — no separate keyspaces. + +import { randomUUID } from 'node:crypto'; +import { + SCHEMA_FIELD_TYPE, + SCHEMA_VECTOR_FIELD_ALGORITHM, +} from 'redis'; + +const VECTOR_DIM_DEFAULT = 384; + +// How close (cosine distance) a candidate must be to an existing +// memory to count as a duplicate at write time. Smaller = stricter. +// 0.20 is calibrated to the `Xenova/all-MiniLM-L6-v2` embedding model +// used in the demo, where a paraphrase of an existing memory lands +// in the 0.10 – 0.20 range and a distinct memory lands above 0.50. +export const DEFAULT_DEDUP_THRESHOLD = 0.20; + +// How close (cosine distance) a candidate must be to count as a +// relevant recall result. Larger than the dedup threshold so the +// agent gets a wider net at read time than at write time. +export const DEFAULT_RECALL_THRESHOLD = 0.55; + +// TTL tiers, in seconds. `null` means "no TTL" — the memory persists +// until explicitly deleted or evicted under memory pressure. +export const TTL_BY_KIND = { + episodic: 7 * 24 * 3600, + semantic: null, +}; + +export class LongTermMemory { + constructor({ + client, + indexName = 'agentmem:idx', + keyPrefix = 'agent:mem:', + vectorDim = VECTOR_DIM_DEFAULT, + dedupThreshold = DEFAULT_DEDUP_THRESHOLD, + recallThreshold = DEFAULT_RECALL_THRESHOLD, + ttlByKind, + }) { + this.client = client; + this.indexName = indexName; + this.keyPrefix = keyPrefix; + this.vectorDim = vectorDim; + this.dedupThreshold = dedupThreshold; + this.recallThreshold = recallThreshold; + this.ttlByKind = ttlByKind || { ...TTL_BY_KIND }; + } + + // -- Keys ----------------------------------------------------------- + + memoryKey(memoryId) { + return `${this.keyPrefix}${memoryId}`; + } + + // -- Index management ---------------------------------------------- + + async createIndex() { + // The index is declared on the JSON document type with `as_name` + // aliases on each path; the same `FT.SEARCH` filter clause works + // here as on a HASH-backed index, and the field paths + // (`$.user`, `$.embedding`, ...) only show up in `FT.CREATE`. + const schema = { + '$.text': { type: SCHEMA_FIELD_TYPE.TEXT, AS: 'text' }, + '$.user': { type: SCHEMA_FIELD_TYPE.TAG, AS: 'user' }, + '$.namespace': { type: SCHEMA_FIELD_TYPE.TAG, AS: 'namespace' }, + '$.kind': { type: SCHEMA_FIELD_TYPE.TAG, AS: 'kind' }, + '$.source_thread': { + type: SCHEMA_FIELD_TYPE.TAG, AS: 'source_thread', + }, + '$.created_ts': { + type: SCHEMA_FIELD_TYPE.NUMERIC, AS: 'created_ts', SORTABLE: true, + }, + '$.hit_count': { + type: SCHEMA_FIELD_TYPE.NUMERIC, AS: 'hit_count', SORTABLE: true, + }, + '$.embedding': { + type: SCHEMA_FIELD_TYPE.VECTOR, + ALGORITHM: SCHEMA_VECTOR_FIELD_ALGORITHM.HNSW, + TYPE: 'FLOAT32', + DIM: this.vectorDim, + DISTANCE_METRIC: 'COSINE', + AS: 'embedding', + }, + }; + try { + await this.client.ft.create(this.indexName, schema, { + ON: 'JSON', + PREFIX: this.keyPrefix, + }); + } catch (err) { + if (!String(err.message || err).includes('Index already exists')) { + throw err; + } + } + } + + async dropIndex({ deleteDocuments = false } = {}) { + try { + await this.client.ft.dropIndex(this.indexName, { DD: deleteDocuments }); + } catch (err) { + const msg = String(err.message || err).toLowerCase(); + if (!msg.includes('no such index') && !msg.includes('unknown index name')) { + throw err; + } + } + } + + // -- Write ---------------------------------------------------------- + + // Write a new memory, deduplicating against existing entries. + // + // Runs one in-scope KNN(1) against the index first. If the nearest + // existing memory is within `dedupThreshold`, the new memory is + // skipped (its content is already represented) and the existing + // memory's `hit_count` is bumped. Otherwise a fresh JSON document + // is written under a new id with a TTL derived from the memory's + // `kind`. + // + // The KNN-then-write sequence is not atomic; two workers that + // remember the same fact at the same time can both miss each + // other's in-flight write and insert duplicate memories. See the + // walkthrough's "Concurrency caveats" section for the production + // fix (periodic background consolidator that merges near-duplicates). + async remember({ + text, + embedding, + user = 'default', + namespace = 'default', + kind = 'episodic', + sourceThread = '', + ttlSeconds, + }) { + if (!(embedding instanceof Float32Array)) { + embedding = Float32Array.from(embedding); + } + if (embedding.length !== this.vectorDim) { + throw new Error( + `embedding length is ${embedding.length}; index expects ${this.vectorDim}`, + ); + } + + const nearest = await this._nearest({ + embedding, user, namespace, kind, k: 1, + }); + const nearestDistance = nearest[0]?.distance ?? null; + if (nearest[0] && nearest[0].distance != null + && nearest[0].distance <= this.dedupThreshold) { + await this._bumpHitCount(nearest[0].id); + return { + id: nearest[0].id, + deduped: true, + existingDistance: nearestDistance, + }; + } + + const id = randomUUID().replace(/-/g, '').slice(0, 12); + const key = this.memoryKey(id); + const now = Date.now() / 1000; + const doc = { + id, + user, + namespace, + kind, + source_thread: sourceThread, + text, + embedding: Array.from(embedding), + created_ts: now, + hit_count: 0, + }; + const ttl = ttlSeconds !== undefined + ? ttlSeconds + : this.ttlByKind[kind] ?? null; + + // MULTI / EXEC so the document and its TTL apply together. A + // connection drop between the JSON.SET and EXPIRE would + // otherwise leave the memory without an expiry. + const tx = this.client.multi().json.set(key, '$', doc); + if (ttl !== null && ttl !== undefined) { + tx.expire(key, ttl); + } + await tx.exec(); + return { id, deduped: false, existingDistance: nearestDistance }; + } + + // -- Recall --------------------------------------------------------- + + // Return the top-k in-scope memories ranked by similarity. + // + // Memories beyond `distanceThreshold` (or the instance default) + // are dropped — the index always returns *something* for KNN, so a + // recall result on an unrelated query would otherwise be a + // confidently-wrong false positive. + async recall({ + queryEmbedding, + user = 'default', + namespace = 'default', + kind, + k = 5, + distanceThreshold, + }) { + const threshold = distanceThreshold !== undefined + ? distanceThreshold + : this.recallThreshold; + const candidates = await this._nearest({ + embedding: queryEmbedding, user, namespace, kind, k, + }); + return candidates.filter( + c => c.distance != null && c.distance <= threshold, + ); + } + + // -- Internals ------------------------------------------------------ + + async _nearest({ embedding, user, namespace, kind, k }) { + if (!(embedding instanceof Float32Array)) { + embedding = Float32Array.from(embedding); + } + if (embedding.length !== this.vectorDim) { + throw new Error( + `embedding length is ${embedding.length}; index expects ${this.vectorDim}`, + ); + } + const filterClause = LongTermMemory.buildFilterClause({ + user, namespace, kind, + }); + const queryStr = `${filterClause}=>[KNN ${k} @embedding $vec AS distance]`; + const vecBytes = Buffer.from( + embedding.buffer, embedding.byteOffset, embedding.byteLength, + ); + const result = await this.client.ft.search(this.indexName, queryStr, { + PARAMS: { vec: vecBytes }, + DIALECT: 2, + SORTBY: 'distance', + RETURN: [ + 'user', 'namespace', 'kind', 'source_thread', + 'text', 'created_ts', 'hit_count', 'distance', + ], + LIMIT: { from: 0, size: k }, + }); + if (!result.documents || result.documents.length === 0) return []; + const out = []; + for (const doc of result.documents) { + // `doc.id` is the full Redis key (e.g. `agent:mem:abc123`). + // Strip the prefix so the returned record exposes only the + // opaque id the UI and `deleteMemory` work with. + const memoryId = this._stripPrefix(doc.id); + const ttl = await this.client.ttl(this.memoryKey(memoryId)); + out.push({ + id: memoryId, + user: doc.value.user ?? '', + namespace: doc.value.namespace ?? '', + kind: doc.value.kind ?? '', + source_thread: doc.value.source_thread ?? '', + text: doc.value.text ?? '', + created_ts: parseFloat(doc.value.created_ts ?? '0') || 0, + hit_count: parseInt(doc.value.hit_count ?? '0', 10) || 0, + distance: parseFloat(doc.value.distance ?? '0') || 0, + ttl_seconds: ttl > 0 ? ttl : null, + }); + } + return out; + } + + async _bumpHitCount(memoryId) { + try { + await this.client.json.numIncrBy( + this.memoryKey(memoryId), '$.hit_count', 1, + ); + } catch { + // The doc may have expired between recall and bump — fine, we + // just lose the hit count update. + } + } + + _stripPrefix(rawKey) { + return rawKey.startsWith(this.keyPrefix) + ? rawKey.slice(this.keyPrefix.length) + : rawKey; + } + + // Characters Redis Search treats as syntax inside a TAG value; any + // of them in a user-supplied filter must be backslash-escaped or + // the surrounding `{...}` block won't parse correctly. + static _TAG_SPECIAL = new Set('\\,.<>{}[]"\':;!@#$%^&*()-+=~| '.split('')); + + static escapeTagValue(value) { + let out = ''; + for (const ch of value) { + out += LongTermMemory._TAG_SPECIAL.has(ch) ? '\\' + ch : ch; + } + return out; + } + + static buildFilterClause({ user, namespace, kind }) { + const clauses = []; + if (user) { + clauses.push(`@user:{${LongTermMemory.escapeTagValue(user)}}`); + } + if (namespace) { + clauses.push(`@namespace:{${LongTermMemory.escapeTagValue(namespace)}}`); + } + if (kind) { + clauses.push(`@kind:{${LongTermMemory.escapeTagValue(kind)}}`); + } + return clauses.length === 0 ? '(*)' : `(${clauses.join(' ')})`; + } + + // -- Admin / inspection -------------------------------------------- + + async indexInfo() { + try { + const info = await this.client.ft.info(this.indexName); + return { + num_docs: Number(info.numDocs ?? info.num_docs ?? 0), + indexing_failures: Number( + info.hashIndexingFailures ?? info.hash_indexing_failures ?? 0, + ), + }; + } catch { + return { num_docs: 0, indexing_failures: 0 }; + } + } + + async listMemories({ user, namespace, kind, limit = 100 } = {}) { + const filterClause = LongTermMemory.buildFilterClause({ + user, namespace, kind, + }); + const result = await this.client.ft.search(this.indexName, filterClause, { + DIALECT: 2, + SORTBY: { BY: 'created_ts', DIRECTION: 'DESC' }, + RETURN: [ + 'user', 'namespace', 'kind', 'source_thread', + 'text', 'created_ts', 'hit_count', + ], + LIMIT: { from: 0, size: limit }, + }); + const out = []; + for (const doc of result.documents || []) { + const memoryId = this._stripPrefix(doc.id); + const ttl = await this.client.ttl(this.memoryKey(memoryId)); + out.push({ + id: memoryId, + user: doc.value.user ?? '', + namespace: doc.value.namespace ?? '', + kind: doc.value.kind ?? '', + source_thread: doc.value.source_thread ?? '', + text: doc.value.text ?? '', + created_ts: parseFloat(doc.value.created_ts ?? '0') || 0, + hit_count: parseInt(doc.value.hit_count ?? '0', 10) || 0, + ttl_seconds: ttl > 0 ? ttl : null, + }); + } + return out; + } + + async deleteMemory(memoryId) { + return (await this.client.del(this.memoryKey(memoryId))) > 0; + } + + async clear() { + // Returns the number of memories that were removed. In production + // the equivalent is `FLUSHDB` on a dedicated memory database, or + // letting TTLs and eviction expire entries naturally. + const before = (await this.indexInfo()).num_docs; + await this.dropIndex({ deleteDocuments: true }); + await this.createIndex(); + return before; + } +} diff --git a/content/develop/use-cases/agent-memory/nodejs/package.json b/content/develop/use-cases/agent-memory/nodejs/package.json new file mode 100644 index 0000000000..a7860a2b0c --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/package.json @@ -0,0 +1,18 @@ +{ + "name": "redis-agent-memory-demo-nodejs", + "version": "1.0.0", + "private": true, + "type": "module", + "description": "Redis agent memory demo with node-redis and @xenova/transformers.", + "main": "demoServer.js", + "scripts": { + "start": "node demoServer.js" + }, + "dependencies": { + "@xenova/transformers": "^2.17.2", + "redis": "^5.12.1" + }, + "engines": { + "node": ">=18" + } +} diff --git a/content/develop/use-cases/agent-memory/nodejs/seedMemory.js b/content/develop/use-cases/agent-memory/nodejs/seedMemory.js new file mode 100644 index 0000000000..2f305805f8 --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/seedMemory.js @@ -0,0 +1,87 @@ +// Pre-seed the long-term memory store with sample memories. +// +// In a real deployment the memory store fills up organically as the +// agent reasons over user turns: each turn produces zero or more +// memories (preferences, facts, episodic summaries) that flow into +// the store with deduplication. To make the demo immediately useful +// — so the first recall query lands on relevant results instead of +// an empty list — we seed a small set of canonical memories for a +// default user at startup. +// +// The seed list mixes `semantic` memories (long-lived preferences +// and facts) with `episodic` memories (snapshots of past sessions), +// matching what the Python demo seeds so the two implementations +// behave identically. + +export const SEED_MEMORIES = [ + { + text: 'The user prefers concise answers without filler phrases.', + kind: 'semantic', + }, + { + text: 'The user is a Python developer working on a logistics platform.', + kind: 'semantic', + }, + { + text: 'The user lives in Berlin and works in the Europe/Berlin time zone.', + kind: 'semantic', + }, + { + text: + 'The user dislikes dark mode and prefers a high-contrast light ' + + 'theme in editors and dashboards.', + kind: 'semantic', + }, + { + text: + 'The user is allergic to peanuts; any restaurant suggestion must ' + + 'avoid dishes that commonly contain them.', + kind: 'semantic', + }, + { + text: + 'Last Tuesday the user asked the agent to draft a postmortem for ' + + 'the order-routing outage. The agent produced a five-section ' + + 'draft and the user approved sections 1, 2, and 4 with minor ' + + 'edits.', + kind: 'episodic', + }, + { + text: + 'In a previous session the user asked for help debugging a flaky ' + + 'test in the inventory service. The fix turned out to be a race ' + + 'condition in the warehouse webhook handler.', + kind: 'episodic', + }, + { + text: + 'Two weeks ago the user mentioned they were planning to migrate ' + + 'the analytics warehouse from Snowflake to BigQuery in Q3.', + kind: 'episodic', + }, +]; + +// Embed and write the seed memories. Returns the count actually written +// (entries that dedup against existing memories don't count). +export async function seed(memory, embedder, { + user = 'default', + namespace = 'default', + sourceThread = 'seed', +} = {}) { + const texts = SEED_MEMORIES.map(m => m.text); + const vectors = await embedder.encodeMany(texts); + let written = 0; + for (let i = 0; i < SEED_MEMORIES.length; i++) { + const entry = SEED_MEMORIES[i]; + const result = await memory.remember({ + text: entry.text, + embedding: vectors[i], + user, + namespace, + kind: entry.kind, + sourceThread, + }); + if (!result.deduped) written += 1; + } + return written; +} diff --git a/content/develop/use-cases/agent-memory/nodejs/sessionStore.js b/content/develop/use-cases/agent-memory/nodejs/sessionStore.js new file mode 100644 index 0000000000..fea5cf9a1f --- /dev/null +++ b/content/develop/use-cases/agent-memory/nodejs/sessionStore.js @@ -0,0 +1,210 @@ +// Working-memory store for an agent session, backed by a Redis Hash. +// +// Each session is one Hash document at `agent:session:{threadId}`. +// The hash holds the running scratchpad, the current goal, a rolling +// window of recent turns (serialised as a JSON list to fit in one +// field), and a few audit fields. One `HGETALL` returns the whole +// session in a single round trip on every step of the agent loop. +// +// Every write refreshes the key's TTL with `EXPIRE`, so idle sessions +// fall off without a separate cleanup job and active sessions stay +// alive as long as the agent keeps touching them. A separate +// `LongTermMemory` (see `longTermMemory.js`) is what survives beyond +// a session's TTL. +// +// The turn window is bounded to `maxTurns` in application code; the +// hash itself doesn't grow, so the working set per thread stays +// constant regardless of how long the agent has been running. + +import { randomUUID } from 'node:crypto'; + +// How many recent turns to keep inline on the session hash. Older +// turns flow through the event log (see `eventLog.js`) and the +// long-term memory store (see `longTermMemory.js`). +export const MAX_TURNS = 20; + +export class AgentSession { + constructor({ + client, + keyPrefix = 'agent:session:', + defaultTtlSeconds = 3600, + maxTurns = MAX_TURNS, + }) { + this.client = client; + this.keyPrefix = keyPrefix; + this.defaultTtlSeconds = defaultTtlSeconds; + this.maxTurns = maxTurns; + } + + sessionKey(threadId) { + return `${this.keyPrefix}${threadId}`; + } + + newThreadId() { + return randomUUID().replace(/-/g, '').slice(0, 12); + } + + // Create a fresh working memory for a thread. Overwrites any + // existing session at the same key. The agent normally calls this + // once per thread at the first turn and relies on `load` / + // `appendTurn` for subsequent steps. + async start(threadId, { + user = 'default', + agent = 'default', + goal = '', + ttlSeconds, + } = {}) { + const ttl = ttlSeconds !== undefined ? ttlSeconds : this.defaultTtlSeconds; + const now = Date.now() / 1000; + const state = { + thread_id: threadId, + user, + agent, + goal, + scratchpad: '', + turn_count: 0, + created_ts: now, + last_active_ts: now, + recent_turns: [], + ttl_seconds: ttl, + }; + await this._write(state, ttl); + return state; + } + + // Return the session state, or `null` if it has expired. + async load(threadId) { + const key = this.sessionKey(threadId); + const raw = await this.client.hGetAll(key); + if (!raw || Object.keys(raw).length === 0) return null; + const ttl = await this.client.ttl(key); + let turns = []; + try { + turns = JSON.parse(raw.recent_turns || '[]'); + } catch { + turns = []; + } + return { + thread_id: threadId, + user: raw.user || 'default', + agent: raw.agent || 'default', + goal: raw.goal || '', + scratchpad: raw.scratchpad || '', + turn_count: parseInt(raw.turn_count || '0', 10) || 0, + created_ts: parseFloat(raw.created_ts || '0') || 0, + last_active_ts: parseFloat(raw.last_active_ts || '0') || 0, + recent_turns: turns, + ttl_seconds: ttl > 0 ? ttl : 0, + }; + } + + // Append a turn, bound the rolling window, refresh the TTL. + // + // `user` and `agent` are only consulted when the session does not + // yet exist — they seed the auto-created session so the + // working-memory hash matches the user the caller is operating + // against. On an existing session they're ignored; the original + // `start` values stand. + // + // Read-modify-write here is last-writer-wins on the turn list if + // two concurrent turns reach the same thread; the demo never + // triggers that race in practice (one browser, one turn at a + // time) but a multi-worker agent that shares a thread id would + // wrap this in `WATCH` / `MULTI` / `EXEC` or a Lua script that + // does the append atomically server-side. + async appendTurn(threadId, { + role, content, user, agent, ttlSeconds, + } = {}) { + let state = await this.load(threadId); + if (!state) { + state = await this.start(threadId, { + user: user ?? 'default', + agent: agent ?? 'default', + ttlSeconds, + }); + } + state.recent_turns.push({ role, content, ts: Date.now() / 1000 }); + if (state.recent_turns.length > this.maxTurns) { + state.recent_turns = state.recent_turns.slice(-this.maxTurns); + } + state.turn_count += 1; + state.last_active_ts = Date.now() / 1000; + const ttl = ttlSeconds !== undefined ? ttlSeconds : this.defaultTtlSeconds; + state.ttl_seconds = ttl; + await this._write(state, ttl); + return state; + } + + async setScratchpad(threadId, text, ttlSeconds) { + const state = await this.load(threadId); + if (!state) return null; + state.scratchpad = text; + state.last_active_ts = Date.now() / 1000; + const ttl = ttlSeconds !== undefined ? ttlSeconds : this.defaultTtlSeconds; + state.ttl_seconds = ttl; + await this._write(state, ttl); + return state; + } + + // Update the goal field without touching turns or the scratchpad. + // Creates the session if it doesn't exist yet — setting a goal on + // a fresh thread is a sensible first step in the agent loop, so + // this method covers both the "rename the goal mid-session" and + // the "start a thread with this goal" cases. + async setGoal(threadId, text, { user, agent, ttlSeconds } = {}) { + const state = await this.load(threadId); + if (!state) { + return this.start(threadId, { + user: user ?? 'default', + agent: agent ?? 'default', + goal: text, + ttlSeconds, + }); + } + state.goal = text; + state.last_active_ts = Date.now() / 1000; + const ttl = ttlSeconds !== undefined ? ttlSeconds : this.defaultTtlSeconds; + state.ttl_seconds = ttl; + await this._write(state, ttl); + return state; + } + + async delete(threadId) { + return (await this.client.del(this.sessionKey(threadId))) > 0; + } + + // Return active thread ids (for the demo's thread switcher). + async listThreads(limit = 100) { + const out = []; + for await (const key of this.client.scanIterator({ + MATCH: `${this.keyPrefix}*`, COUNT: 200, + })) { + const threadId = String(key).slice(this.keyPrefix.length); + out.push(threadId); + if (out.length >= limit) break; + } + return out; + } + + async _write(state, ttl) { + const key = this.sessionKey(state.thread_id); + const mapping = { + thread_id: state.thread_id, + user: state.user, + agent: state.agent, + goal: state.goal, + scratchpad: state.scratchpad, + turn_count: String(state.turn_count), + created_ts: String(state.created_ts), + last_active_ts: String(state.last_active_ts), + recent_turns: JSON.stringify(state.recent_turns), + }; + // MULTI / EXEC so HSET and EXPIRE either both apply or neither + // does. A connection drop between the two writes would otherwise + // leave the session without a TTL. + await this.client.multi() + .hSet(key, mapping) + .expire(key, ttl) + .exec(); + } +} diff --git a/content/develop/use-cases/agent-memory/php/.gitignore b/content/develop/use-cases/agent-memory/php/.gitignore new file mode 100644 index 0000000000..c2c0649b99 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/.gitignore @@ -0,0 +1,5 @@ +vendor/ +.transformers-cache/ +models/ +*.log +.DS_Store diff --git a/content/develop/use-cases/agent-memory/php/_index.md b/content/develop/use-cases/agent-memory/php/_index.md new file mode 100644 index 0000000000..085384dff2 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/_index.md @@ -0,0 +1,379 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in PHP with Predis, TransformersPHP, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: Predis example (PHP) +title: Redis agent memory with Predis +weight: 8 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in PHP with [Predis]({{< relref "/develop/clients/php" >}}) and the [TransformersPHP](https://transformers.codewithkyrian.com/) library, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built on PHP's `stream_socket_server` so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [TransformersPHP](https://transformers.codewithkyrian.com/) running the ONNX-exported [`Xenova/all-MiniLM-L6-v2`](https://huggingface.co/Xenova/all-MiniLM-L6-v2) model through ONNX Runtime via FFI, which is the same encoder the [Python example]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) uses. Embeddings produced by the two implementations are numerically very close — paraphrase distances drift by less than 0.02 — so a memory written by one demo can be recalled by the other against the same Redis instance, and the distance bands the Python walkthrough quotes carry over to this one without recalibration. One quirk worth flagging up front: TransformersPHP 0.6 accepts a `normalize: true` keyword on the `feature-extraction` / `embeddings` pipeline but silently returns un-normalized vectors anyway, so the `Embedder` wrapper L2-normalizes in PHP code before handing the vector to recall or dedup — see [Embedder.php](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/php/src/Embedder.php) for the workaround. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* A single round trip per tier: one [`HGETALL`]({{< relref "/commands/hgetall" >}}) for the session, one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) for recall, one [`XADD`]({{< relref "/commands/xadd" >}}) for the event log. ([`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) itself is one round trip; the helper also issues one [`TTL`]({{< relref "/commands/ttl" >}}) call per returned row to populate `ttl_seconds` for the admin panel.) +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `$embedder->encodeOne($text)` to turn the incoming turn into a 384-element `array` of floats, L2-normalized in PHP code (see the TransformersPHP normalization quirk noted above). +2. `$session->appendTurn($threadId, role: ..., content: ...)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI/EXEC`]({{< relref "/commands/multi" >}}). The session TTL refreshes on every write so an active thread stays alive. +3. `$memory->recall(queryEmbedding: $vec, user: ..., namespace: ..., k: 5)` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `$memory->remember(text: ..., embedding: $vec, user: ..., namespace: ..., kind: ...)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `$eventLog->record($threadId, $action, $detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/php/src/AgentSession.php)): + +```php + 'tcp', 'host' => 'localhost', 'port' => 6379]); +$session = new AgentSession( + client: $client, + keyPrefix: 'agent:session:', + defaultTtlSeconds: 3600, // one hour + maxTurns: 20, // rolling window per thread +); + +$threadId = $session->newThreadId(); +$session->start( + $threadId, + user: 'alice', + agent: 'demo-agent', + goal: "Plan next week's meetings.", +); +$session->appendTurn( + $threadId, + role: 'user', + content: 'Schedule a budget review with finance.', +); +$state = $session->load($threadId); +echo $state['turn_count'] . ' ' + . count($state['recent_turns']) . ' ' + . $state['ttl_seconds'] . "\n"; +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `appendTurn`, `setScratchpad` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/php/src/LongTermMemory.php)): + +```php +createIndex(); // idempotent + +// Write a memory. The same KNN that powers recall also runs here at +// a tighter threshold so paraphrases of the same fact collapse. +$vec = $embedder->encodeOne('The user prefers light mode in editors.'); +$result = $memory->remember( + text: 'The user prefers light mode in editors.', + embedding: $vec, + user: 'alice', + namespace: 'default', + kind: 'semantic', + sourceThread: '9f3d2a4b8c61', +); +printf("deduped=%s id=%s existing_distance=%s\n", + $result['deduped'] ? 'true' : 'false', + $result['id'], + var_export($result['existing_distance'], true)); + +// Recall against a later question. +$q = $embedder->encodeOne('Which theme does this user like?'); +$hits = $memory->recall( + queryEmbedding: $q, + user: 'alice', + namespace: 'default', + k: 5, +); +foreach ($hits as $h) { + printf("%.3f [%s] %s\n", $h['distance'], $h['kind'], $h['text']); +} +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (the demo packs them with PHP's [`pack('g*', ...)`](https://www.php.net/manual/en/function.pack.php), which emits little-endian float32), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with `AS` aliases so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write with `ttlSeconds: ...` on `remember`, or pass a different `ttlByKind: [...]` to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/php/src/AgentEventLog.php)): + +```php +record( + $threadId, 'turn_appended:user', + 'Schedule a budget review with finance.', +); +$events->record( + $threadId, 'memory_written', + 'wrote 7c3f8a1b9e02 as semantic', +); + +foreach ($events->recent($threadId, 20) as $event) { + echo $event['action'] . ' ' . $event['detail'] . "\n"; +} +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarizers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession::appendTurn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recent_turns` list in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recent_turns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory::remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `currentThreadId` field on a `DemoState` object that `/new_thread` and `/reset` mutate; `handleTurn` reads it without coordination, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +* **`Embedder` is shared.** TransformersPHP's `pipeline` object is constructed once and reused across every request; the underlying ONNX session is *not* documented as thread-safe. PHP's request model usually makes this a non-issue (one process, one in-flight request at a time on the blocking demo server), but if you front this code with a thread-pool runtime like Swoole or ReactPHP, fence the embedder behind a mutex or per-worker instance. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `SeedMemory.php` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/php/src/SeedMemory.php)): + +```php +createIndex(); +SeedMemory::seed($memory, $embedder, user: 'default', namespace: 'default'); +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`demo_server.php` runs a [`stream_socket_server`](https://www.php.net/manual/en/function.stream-socket-server.php) HTTP/1.1 loop on port 8090. PHP's built-in `php -S` dev server re-runs the entry script for each request and discards process state in between, which would force the ~80 MB ONNX model to reload on every turn; the single-process socket loop keeps the embedder, the Predis connection, and the demo state alive between requests. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `Embedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a single field on a `DemoState` object that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the + example directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/php + ``` + +2. Install the dependencies. TransformersPHP ships an installer plugin that downloads + the ONNX Runtime shared library for your platform; the `composer.json` in this + directory already allow-lists it. + + ```bash + composer install + ``` + + PHP 8.1 or later is required, with the [`ffi`](https://www.php.net/manual/en/book.ffi.php) + extension enabled (it ships with the official PHP builds and the Homebrew formula on macOS). + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Pre-fetch the embedding model (optional). The demo will lazy-download the + [`Xenova/all-MiniLM-L6-v2`](https://huggingface.co/Xenova/all-MiniLM-L6-v2) ONNX + weights (around 80 MB) on the first request, but the wait is more obvious when + it lands inside the first turn rather than at startup: + + ```bash + composer run download-model + ``` + +5. Start the demo server: + + ```bash + php demo_server.php + ``` + +6. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.16 (the ONNX + Runtime FFI bindings TransformersPHP uses sit a hair behind the + PyTorch arithmetic the Python demo runs, so paraphrase distances are + a couple of hundredths higher than in the Python walkthrough), with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing recalls + above the threshold (every seed lands above 0.8), and the new memory is + written as `episodic` (or `semantic`, depending on the dropdown) under a + fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + `Xenova/all-MiniLM-L6-v2` puts a faithful paraphrase in the 0.15 – 0.50 + cosine-distance range, a loose paraphrase or related topic in the 0.50 – 0.80 + range, and unrelated queries above 0.8 — which is what motivates the 0.55 + default recall threshold and the 0.20 default dedup threshold. A stricter + embedding model (or a domain-tuned one) would let you tighten both; a + noisier one would push them up. The right thresholds are always a function + of the model, the corpus, and how conservative the agent needs to be about + accepting a memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags: + +* `--host` — HTTP bind host (default `127.0.0.1`). +* `--port` — HTTP bind port (default `8090`). +* `--redis-host`, `--redis-port` — point the demo at a non-local Redis. +* `--mem-index-name`, `--mem-key-prefix`, `--session-key-prefix`, `--event-key-prefix` — change the default key namespacing. +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/php/composer.json b/content/develop/use-cases/agent-memory/php/composer.json new file mode 100644 index 0000000000..8b2d256571 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/composer.json @@ -0,0 +1,25 @@ +{ + "name": "redis/agent-memory-demo-php", + "description": "Redis agent-memory demo (PHP) — working memory in a Hash, long-term memory as JSON+Search, event log in a Stream.", + "license": "MIT", + "type": "project", + "require": { + "php": "^8.1", + "ext-ffi": "*", + "predis/predis": "^3.5", + "codewithkyrian/transformers": "^0.6.2" + }, + "config": { + "allow-plugins": { + "codewithkyrian/platform-package-installer": true + } + }, + "autoload": { + "psr-4": { + "Redis\\AgentMemory\\": "src/" + } + }, + "scripts": { + "download-model": "@php vendor/bin/transformers download-model Xenova/all-MiniLM-L6-v2" + } +} diff --git a/content/develop/use-cases/agent-memory/php/demo_server.php b/content/develop/use-cases/agent-memory/php/demo_server.php new file mode 100644 index 0000000000..dd7488d971 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/demo_server.php @@ -0,0 +1,562 @@ + 'tcp', 'host' => $redisHost, 'port' => $redisPort], + // No automatic key prefix — every helper passes its own prefix. +); +try { + $client->ping(); +} catch (Throwable $exc) { + fwrite(STDERR, "Error: cannot reach Redis at $redisHost:$redisPort\n"); + fwrite(STDERR, " (" . $exc->getMessage() . ")\n"); + exit(1); +} + +$sessionStore = new AgentSession( + client: $client, + keyPrefix: $sessionKeyPrefix, + defaultTtlSeconds: $sessionTtlSeconds, +); +$memory = new LongTermMemory( + client: $client, + indexName: $memIndexName, + keyPrefix: $memKeyPrefix, + dedupThreshold: $dedupThreshold, + recallThreshold: $recallThreshold, +); +$memory->createIndex(); +$eventLog = new AgentEventLog( + client: $client, + keyPrefix: $eventKeyPrefix, +); + +fwrite(STDOUT, "Loading embedding model (first run downloads ~80 MB)...\n"); +$embedder = new Embedder(); + +// One mutable container for the current thread id. The demo never +// races itself in practice — single browser, blocking server — so +// this can be a class instead of a process-wide lock. +final class DemoState +{ + public string $currentThreadId; + public function __construct(string $threadId) { $this->currentThreadId = $threadId; } +} +$demo = new DemoState($sessionStore->newThreadId()); + +// Load and template the shared HTML once. The four `__*__` tokens +// are replaced with the configured key prefixes and index name so +// the lede paragraph shows the actual values in use. +$htmlPage = strtr( + file_get_contents(__DIR__ . '/index.html'), + [ + '__SESSION_PREFIX__' => $sessionKeyPrefix, + '__MEM_PREFIX__' => $memKeyPrefix, + '__MEM_INDEX__' => $memIndexName, + '__EVENT_PREFIX__' => $eventKeyPrefix, + ], +); + +$stackLabel = 'predis + TransformersPHP + pure-PHP stream_socket_server'; + +// ---- Demo helpers ---------------------------------------------------- + +$reseed = function (string $user, string $namespace) use ( + $memory, $sessionStore, $eventLog, $embedder, $demo, +): int { + $memory->clear(); + $sessionStore->delete($demo->currentThreadId); + $eventLog->clear($demo->currentThreadId); + $written = SeedMemory::seed( + $memory, $embedder, $user, $namespace, 'seed', + ); + $demo->currentThreadId = $sessionStore->newThreadId(); + return $written; +}; + +$newThread = function (string $user, string $namespace) use ( + $sessionStore, $eventLog, $demo, +): string { + $eventLog->clear($demo->currentThreadId); + $demo->currentThreadId = $sessionStore->newThreadId(); + $sessionStore->start( + $demo->currentThreadId, user: $user, agent: 'demo-agent', goal: '', + ); + $eventLog->record( + $demo->currentThreadId, 'thread_started', + "user=$user namespace=$namespace", + ); + return $demo->currentThreadId; +}; + +$handleTurn = function (array $params) use ( + $sessionStore, $memory, $eventLog, $embedder, $demo, $recallThreshold, +): array { + $text = trim((string) ($params['text'] ?? '')); + if ($text === '') { + return ['__http_status' => 400, 'error' => 'text is required']; + } + $user = ($params['user'] ?? '') !== '' ? $params['user'] : 'default'; + $namespace = ($params['namespace'] ?? '') !== '' + ? $params['namespace'] : 'default'; + $kind = ($params['kind'] ?? '') !== '' ? $params['kind'] : 'episodic'; + $role = ($params['role'] ?? '') !== '' ? $params['role'] : 'user'; + $action = ($params['action'] ?? '') !== '' ? $params['action'] : 'turn'; + + // Missing/blank threshold falls back to the configured + // `--recall-threshold` rather than a hard-coded constant, so the + // server-wide flag actually drives the default. + $thresholdRaw = $params['threshold'] ?? ''; + $threshold = $thresholdRaw === '' ? $recallThreshold : (float) $thresholdRaw; + // `(float)` parses "nan"/"inf" as NaN/INF; either would silently + // turn recall into "every memory" or "nothing". Clamp to the + // meaningful cosine-distance range so a malformed POST can't + // override the threshold semantics. + if (!is_finite($threshold)) { + $threshold = $recallThreshold; + } + $threshold = max(0.0, min(2.0, $threshold)); + + $threadId = $demo->currentThreadId; + + $t0 = microtime(true); + $vec = $embedder->encodeOne($text); + $embedMs = (microtime(true) - $t0) * 1000; + + // `setGoal` only touches the goal field so existing turns aren't + // wiped; `appendTurn` carries the request `user` through to the + // auto-create path so a first turn for a new thread doesn't land + // under the default user. + if ($action === 'goal') { + $sessionStore->setGoal( + $threadId, $text, user: $user, agent: 'demo-agent', + ); + $sessionAction = 'goal_set'; + } else { + $sessionStore->appendTurn( + $threadId, role: $role, content: $text, + user: $user, agent: 'demo-agent', + ); + $sessionAction = "turn_appended:$role"; + } + + $t1 = microtime(true); + $recalled = $memory->recall( + queryEmbedding: $vec, + user: $user, + namespace: $namespace, + k: 5, + distanceThreshold: $threshold, + ); + $recallMs = (microtime(true) - $t1) * 1000; + + $writeSkipped = ($kind === 'skip' || $action === 'goal'); + $writeResult = null; + $writeMs = 0.0; + if (!$writeSkipped) { + $t2 = microtime(true); + $writeResult = $memory->remember( + text: $text, + embedding: $vec, + user: $user, + namespace: $namespace, + kind: $kind, + sourceThread: $threadId, + ); + $writeMs = (microtime(true) - $t2) * 1000; + } + + if ($writeResult !== null) { + $eventDetail = $writeResult['deduped'] + ? "deduped onto {$writeResult['id']}" + : "wrote {$writeResult['id']} as $kind"; + $eventLog->record($threadId, $sessionAction, $eventDetail); + } else { + $eventLog->record($threadId, $sessionAction, ''); + } + + return [ + 'thread_id' => $threadId, + 'write_skipped' => $writeSkipped, + 'memory_id' => $writeResult['id'] ?? null, + 'deduped' => $writeResult['deduped'] ?? false, + 'existing_distance' => $writeResult['existing_distance'] ?? null, + 'kind' => $writeSkipped ? null : $kind, + 'recalled' => $recalled, + 'embed_ms' => $embedMs, + 'recall_ms' => $recallMs, + 'write_ms' => $writeMs, + ]; +}; + +$buildState = function (string $user, string $namespace) use ( + $memory, $sessionStore, $eventLog, $embedder, $demo, $stackLabel, + $memIndexName, $sessionTtlSeconds, $dedupThreshold, $recallThreshold, +): array { + $info = $memory->indexInfo(); + $threadId = $demo->currentThreadId; + $session = $sessionStore->load($threadId); + $memories = $memory->listMemories( + user: $user, namespace: $namespace, limit: 200, + ); + $events = $eventLog->recent($threadId, 20); + return [ + 'index' => [ + 'num_docs' => $info['num_docs'], + 'indexing_failures' => $info['indexing_failures'], + 'index_name' => $memIndexName, + 'model' => $embedder->modelName, + 'session_ttl_seconds' => $sessionTtlSeconds, + 'dedup_threshold' => $dedupThreshold, + 'default_recall_threshold' => $recallThreshold, + 'stack_label' => $stackLabel, + ], + 'thread_id' => $threadId, + 'session' => $session, + 'memories' => $memories, + 'events' => $events, + // `recalled` is populated by /turn; on plain /state reads + // the UI keeps showing the last turn's result, which is the + // useful behavior for an "agent" panel. + 'recalled' => [], + ]; +}; + +// ---- HTTP loop ------------------------------------------------------- + +// Cap POST bodies so a runaway client (or a `curl --data-binary +// @big-file` by mistake) can't accumulate unbounded memory before the +// handler runs. The demo's largest legitimate body is a few hundred +// bytes of form-encoded query fields; 1 MiB is a generous ceiling +// matching the Node, .NET, Rust, Go, and Java demos. +const MAX_BODY_BYTES = 1 * 1024 * 1024; + +function sendResponse($conn, int $status, string $contentType, string $body): void +{ + $reasons = [ + 200 => 'OK', + 400 => 'Bad Request', + 404 => 'Not Found', + 405 => 'Method Not Allowed', + 413 => 'Payload Too Large', + 500 => 'Internal Server Error', + ]; + $reason = $reasons[$status] ?? 'Internal Server Error'; + $headers = "HTTP/1.1 $status $reason\r\n" + . "Content-Type: $contentType\r\n" + . 'Content-Length: ' . strlen($body) . "\r\n" + . "Connection: close\r\n\r\n"; + fwrite($conn, $headers . $body); +} + +function sendJson($conn, mixed $payload, int $status = 200): void +{ + sendResponse( + $conn, $status, 'application/json', + json_encode($payload, JSON_UNESCAPED_SLASHES | JSON_PRESERVE_ZERO_FRACTION), + ); +} + +function sendHtml($conn, string $html, int $status = 200): void +{ + sendResponse($conn, $status, 'text/html; charset=utf-8', $html); +} + +function readRequest($conn): ?array +{ + // Parse request line + headers until \r\n\r\n. + $head = ''; + while (!feof($conn)) { + $chunk = fgets($conn); + if ($chunk === false) { + return null; + } + $head .= $chunk; + if (str_ends_with($head, "\r\n\r\n") || $chunk === "\r\n") { + break; + } + if (strlen($head) > 64 * 1024) { + return null; // header section too big + } + } + $lines = preg_split('/\r\n/', rtrim($head, "\r\n")); + if (empty($lines)) { + return null; + } + $requestLine = array_shift($lines); + if (!preg_match('#^(\S+)\s+(\S+)\s+HTTP/\S+#', $requestLine, $m)) { + return null; + } + [, $method, $target] = $m; + $headers = []; + foreach ($lines as $line) { + $pos = strpos($line, ':'); + if ($pos === false) { + continue; + } + $name = strtolower(trim(substr($line, 0, $pos))); + $value = trim(substr($line, $pos + 1)); + $headers[$name] = $value; + } + $contentLength = (int) ($headers['content-length'] ?? 0); + $body = ''; + if ($contentLength > 0) { + if ($contentLength > MAX_BODY_BYTES) { + // Drain a bounded chunk so the connection state stays + // sane, then signal the cap to the caller. + $drained = fread($conn, MAX_BODY_BYTES); + return [ + 'method' => $method, 'target' => $target, + 'headers' => $headers, 'body' => $drained, + 'body_too_large' => true, + ]; + } + $remaining = $contentLength; + while ($remaining > 0 && !feof($conn)) { + $chunk = fread($conn, $remaining); + if ($chunk === false || $chunk === '') { + break; + } + $body .= $chunk; + $remaining -= strlen($chunk); + } + } + return [ + 'method' => $method, 'target' => $target, + 'headers' => $headers, 'body' => $body, + 'body_too_large' => false, + ]; +} + +function parseForm(string $body): array +{ + $params = []; + parse_str($body, $params); + return is_array($params) ? $params : []; +} + +$endpoint = "tcp://$host:$port"; +$serverContext = stream_context_create([ + 'socket' => ['so_reuseport' => true, 'so_reuseaddr' => true], +]); +$server = @stream_socket_server( + $endpoint, $errno, $errstr, + STREAM_SERVER_BIND | STREAM_SERVER_LISTEN, + $serverContext, +); +if ($server === false) { + fwrite(STDERR, "Error: cannot bind $endpoint: $errstr ($errno)\n"); + exit(1); +} + +if ($resetOnStart) { + fwrite(STDOUT, + "Dropping any existing memories under '$memKeyPrefix*' and re-seeding " + . "from the sample memory list (pass --no-reset to keep).\n" + ); + $seeded = $reseed('default', 'default'); + fwrite(STDOUT, "Seeded $seeded memories.\n"); +} + +fwrite(STDOUT, "Redis agent memory demo listening on http://$host:$port\n"); +fwrite(STDOUT, + "Using Redis at $redisHost:$redisPort with memory index '$memIndexName'\n", +); + +// SIGINT / SIGTERM clean up the listener so a re-run can bind the +// same port immediately. Requires ext-pcntl; without it Ctrl+C still +// works, the kernel just reclaims the socket on process exit. +if (function_exists('pcntl_async_signals')) { + pcntl_async_signals(true); + $shutdown = function () use ($server) { + fwrite(STDOUT, "\nShutting down...\n"); + @fclose($server); + exit(0); + }; + pcntl_signal(SIGINT, $shutdown); + pcntl_signal(SIGTERM, $shutdown); +} + +while (true) { + $conn = @stream_socket_accept($server, -1); + if ($conn === false) { + continue; + } + try { + $req = readRequest($conn); + if ($req === null) { + sendJson($conn, ['error' => 'malformed request'], 400); + continue; + } + if (!empty($req['body_too_large'])) { + sendJson($conn, [ + 'error' => 'request body exceeds ' . MAX_BODY_BYTES . ' bytes', + ], 413); + continue; + } + + $url = parse_url($req['target']); + $path = $url['path'] ?? '/'; + $query = []; + if (isset($url['query'])) { + parse_str($url['query'], $query); + } + + if ($req['method'] === 'GET' && in_array($path, ['/', '/index.html'], true)) { + sendHtml($conn, $htmlPage); + } elseif ($req['method'] === 'GET' && $path === '/state') { + $user = ($query['user'] ?? '') !== '' ? $query['user'] : 'default'; + $namespace = ($query['namespace'] ?? '') !== '' + ? $query['namespace'] : 'default'; + sendJson($conn, $buildState($user, $namespace)); + } elseif ($req['method'] === 'POST') { + $params = parseForm($req['body']); + switch ($path) { + case '/turn': + $payload = $handleTurn($params); + $status = $payload['__http_status'] ?? 200; + unset($payload['__http_status']); + sendJson($conn, $payload, $status); + break; + case '/new_thread': + $threadId = $newThread( + ($params['user'] ?? '') !== '' ? $params['user'] : 'default', + ($params['namespace'] ?? '') !== '' ? $params['namespace'] : 'default', + ); + sendJson($conn, ['thread_id' => $threadId]); + break; + case '/reset': + $seeded = $reseed( + ($params['user'] ?? '') !== '' ? $params['user'] : 'default', + ($params['namespace'] ?? '') !== '' ? $params['namespace'] : 'default', + ); + sendJson($conn, ['seeded' => $seeded]); + break; + case '/drop_memory': + $memoryId = trim((string) ($params['memory_id'] ?? '')); + if ($memoryId === '') { + sendJson($conn, ['error' => 'memory_id is required'], 400); + break; + } + $deleted = $memory->deleteMemory($memoryId); + sendJson($conn, ['deleted' => $deleted, 'memory_id' => $memoryId]); + break; + default: + sendJson($conn, ['error' => 'not found'], 404); + } + } else { + sendJson($conn, ['error' => 'not found'], 404); + } + } catch (Throwable $exc) { + // Without this catch, an exception escapes the loop and the + // client's `await res.json()` blows up on an opaque parse + // error instead of seeing what went wrong. + fwrite(STDERR, "[demo] handler error: " . $exc->getMessage() . "\n"); + try { + sendJson($conn, [ + 'error' => $exc->getMessage(), + 'type' => $exc::class, + ], 500); + } catch (Throwable) { + // Headers may already be partially flushed; nothing left to do. + } + } finally { + @fclose($conn); + } +} diff --git a/content/develop/use-cases/agent-memory/php/index.html b/content/develop/use-cases/agent-memory/php/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/php/src/AgentEventLog.php b/content/develop/use-cases/agent-memory/php/src/AgentEventLog.php new file mode 100644 index 0000000000..81cf6b6f1b --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/src/AgentEventLog.php @@ -0,0 +1,117 @@ +client = $client; + $this->keyPrefix = $keyPrefix; + $this->maxLen = $maxLen; + } + + public function streamKey(string $threadId): string + { + return $this->keyPrefix . $threadId; + } + + /** + * Append one event and return its stream id. + * + * `MAXLEN ~ N` keeps the stream bounded with near-zero overhead; + * an exact bound (`MAXLEN N` without the tilde) forces a scan + * and is rarely worth the cost. + */ + public function record( + string $threadId, + string $action, + string $detail = '', + ): string { + return (string) $this->client->xadd( + $this->streamKey($threadId), + [ + 'action' => $action, + 'detail' => $detail, + 'ts' => (string) microtime(true), + ], + '*', + ['trim' => ['MAXLEN', '~', $this->maxLen]], + ); + } + + /** + * Return the most recent events, newest first. + * + * @return list> + */ + public function recent(string $threadId, int $count = 20): array + { + $rows = $this->client->xrevrange( + $this->streamKey($threadId), '+', '-', $count, + ); + $out = []; + foreach ($rows as $entryId => $fields) { + // Predis returns an associative array of [streamId => + // [field => value, ...]] for XREVRANGE. Older releases + // emit the wire-format pair instead; handle either. + if (is_array($fields) && is_string($entryId)) { + $data = $fields; + $id = $entryId; + } else { + [$id, $data] = $fields; + } + $out[] = [ + 'event_id' => $id, + 'thread_id' => $threadId, + 'action' => $data['action'] ?? '', + 'detail' => $data['detail'] ?? '', + 'ts' => (float) ($data['ts'] ?? 0), + ]; + } + return $out; + } + + public function length(string $threadId): int + { + return (int) $this->client->xlen($this->streamKey($threadId)); + } + + public function clear(string $threadId): bool + { + return ((int) $this->client->del($this->streamKey($threadId))) > 0; + } +} diff --git a/content/develop/use-cases/agent-memory/php/src/AgentSession.php b/content/develop/use-cases/agent-memory/php/src/AgentSession.php new file mode 100644 index 0000000000..a9c5260314 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/src/AgentSession.php @@ -0,0 +1,274 @@ +client = $client; + $this->keyPrefix = $keyPrefix; + $this->defaultTtlSeconds = $defaultTtlSeconds; + $this->maxTurns = $maxTurns; + } + + public function sessionKey(string $threadId): string + { + return $this->keyPrefix . $threadId; + } + + public function newThreadId(): string + { + return substr(bin2hex(random_bytes(6)), 0, 12); + } + + /** + * Create a fresh working memory for a thread. Overwrites any + * existing session at the same key. The agent normally calls + * this once per thread at the first turn and relies on `load` / + * `appendTurn` for subsequent steps. + * + * @return array + */ + public function start( + string $threadId, + string $user = 'default', + string $agent = 'default', + string $goal = '', + ?int $ttlSeconds = null, + ): array { + $ttl = $ttlSeconds ?? $this->defaultTtlSeconds; + $now = microtime(true); + $state = [ + 'thread_id' => $threadId, + 'user' => $user, + 'agent' => $agent, + 'goal' => $goal, + 'scratchpad' => '', + 'turn_count' => 0, + 'created_ts' => $now, + 'last_active_ts' => $now, + 'recent_turns' => [], + 'ttl_seconds' => $ttl, + ]; + $this->write($state, $ttl); + return $state; + } + + /** + * Return the session state, or `null` if it has expired. + * + * @return array|null + */ + public function load(string $threadId): ?array + { + $key = $this->sessionKey($threadId); + $raw = $this->client->hgetall($key); + if (!$raw) { + return null; + } + $ttl = (int) $this->client->ttl($key); + $turnsBlob = $raw['recent_turns'] ?? '[]'; + $turns = json_decode($turnsBlob, true); + if (!is_array($turns)) { + $turns = []; + } + return [ + 'thread_id' => $threadId, + 'user' => $raw['user'] ?? 'default', + 'agent' => $raw['agent'] ?? 'default', + 'goal' => $raw['goal'] ?? '', + 'scratchpad' => $raw['scratchpad'] ?? '', + 'turn_count' => (int) ($raw['turn_count'] ?? 0), + 'created_ts' => (float) ($raw['created_ts'] ?? 0), + 'last_active_ts' => (float) ($raw['last_active_ts'] ?? 0), + 'recent_turns' => $turns, + 'ttl_seconds' => $ttl > 0 ? $ttl : 0, + ]; + } + + /** + * Append a turn, bound the rolling window, refresh the TTL. + * + * `user` and `agent` are only consulted when the session does + * not yet exist — they seed the auto-created session so the + * working-memory hash matches the user the caller is operating + * against. On an existing session they're ignored; the original + * `start` values stand. + * + * Read-modify-write here is last-writer-wins on the turn list if + * two concurrent turns reach the same thread; the demo never + * triggers that race in practice (one browser, one turn at a + * time) but a multi-worker agent that shares a thread id would + * wrap this in `WATCH` / `MULTI` / `EXEC` or a Lua script that + * does the append atomically server-side. + * + * @return array + */ + public function appendTurn( + string $threadId, + string $role, + string $content, + ?string $user = null, + ?string $agent = null, + ?int $ttlSeconds = null, + ): array { + $state = $this->load($threadId); + if ($state === null) { + $state = $this->start( + $threadId, + user: $user ?? 'default', + agent: $agent ?? 'default', + ttlSeconds: $ttlSeconds, + ); + } + $state['recent_turns'][] = [ + 'role' => $role, + 'content' => $content, + 'ts' => microtime(true), + ]; + if (count($state['recent_turns']) > $this->maxTurns) { + $state['recent_turns'] = array_slice( + $state['recent_turns'], -$this->maxTurns, + ); + } + $state['turn_count'] = ((int) $state['turn_count']) + 1; + $state['last_active_ts'] = microtime(true); + $ttl = $ttlSeconds ?? $this->defaultTtlSeconds; + $state['ttl_seconds'] = $ttl; + $this->write($state, $ttl); + return $state; + } + + /** + * Update the agent's running scratchpad and refresh TTL. + * + * @return array|null + */ + public function setScratchpad( + string $threadId, + string $text, + ?int $ttlSeconds = null, + ): ?array { + $state = $this->load($threadId); + if ($state === null) { + return null; + } + $state['scratchpad'] = $text; + $state['last_active_ts'] = microtime(true); + $ttl = $ttlSeconds ?? $this->defaultTtlSeconds; + $state['ttl_seconds'] = $ttl; + $this->write($state, $ttl); + return $state; + } + + /** + * Update the goal field without touching turns or the scratchpad. + * + * Creates the session if it doesn't exist yet — setting a goal + * on a fresh thread is a sensible first step in the agent loop, + * so this method covers both the "rename the goal mid-session" + * and the "start a thread with this goal" cases. + * + * @return array + */ + public function setGoal( + string $threadId, + string $text, + ?string $user = null, + ?string $agent = null, + ?int $ttlSeconds = null, + ): array { + $state = $this->load($threadId); + if ($state === null) { + return $this->start( + $threadId, + user: $user ?? 'default', + agent: $agent ?? 'default', + goal: $text, + ttlSeconds: $ttlSeconds, + ); + } + $state['goal'] = $text; + $state['last_active_ts'] = microtime(true); + $ttl = $ttlSeconds ?? $this->defaultTtlSeconds; + $state['ttl_seconds'] = $ttl; + $this->write($state, $ttl); + return $state; + } + + public function delete(string $threadId): bool + { + return ((int) $this->client->del($this->sessionKey($threadId))) > 0; + } + + /** + * @param array $state + */ + private function write(array $state, int $ttl): void + { + $key = $this->sessionKey($state['thread_id']); + $mapping = [ + 'thread_id' => $state['thread_id'], + 'user' => $state['user'], + 'agent' => $state['agent'], + 'goal' => $state['goal'], + 'scratchpad' => $state['scratchpad'], + 'turn_count' => (string) $state['turn_count'], + 'created_ts' => (string) $state['created_ts'], + 'last_active_ts' => (string) $state['last_active_ts'], + 'recent_turns' => json_encode( + $state['recent_turns'], JSON_UNESCAPED_SLASHES, + ), + ]; + // Predis HSET takes positional field/value arguments, not + // an associative array — flatten the mapping into the order + // Redis expects on the wire. + $hsetArgs = [$key]; + foreach ($mapping as $k => $v) { + $hsetArgs[] = $k; + $hsetArgs[] = $v; + } + // MULTI/EXEC so HSET and EXPIRE either both apply or neither + // does. A connection drop between the two writes would + // otherwise leave the session without a TTL. + $this->client->transaction( + function ($tx) use ($hsetArgs, $key, $ttl) { + $tx->hset(...$hsetArgs); + $tx->expire($key, $ttl); + }, + ); + } +} diff --git a/content/develop/use-cases/agent-memory/php/src/Embedder.php b/content/develop/use-cases/agent-memory/php/src/Embedder.php new file mode 100644 index 0000000000..2ebee2aa3f --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/src/Embedder.php @@ -0,0 +1,122 @@ +modelName = $modelName; + $this->extractor = pipeline('embeddings', $modelName); + // Probe the output shape once so callers can compare against + // the index's expected vector dimension before doing any + // inserts. LongTermMemory also checks length on every + // remember / recall, so a model swap that produces wrong-dim + // vectors fails at the call site with a clear error. + $probe = $this->encodeOne('dimension probe'); + $this->dim = count($probe); + } + + /** + * Encode a single string. Returns an array of `dim` floats, + * L2-normalized so the dot product equals the cosine similarity. + * + * @return list + */ + public function encodeOne(string $text): array + { + return $this->encodeMany([$text])[0]; + } + + /** + * Encode several strings in one pipeline call. Returns an array + * of L2-normalized float vectors, one per input. + * + * @param list $texts + * @return list> + */ + public function encodeMany(array $texts): array + { + // TransformersPHP's `feature-extraction` / `embeddings` + // pipeline returns shape [batch, dim] when called with + // `pooling: 'mean'`. Without the pooling argument we'd get + // [batch, n_tokens, dim] back and would have to pool + // ourselves; we always pool to keep one code path. + $raw = ($this->extractor)($texts, pooling: 'mean'); + + // Some library versions hand back a single nested array when + // the input is also a single string; normalize to [batch, dim]. + if (!empty($raw) && !is_array($raw[0])) { + $raw = [$raw]; + } + + $out = []; + foreach ($raw as $row) { + $out[] = self::_normalize($row); + } + return $out; + } + + /** + * Pack a float vector into the little-endian float32 byte string + * Redis Search expects as a vector PARAM value. + */ + public static function toBytes(array $vector): string + { + // `pack('g*', ...)` writes a packed array of little-endian + // floats — the same byte layout `float32` ONNX outputs use + // on every platform we care about (x86_64 / arm64, both LE). + return pack('g*', ...$vector); + } + + /** + * @param list $vector + * @return list + */ + private static function _normalize(array $vector): array + { + $sum = 0.0; + foreach ($vector as $v) { + $sum += $v * $v; + } + if ($sum <= 0.0) { + return $vector; + } + $inv = 1.0 / sqrt($sum); + foreach ($vector as $i => $v) { + $vector[$i] = $v * $inv; + } + return $vector; + } +} diff --git a/content/develop/use-cases/agent-memory/php/src/LongTermMemory.php b/content/develop/use-cases/agent-memory/php/src/LongTermMemory.php new file mode 100644 index 0000000000..18be1d0d45 --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/src/LongTermMemory.php @@ -0,0 +1,513 @@ +`. The +// document holds the memory text, its embedding vector, and a small +// metadata block — user, namespace, kind, source thread, timestamps — +// that lets the recall query scope results without falling back to +// application-side filtering. +// +// A single Redis Search index covers the embedding plus every metadata +// field, so one `FT.SEARCH` call performs approximate-nearest- +// neighbor over the in-scope subset and returns the top-k memories +// ranked by cosine distance. The same KNN check runs at *write* time +// to deduplicate near-identical memories before they enter the store, +// which keeps the index from filling with paraphrases of the same fact +// as the agent reasons over similar topics across sessions. +// +// Memories carry one of two kinds: +// +// * `episodic` — "what happened" snapshots from a specific thread, +// written with a medium TTL so old session detail decays naturally. +// * `semantic` — distilled facts and preferences the agent should +// carry forward indefinitely. Written with no TTL by default. +// +// The split is enforced as a TAG on the index, so the recall query +// can ask for one kind or both with a filter — no separate keyspaces. + +declare(strict_types=1); + +namespace Redis\AgentMemory; + +use Predis\Client; +use Predis\Command\Argument\Search\CreateArguments; +use Predis\Command\Argument\Search\DropArguments; +use Predis\Command\Argument\Search\SchemaFields\NumericField; +use Predis\Command\Argument\Search\SchemaFields\TagField; +use Predis\Command\Argument\Search\SchemaFields\TextField; +use Predis\Command\Argument\Search\SchemaFields\VectorField; +use Predis\Command\Argument\Search\SearchArguments; +use Predis\Response\ServerException; + +class LongTermMemory +{ + public const VECTOR_DIM_DEFAULT = 384; + + // How close (cosine distance) a candidate must be to an existing + // memory to count as a duplicate at write time. Smaller = stricter. + // 0.20 is calibrated to the `all-MiniLM-L6-v2` embedding model used + // in the demo, where a paraphrase of an existing memory lands in + // the 0.10 – 0.20 range and a distinct memory lands above 0.50. + public const DEFAULT_DEDUP_THRESHOLD = 0.20; + + // How close (cosine distance) a candidate must be to count as a + // relevant recall result. Larger than the dedup threshold so the + // agent gets a wider net at read time than at write time. + public const DEFAULT_RECALL_THRESHOLD = 0.55; + + // TTL tiers, in seconds. `null` means "no TTL" — the memory + // persists until explicitly deleted or evicted under memory + // pressure. + public const TTL_BY_KIND = [ + 'episodic' => 7 * 24 * 3600, + 'semantic' => null, + ]; + + public readonly Client $client; + public readonly string $indexName; + public readonly string $keyPrefix; + public readonly int $vectorDim; + public readonly float $dedupThreshold; + public readonly float $recallThreshold; + public readonly array $ttlByKind; + + public function __construct( + Client $client, + string $indexName = 'agentmem:idx', + string $keyPrefix = 'agent:mem:', + int $vectorDim = self::VECTOR_DIM_DEFAULT, + float $dedupThreshold = self::DEFAULT_DEDUP_THRESHOLD, + float $recallThreshold = self::DEFAULT_RECALL_THRESHOLD, + ?array $ttlByKind = null, + ) { + $this->client = $client; + $this->indexName = $indexName; + $this->keyPrefix = $keyPrefix; + $this->vectorDim = $vectorDim; + $this->dedupThreshold = $dedupThreshold; + $this->recallThreshold = $recallThreshold; + $this->ttlByKind = $ttlByKind ?? self::TTL_BY_KIND; + } + + // -- Keys and index -------------------------------------------------- + + public function memoryKey(string $memoryId): string + { + return $this->keyPrefix . $memoryId; + } + + /** + * Create the Redis Search index if it doesn't already exist. + * + * The index is declared on the JSON document type, with a + * `$.embedding` path holding the vector and TAG fields for + * `user`, `namespace`, `kind`, and `source_thread`. One + * `FT.SEARCH` can therefore pre-filter by any combination of + * those tags and KNN-rank the matching memories in one pass. + */ + public function createIndex(): void + { + $schema = [ + new TextField('$.text', 'text'), + new TagField('$.user', 'user'), + new TagField('$.namespace', 'namespace'), + new TagField('$.kind', 'kind'), + new TagField('$.source_thread', 'source_thread'), + new NumericField('$.created_ts', 'created_ts', NumericField::SORTABLE), + new NumericField('$.hit_count', 'hit_count', NumericField::SORTABLE), + new VectorField( + '$.embedding', + 'HNSW', + [ + 'TYPE', 'FLOAT32', + 'DIM', $this->vectorDim, + 'DISTANCE_METRIC', 'COSINE', + ], + 'embedding', + ), + ]; + try { + $this->client->ftcreate( + $this->indexName, + $schema, + (new CreateArguments()) + ->on('JSON') + ->prefix([$this->keyPrefix]), + ); + } catch (ServerException $exc) { + if (!str_contains($exc->getMessage(), 'Index already exists')) { + throw $exc; + } + } + } + + public function dropIndex(bool $deleteDocuments = false): void + { + try { + $args = new DropArguments(); + if ($deleteDocuments) { + $args->dd(); + } + $this->client->ftdropindex($this->indexName, $args); + } catch (ServerException $exc) { + $msg = strtolower($exc->getMessage()); + if (!str_contains($msg, 'no such index') + && !str_contains($msg, 'unknown index name')) { + throw $exc; + } + } + } + + // -- Write ----------------------------------------------------------- + + /** + * Write a new memory, deduplicating against existing entries. + * + * Runs one in-scope KNN(1) against the index first. If the + * nearest existing memory is within `dedupThreshold`, the new + * memory is skipped (its content is already represented) and the + * existing memory's `hit_count` is bumped. Otherwise a fresh JSON + * document is written under a new id with a TTL derived from the + * memory's `kind`. + * + * The KNN-then-write sequence is not atomic; two workers that + * remember the same fact at the same time can both miss each + * other's in-flight write and insert duplicate memories. See the + * walkthrough's "Concurrency caveats" section for the production + * fix (periodic background consolidator that merges + * near-duplicates). + * + * @param list $embedding + * @return array + */ + public function remember( + string $text, + array $embedding, + string $user = 'default', + string $namespace = 'default', + string $kind = 'episodic', + string $sourceThread = '', + int|false $ttlSeconds = false, + ): array { + if (count($embedding) !== $this->vectorDim) { + throw new \InvalidArgumentException(sprintf( + 'embedding length is %d; index expects %d', + count($embedding), $this->vectorDim, + )); + } + + $nearest = $this->nearest($embedding, $user, $namespace, $kind, 1); + $nearestDistance = $nearest[0]['distance'] ?? null; + if (!empty($nearest) + && $nearest[0]['distance'] !== null + && $nearest[0]['distance'] <= $this->dedupThreshold) { + // Duplicate. Bump the hit count on the existing memory so + // the admin UI can show how often it's been re-derived. + $this->bumpHitCount($nearest[0]['id']); + return [ + 'id' => $nearest[0]['id'], + 'deduped' => true, + 'existing_distance' => $nearestDistance, + ]; + } + + $memoryId = substr(bin2hex(random_bytes(6)), 0, 12); + $key = $this->memoryKey($memoryId); + $now = microtime(true); + $doc = [ + 'id' => $memoryId, + 'user' => $user, + 'namespace' => $namespace, + 'kind' => $kind, + 'source_thread' => $sourceThread, + 'text' => $text, + 'embedding' => $embedding, + 'created_ts' => $now, + 'hit_count' => 0, + ]; + $ttl = $this->resolveTtl($kind, $ttlSeconds); + + // MULTI/EXEC so the JSON document and its TTL apply together. + // A connection drop between the JSON.SET and EXPIRE would + // otherwise leave the memory without an expiry. + $this->client->transaction( + function ($tx) use ($key, $doc, $ttl) { + $tx->jsonset($key, '$', json_encode($doc, JSON_THROW_ON_ERROR)); + if ($ttl !== null) { + $tx->expire($key, $ttl); + } + }, + ); + return [ + 'id' => $memoryId, + 'deduped' => false, + 'existing_distance' => $nearestDistance, + ]; + } + + // -- Recall ---------------------------------------------------------- + + /** + * Return the top-k in-scope memories ranked by similarity. + * + * Memories beyond `distanceThreshold` (or the instance default) + * are dropped — the index always returns *something* for KNN, so + * a recall result on an unrelated query would otherwise be a + * confidently-wrong false positive. + * + * @param list $queryEmbedding + * @return list> + */ + public function recall( + array $queryEmbedding, + string $user = 'default', + ?string $namespace = 'default', + ?string $kind = null, + int $k = 5, + float|null $distanceThreshold = null, + ): array { + $threshold = $distanceThreshold ?? $this->recallThreshold; + $candidates = $this->nearest($queryEmbedding, $user, $namespace, $kind, $k); + return array_values(array_filter( + $candidates, + fn(array $c) => $c['distance'] !== null && $c['distance'] <= $threshold, + )); + } + + // -- Internals ------------------------------------------------------- + + /** + * @param list $embedding + * @return list> + */ + private function nearest( + array $embedding, + ?string $user, + ?string $namespace, + ?string $kind, + int $k, + ): array { + if (count($embedding) !== $this->vectorDim) { + throw new \InvalidArgumentException(sprintf( + 'embedding length is %d; index expects %d', + count($embedding), $this->vectorDim, + )); + } + $filterClause = self::buildFilterClause($user, $namespace, $kind); + $knnQuery = sprintf( + '%s=>[KNN %d @embedding $vec AS distance]', + $filterClause, $k, + ); + $args = (new SearchArguments()) + ->params(['vec', Embedder::toBytes($embedding)]) + ->dialect('2') + ->sortBy('distance', 'asc') + ->limit(0, $k) + ->addReturn( + 8, + 'user', 'namespace', 'kind', 'source_thread', + 'text', 'created_ts', 'hit_count', 'distance', + ); + $result = $this->client->ftsearch($this->indexName, $knnQuery, $args); + return $this->parseSearchResult($result); + } + + private function bumpHitCount(string $memoryId): void + { + try { + $this->client->jsonnumincrby( + $this->memoryKey($memoryId), '$.hit_count', 1, + ); + } catch (ServerException) { + // The doc may have expired between recall and bump — fine, + // we just lose the hit count update. + } + } + + private function resolveTtl(string $kind, int|false $override): ?int + { + if ($override === false) { + return $this->ttlByKind[$kind] ?? null; + } + return $override; + } + + private function stripPrefix(string $rawKey): string + { + if (str_starts_with($rawKey, $this->keyPrefix)) { + return substr($rawKey, strlen($this->keyPrefix)); + } + return $rawKey; + } + + // Characters Redis Search treats as syntax inside a TAG value; any + // of them in a user-supplied filter must be backslash-escaped or + // the surrounding `{...}` block won't parse correctly. + private const TAG_SPECIAL = '\\,.<>{}[]"\':;!@#$%^&*()-+=~| '; + + private static function escapeTagValue(string $value): string + { + $out = ''; + $specials = str_split(self::TAG_SPECIAL); + foreach (mb_str_split($value) as $ch) { + $out .= in_array($ch, $specials, true) ? '\\' . $ch : $ch; + } + return $out; + } + + private static function buildFilterClause( + ?string $user, + ?string $namespace, + ?string $kind, + ): string { + // Truthy-check would drop `"0"` as falsy and silently broaden + // the scope — only `null` and `""` mean "no filter". + $clauses = []; + if ($user !== null && $user !== '') { + $clauses[] = '@user:{' . self::escapeTagValue($user) . '}'; + } + if ($namespace !== null && $namespace !== '') { + $clauses[] = '@namespace:{' . self::escapeTagValue($namespace) . '}'; + } + if ($kind !== null && $kind !== '') { + $clauses[] = '@kind:{' . self::escapeTagValue($kind) . '}'; + } + return empty($clauses) ? '(*)' : '(' . implode(' ', $clauses) . ')'; + } + + /** + * Parse Predis's flat FT.SEARCH result into structured records. + * + * The wire format under DIALECT 2 is: `[total, key1, [k1, v1, + * k2, v2, ...], key2, [...], ...]`. We translate the + * alternating field-pair arrays into associative arrays, strip + * the key prefix off the document id, and look up each + * document's TTL so the admin panel can show it. + * + * @param mixed $result + * @return list> + */ + private function parseSearchResult(mixed $result): array + { + if (!is_array($result) || count($result) < 2) { + return []; + } + $count = (int) $result[0]; + $out = []; + for ($i = 0; $i < $count; $i++) { + $idIdx = 1 + ($i * 2); + $fieldsIdx = $idIdx + 1; + if (!isset($result[$idIdx], $result[$fieldsIdx])) { + break; + } + $rawKey = (string) $result[$idIdx]; + $memoryId = $this->stripPrefix($rawKey); + $flat = $result[$fieldsIdx]; + $fields = is_array($flat) ? self::flatToAssoc($flat) : []; + $ttl = (int) $this->client->ttl($this->memoryKey($memoryId)); + $out[] = [ + 'id' => $memoryId, + 'user' => $fields['user'] ?? '', + 'namespace' => $fields['namespace'] ?? '', + 'kind' => $fields['kind'] ?? '', + 'source_thread' => $fields['source_thread'] ?? '', + 'text' => $fields['text'] ?? '', + 'created_ts' => (float) ($fields['created_ts'] ?? 0), + 'hit_count' => (int) ($fields['hit_count'] ?? 0), + 'distance' => isset($fields['distance']) + ? (float) $fields['distance'] + : null, + 'ttl_seconds' => $ttl > 0 ? $ttl : null, + ]; + } + return $out; + } + + /** @param list $flat */ + private static function flatToAssoc(array $flat): array + { + $out = []; + for ($i = 0, $n = count($flat) - 1; $i < $n; $i += 2) { + $out[(string) $flat[$i]] = $flat[$i + 1]; + } + return $out; + } + + // -- Admin / inspection --------------------------------------------- + + /** + * @return array{num_docs:int,indexing_failures:int} + */ + public function indexInfo(): array + { + try { + $info = $this->client->ftinfo($this->indexName); + } catch (ServerException) { + return ['num_docs' => 0, 'indexing_failures' => 0]; + } + // Predis returns FT.INFO as a flat alternating key/value array. + $assoc = self::flatToAssoc($info); + return [ + 'num_docs' => (int) ($assoc['num_docs'] ?? 0), + 'indexing_failures' => (int) ( + $assoc['hash_indexing_failures'] ?? 0 + ), + ]; + } + + /** + * Return memories matching the filters, newest first. + * + * @return list> + */ + public function listMemories( + ?string $user = null, + ?string $namespace = null, + ?string $kind = null, + int $limit = 100, + ): array { + $filterClause = self::buildFilterClause($user, $namespace, $kind); + $args = (new SearchArguments()) + ->dialect('2') + ->sortBy('created_ts', 'desc') + ->limit(0, $limit) + ->addReturn( + 7, + 'user', 'namespace', 'kind', 'source_thread', + 'text', 'created_ts', 'hit_count', + ); + try { + $result = $this->client->ftsearch( + $this->indexName, $filterClause, $args, + ); + } catch (ServerException) { + return []; + } + $out = $this->parseSearchResult($result); + // No `distance` is requested by listMemories, so drop the + // null distance field rather than confuse the UI. + foreach ($out as &$row) { + unset($row['distance']); + } + return $out; + } + + public function deleteMemory(string $memoryId): bool + { + return ((int) $this->client->del($this->memoryKey($memoryId))) > 0; + } + + /** + * Drop the index and every memory document. Returns the count of + * documents that were removed. In production the equivalent is + * `FLUSHDB` on a dedicated memory database, or letting TTLs and + * eviction expire entries naturally. + */ + public function clear(): int + { + $before = $this->indexInfo()['num_docs']; + $this->dropIndex(deleteDocuments: true); + $this->createIndex(); + return $before; + } +} diff --git a/content/develop/use-cases/agent-memory/php/src/SeedMemory.php b/content/develop/use-cases/agent-memory/php/src/SeedMemory.php new file mode 100644 index 0000000000..4ad585221d --- /dev/null +++ b/content/develop/use-cases/agent-memory/php/src/SeedMemory.php @@ -0,0 +1,101 @@ + 'The user prefers concise answers without filler phrases.', + 'kind' => 'semantic', + ], + [ + 'text' => 'The user is a Python developer working on a logistics platform.', + 'kind' => 'semantic', + ], + [ + 'text' => 'The user lives in Berlin and works in the Europe/Berlin time zone.', + 'kind' => 'semantic', + ], + [ + 'text' => + 'The user dislikes dark mode and prefers a high-contrast ' + . 'light theme in editors and dashboards.', + 'kind' => 'semantic', + ], + [ + 'text' => + 'The user is allergic to peanuts; any restaurant suggestion ' + . 'must avoid dishes that commonly contain them.', + 'kind' => 'semantic', + ], + [ + 'text' => + 'Last Tuesday the user asked the agent to draft a postmortem ' + . 'for the order-routing outage. The agent produced a ' + . 'five-section draft and the user approved sections 1, 2, ' + . 'and 4 with minor edits.', + 'kind' => 'episodic', + ], + [ + 'text' => + 'In a previous session the user asked for help debugging a ' + . 'flaky test in the inventory service. The fix turned out ' + . 'to be a race condition in the warehouse webhook handler.', + 'kind' => 'episodic', + ], + [ + 'text' => + 'Two weeks ago the user mentioned they were planning to ' + . 'migrate the analytics warehouse from Snowflake to ' + . 'BigQuery in Q3.', + 'kind' => 'episodic', + ], + ]; + + /** + * Embed and write the seed memories. Returns the count actually + * written (entries that dedup against existing memories don't count). + */ + public static function seed( + LongTermMemory $memory, + Embedder $embedder, + string $user = 'default', + string $namespace = 'default', + string $sourceThread = 'seed', + ): int { + $texts = array_map(fn($m) => $m['text'], self::SEED_MEMORIES); + $vectors = $embedder->encodeMany($texts); + $written = 0; + foreach (self::SEED_MEMORIES as $i => $entry) { + $result = $memory->remember( + text: $entry['text'], + embedding: $vectors[$i], + user: $user, + namespace: $namespace, + kind: $entry['kind'], + sourceThread: $sourceThread, + ); + if (!$result['deduped']) { + $written++; + } + } + return $written; + } +} diff --git a/content/develop/use-cases/agent-memory/redis-py/_index.md b/content/develop/use-cases/agent-memory/redis-py/_index.md new file mode 100644 index 0000000000..222074c2af --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/_index.md @@ -0,0 +1,331 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Python with redis-py, sentence-transformers, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: redis-py example (Python) +title: Redis agent memory with redis-py +weight: 1 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Python with [`redis-py`]({{< relref "/develop/clients/redis-py" >}}) and the [`sentence-transformers`](https://www.sbert.net/) library, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with the Python standard library so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* A single round trip per tier: one [`HGETALL`]({{< relref "/commands/hgetall" >}}) for the session, one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) for recall, one [`XADD`]({{< relref "/commands/xadd" >}}) for the event log. +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate the per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.encode_one(text)` to turn the incoming turn into a 384-dimensional `float32` vector. +2. `session.append_turn(thread_id, role, content)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with an [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) pipeline. The session TTL refreshes on every write so an active thread stays alive. +3. `memory.recall(vec, user=..., namespace=..., k=5)` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.remember(text, vec, user=..., namespace=..., kind=...)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `event_log.record(thread_id, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/redis-py/session_store.py)): + +```python +import redis +from session_store import AgentSession + +r = redis.Redis(host="localhost", port=6379, decode_responses=False) +session = AgentSession( + redis_client=r, + key_prefix="agent:session:", + default_ttl_seconds=3600, # one hour + max_turns=20, # rolling window per thread +) + +thread_id = session.new_thread_id() +session.start(thread_id, user="alice", agent="demo-agent", + goal="Plan next week's meetings.") +session.append_turn(thread_id, role="user", + content="Schedule a budget review with finance.") +state = session.load(thread_id) +print(state.turn_count, len(state.recent_turns), state.ttl_seconds) +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `append_turn`, `set_scratchpad` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a `MULTI` / `EXEC` block, so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/redis-py/long_term_memory.py)): + +```python +import numpy as np +from long_term_memory import LongTermMemory +from embeddings import LocalEmbedder + +memory = LongTermMemory( + redis_client=r, + index_name="agentmem:idx", + key_prefix="agent:mem:", + dedup_threshold=0.20, # cosine distance — tight at write time + recall_threshold=0.55, # looser at read time +) +embedder = LocalEmbedder() +memory.create_index() # idempotent + +# Write a memory. The same KNN that powers recall also runs here +# at a tighter threshold so paraphrases of the same fact collapse. +vec = embedder.encode_one("The user prefers light mode in editors.") +result = memory.remember( + text="The user prefers light mode in editors.", + embedding=np.asarray(vec, dtype=np.float32), + user="alice", + namespace="default", + kind="semantic", + source_thread="9f3d2a4b8c61", +) +print(result.deduped, result.id, result.existing_distance) + +# Recall against a later question. +q = embedder.encode_one("Which theme does this user like?") +hits = memory.recall( + query_embedding=np.asarray(q, dtype=np.float32), + user="alice", + namespace="default", + k=5, +) +for h in hits: + print(f"{h.distance:.3f} [{h.kind}] {h.text}") +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes, regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with `as_name` aliases so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup are the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write with `ttl_seconds=...` on `remember`, or pass a different `ttl_by_kind={...}` map to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/redis-py/event_log.py)): + +```python +from event_log import AgentEventLog + +events = AgentEventLog(redis_client=r, max_len=1000) +events.record(thread_id, action="turn_appended:user", + detail="Schedule a budget review with finance.") +events.record(thread_id, action="memory_written", + detail="wrote 7c3f8a1b9e02 as semantic") + +for event in events.recent(thread_id, count=20): + print(event.action, event.detail) +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `maxlen=~1000`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarizers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession.append_turn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recent_turns` list in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recent_turns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory.remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `current_thread_id` that `/new_thread` and `/reset` mutate under a lock; `handle_turn` reads it outside that lock, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `seed_memory.py` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/redis-py/seed_memory.py)): + +```python +from seed_memory import seed +from long_term_memory import LongTermMemory +from embeddings import LocalEmbedder + +memory = LongTermMemory() +embedder = LocalEmbedder() +memory.create_index() +seed(memory, embedder, user="default", namespace="default") +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`demo_server.py` runs a [`ThreadingHTTPServer`](https://docs.python.org/3/library/http.server.html) on port 8086. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a class attribute that the **New thread** button rotates — every browser session inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/redis-py + ``` + +2. Install the dependencies: + + ```bash + pip install redis sentence-transformers numpy + ``` + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) ships both, + or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the Search and JSON modules + enabled. + +4. Start the demo server. The first run downloads the `all-MiniLM-L6-v2` model + (~80 MB) into the local Hugging Face cache: + + ```bash + python demo_server.py + ``` + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing + recalls above the threshold (every seed lands above 0.8), and the new + memory is written as `episodic` (or `semantic`, depending on the + dropdown) under a fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + `all-MiniLM-L6-v2` puts a faithful paraphrase in the 0.15 – 0.50 + cosine-distance range, a loose paraphrase or related topic in the + 0.50 – 0.80 range, and unrelated queries above 0.8 — which is what + motivates the 0.55 default recall threshold and the 0.20 default + dedup threshold. A stricter embedding model (or a domain-tuned one) + would let you tighten both; a noisier one would push them up. The + right thresholds are always a function of the model, the corpus, + and how conservative the agent needs to be about accepting a memory + as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags: + +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/redis-py/demo_server.py b/content/develop/use-cases/agent-memory/redis-py/demo_server.py new file mode 100644 index 0000000000..b1f44b88f2 --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/demo_server.py @@ -0,0 +1,1105 @@ +#!/usr/bin/env python3 +""" +Redis agent-memory demo server. + +Run this file and visit http://localhost:8086 to drive a small +agent-memory demo backed by Redis Hashes, JSON, Search, and Streams. +The UI lets you: + +* Type a turn as the user (or paste a goal / scratchpad note). The + server appends the turn to the per-thread working-memory hash, + embeds the turn, recalls the top-k semantically nearest long-term + memories, optionally writes the turn back as a new memory with + write-time deduplication, and appends an event to the per-thread + stream. +* Watch the three memory tiers update in place: working memory in + one Hash, long-term memories as JSON documents under one index, + and the event log in one Stream. +* Switch user, namespace, kind, and recall threshold to see how + scoping changes which memories the agent sees. +* Inspect every long-term memory (including remaining TTL and total + hit count) and drop individual memories to simulate eviction. + +The server holds a single ``LocalEmbedder``, one ``AgentSession`` +(working memory), one ``LongTermMemory`` (semantic recall + dedup), +and one ``AgentEventLog`` (event stream) for the lifetime of the +process. The first run downloads the embedding model (~80 MB) into +the local Hugging Face cache; everything after is local. +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from urllib.parse import parse_qs, urlparse + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +try: + import redis + + from embeddings import LocalEmbedder + from event_log import AgentEventLog + from long_term_memory import LongTermMemory + from seed_memory import seed + from session_store import AgentSession +except ImportError as exc: + print(f"Error: {exc}") + print( + "Make sure the required packages are installed:\n" + " pip install redis sentence-transformers numpy" + ) + sys.exit(1) + + +HTML_TEMPLATE = """ + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + +""" + + +class AgentMemoryDemo: + """Demo state: working memory, long-term memory, event log.""" + + def __init__( + self, + session_store: AgentSession, + memory: LongTermMemory, + event_log: AgentEventLog, + embedder: LocalEmbedder, + default_user: str = "default", + default_namespace: str = "default", + ) -> None: + self.session_store = session_store + self.memory = memory + self.event_log = event_log + self.embedder = embedder + self.default_user = default_user + self.default_namespace = default_namespace + self.current_thread_id: str = session_store.new_thread_id() + + # ``seed`` / ``new_thread`` / ``handle_turn`` all touch + # ``current_thread_id`` without coordination — see the walkthrough's + # "Concurrency caveats" section. The demo is single-user in + # practice, so the race never triggers; a multi-user agent would + # carry the thread id on each request instead of holding it as + # shared server state. + + def seed(self, user: str, namespace: str) -> int: + """Drop everything in scope and pre-populate with seed memories.""" + self.memory.clear() + self.session_store.delete(self.current_thread_id) + self.event_log.clear(self.current_thread_id) + written = seed( + self.memory, + self.embedder, + user=user, + namespace=namespace, + source_thread="seed", + ) + self.current_thread_id = self.session_store.new_thread_id() + return written + + def new_thread(self, user: str, namespace: str) -> str: + """Start a fresh thread. Long-term memory is unaffected.""" + self.event_log.clear(self.current_thread_id) + self.current_thread_id = self.session_store.new_thread_id() + self.session_store.start( + self.current_thread_id, user=user, + agent="demo-agent", goal="", + ) + self.event_log.record( + self.current_thread_id, "thread_started", + f"user={user} namespace={namespace}", + ) + return self.current_thread_id + + def handle_turn( + self, + text: str, + user: str, + namespace: str, + kind: str, + role: str, + threshold: float, + action: str, + ) -> dict: + """One pass through the agent loop: append, recall, remember, log. + + The order matters. We embed once and reuse the vector for + both the recall and (if asked) the remember step — no point + encoding the same text twice. Recall runs *before* the + remember write so the agent doesn't see its own just-written + turn as a recalled memory; some agent designs do the + opposite, but the demo defaults to the more useful one. + """ + thread_id = self.current_thread_id + + t0 = time.perf_counter() + vec = self.embedder.encode_one(text) + embed_ms = (time.perf_counter() - t0) * 1000 + + # Append to working memory or update the goal, depending on + # which button the user pressed. ``set_goal`` only touches the + # goal field, so existing turns aren't wiped; ``append_turn`` + # carries the request ``user`` through to the auto-create path + # so a first turn for a new thread doesn't land under the + # default user. + if action == "goal": + self.session_store.set_goal( + thread_id, text, user=user, agent="demo-agent", + ) + session_action = "goal_set" + else: + self.session_store.append_turn( + thread_id, + role=role, + content=text, + user=user, + agent="demo-agent", + ) + session_action = f"turn_appended:{role}" + + # Recall before write. + t1 = time.perf_counter() + recalled = self.memory.recall( + query_embedding=np.asarray(vec, dtype=np.float32), + user=user, + namespace=namespace, + k=5, + distance_threshold=threshold, + ) + recall_ms = (time.perf_counter() - t1) * 1000 + + # Optionally remember the turn as a long-term memory. + write_skipped = (kind == "skip" or action == "goal") + write_result = None + if not write_skipped: + t2 = time.perf_counter() + write_result = self.memory.remember( + text=text, + embedding=np.asarray(vec, dtype=np.float32), + user=user, + namespace=namespace, + kind=kind, + source_thread=thread_id, + ) + write_ms = (time.perf_counter() - t2) * 1000 + else: + write_ms = 0.0 + + # Append to event log so the audit trail shows what happened. + if write_result is not None: + event_detail = ( + f"deduped onto {write_result.id}" + if write_result.deduped else f"wrote {write_result.id} as {kind}" + ) + self.event_log.record(thread_id, session_action, event_detail) + else: + self.event_log.record(thread_id, session_action, "") + + return { + "thread_id": thread_id, + "write_skipped": write_skipped, + "memory_id": write_result.id if write_result else None, + "deduped": write_result.deduped if write_result else False, + "existing_distance": + write_result.existing_distance if write_result else None, + "kind": kind if not write_skipped else None, + "recalled": [m.to_dict() for m in recalled], + "embed_ms": embed_ms, + "recall_ms": recall_ms, + "write_ms": write_ms, + } + + +class AgentMemoryHandler(BaseHTTPRequestHandler): + """HTTP handler. Server-state lives on class attributes.""" + + session_store: AgentSession | None = None + memory: LongTermMemory | None = None + event_log: AgentEventLog | None = None + embedder: LocalEmbedder | None = None + demo: AgentMemoryDemo | None = None + + # ------------------------------------------------------------------ + # GET + # ------------------------------------------------------------------ + + def do_GET(self) -> None: + try: + parsed = urlparse(self.path) + if parsed.path in {"/", "/index.html"}: + self._send_html(self._html_page()) + return + if parsed.path == "/state": + params = parse_qs(parsed.query) + user = (params.get("user", ["default"])[0] + or self.demo.default_user) + namespace = (params.get("namespace", ["default"])[0] + or self.demo.default_namespace) + self._send_json(self._build_state(user, namespace), 200) + return + self.send_error(404) + except Exception as exc: + self._send_error_json(exc) + + # ------------------------------------------------------------------ + # POST + # ------------------------------------------------------------------ + + def do_POST(self) -> None: + try: + parsed = urlparse(self.path) + if parsed.path == "/turn": + self._handle_turn() + return + if parsed.path == "/new_thread": + self._handle_new_thread() + return + if parsed.path == "/reset": + self._handle_reset() + return + if parsed.path == "/drop_memory": + self._handle_drop_memory() + return + self.send_error(404) + except Exception as exc: + self._send_error_json(exc) + + def _send_error_json(self, exc: Exception) -> None: + """Return a JSON 500 so the client's ``res.json()`` works. + + Without this wrapper, an exception in a handler escapes to + ``BaseHTTPRequestHandler`` which writes a plain-text 500 page; + the demo's ``fetch().then(r => r.json())`` then explodes with + an opaque JSON parse error instead of surfacing what went wrong. + """ + sys.stderr.write(f"[demo] handler error: {type(exc).__name__}: {exc}\n") + try: + self._send_json( + {"error": str(exc), "type": type(exc).__name__}, 500, + ) + except Exception: + pass + + # ---- handlers --------------------------------------------------- + + def _handle_turn(self) -> None: + params = self._read_form() + text = params.get("text", [""])[0].strip() + if not text: + self._send_json({"error": "text is required"}, 400) + return + try: + threshold = float(params.get("threshold", ["0.55"])[0]) + except ValueError: + threshold = 0.55 + # ``float()`` happily parses "nan"/"inf"; either would silently + # turn recall into "every memory" or "nothing". Clamp to the + # meaningful cosine-distance range. + import math + if not math.isfinite(threshold): + threshold = 0.55 + threshold = max(0.0, min(2.0, threshold)) + payload = self.demo.handle_turn( + text=text, + user=params.get("user", ["default"])[0] or "default", + namespace=params.get("namespace", ["default"])[0] or "default", + kind=params.get("kind", ["episodic"])[0] or "episodic", + role=params.get("role", ["user"])[0] or "user", + threshold=threshold, + action=params.get("action", ["turn"])[0] or "turn", + ) + self._send_json(payload, 200) + + def _handle_new_thread(self) -> None: + params = self._read_form() + thread_id = self.demo.new_thread( + user=params.get("user", ["default"])[0] or "default", + namespace=params.get("namespace", ["default"])[0] or "default", + ) + self._send_json({"thread_id": thread_id}, 200) + + def _handle_reset(self) -> None: + params = self._read_form() + seeded = self.demo.seed( + user=params.get("user", ["default"])[0] or "default", + namespace=params.get("namespace", ["default"])[0] or "default", + ) + self._send_json({"seeded": seeded}, 200) + + def _handle_drop_memory(self) -> None: + params = self._read_form() + memory_id = params.get("memory_id", [""])[0].strip() + if not memory_id: + self._send_json({"error": "memory_id is required"}, 400) + return + deleted = self.memory.delete_memory(memory_id) + self._send_json({"deleted": deleted, "memory_id": memory_id}, 200) + + # ---- state assembly --------------------------------------------- + + def _build_state(self, user: str, namespace: str) -> dict: + info = self.memory.index_info() + info["index_name"] = self.memory.index_name + info["model"] = self.embedder.model_name + info["session_ttl_seconds"] = self.session_store.default_ttl_seconds + info["dedup_threshold"] = self.memory.dedup_threshold + info["default_recall_threshold"] = self.memory.recall_threshold + info["stack_label"] = ( + "redis-py + sentence-transformers + " + "Python standard library HTTP server" + ) + thread_id = self.demo.current_thread_id + session = self.session_store.load(thread_id) + memories = self.memory.list_memories( + user=user, namespace=namespace, limit=200, + ) + events = self.event_log.recent(thread_id, count=20) + return { + "index": info, + "thread_id": thread_id, + "session": session.to_dict() if session else None, + "memories": [m.to_dict() for m in memories], + "events": [e.to_dict() for e in events], + # ``recalled`` is populated by /turn; on plain /state reads + # the UI keeps showing the last turn's result, which is + # the useful behavior for an "agent" panel. + "recalled": [], + } + + # ---- HTTP plumbing ---------------------------------------------- + + # Cap POST bodies so a runaway client (or a ``curl --data-binary + # @big-file`` by mistake) can't make the server buffer unbounded + # data before the handler runs. The demo's largest legitimate + # body is a few hundred bytes of form-encoded query fields; 1 MiB + # is a generous ceiling matching the Node, .NET, Rust, and Go + # demos. + _MAX_BODY_BYTES = 1 * 1024 * 1024 + + def _read_form(self) -> dict[str, list[str]]: + length = int(self.headers.get("Content-Length", "0")) + if length < 0: + length = 0 + if length > self._MAX_BODY_BYTES: + # Read and discard so the connection stays in a sane state, + # then raise — the handler's wrapper converts this into a + # 500 with a JSON body the demo UI can render. + self.rfile.read(min(length, self._MAX_BODY_BYTES)) + raise ValueError( + f"request body exceeds {self._MAX_BODY_BYTES} bytes" + ) + raw = self.rfile.read(length).decode("utf-8") if length else "" + return parse_qs(raw) + + def _send_html(self, html: str, status: int = 200) -> None: + self.send_response(status) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write(html.encode("utf-8")) + + def _send_json(self, payload: dict, status: int) -> None: + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(payload, default=_json_default).encode("utf-8")) + + def log_message(self, format: str, *args) -> None: # noqa: A002 + sys.stderr.write(f"[demo] {format % args}\n") + + def _html_page(self) -> str: + return ( + HTML_TEMPLATE + .replace("__SESSION_PREFIX__", self.session_store.key_prefix) + .replace("__MEM_PREFIX__", self.memory.key_prefix) + .replace("__MEM_INDEX__", self.memory.index_name) + .replace("__EVENT_PREFIX__", self.event_log.key_prefix) + ) + + +def _json_default(value): + if isinstance(value, np.floating): + return float(value) + if isinstance(value, np.integer): + return int(value) + if isinstance(value, np.ndarray): + return value.tolist() + raise TypeError(f"unserializable: {type(value).__name__}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run the Redis agent-memory demo server.", + ) + parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host") + parser.add_argument("--port", type=int, default=8086, help="HTTP bind port") + parser.add_argument("--redis-host", default="localhost", help="Redis host") + parser.add_argument("--redis-port", type=int, default=6379, help="Redis port") + parser.add_argument( + "--mem-index-name", default="agentmem:idx", + help="Redis Search index name for long-term memories", + ) + parser.add_argument( + "--mem-key-prefix", default="agent:mem:", + help="JSON key prefix for long-term memories", + ) + parser.add_argument( + "--session-key-prefix", default="agent:session:", + help="Hash key prefix for working memory", + ) + parser.add_argument( + "--event-key-prefix", default="agent:events:", + help="Stream key prefix for the agent event log", + ) + parser.add_argument( + "--session-ttl-seconds", type=int, default=3600, + help="TTL applied to working-memory hashes on every write", + ) + parser.add_argument( + "--dedup-threshold", type=float, default=0.20, + help="Cosine-distance threshold for write-time deduplication", + ) + parser.add_argument( + "--recall-threshold", type=float, default=0.55, + help="Default cosine-distance threshold for recall results", + ) + parser.add_argument( + "--no-reset", dest="reset_on_start", action="store_false", + help=( + "Keep any existing memories instead of dropping and re-seeding" + " on startup." + ), + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + redis_client = redis.Redis( + host=args.redis_host, + port=args.redis_port, + decode_responses=False, + ) + try: + redis_client.ping() + except redis.ConnectionError as exc: + print(f"Error: cannot reach Redis at {args.redis_host}:{args.redis_port}") + print(f" ({exc})") + sys.exit(1) + + session_store = AgentSession( + redis_client=redis_client, + key_prefix=args.session_key_prefix, + default_ttl_seconds=args.session_ttl_seconds, + ) + memory = LongTermMemory( + redis_client=redis_client, + index_name=args.mem_index_name, + key_prefix=args.mem_key_prefix, + dedup_threshold=args.dedup_threshold, + recall_threshold=args.recall_threshold, + ) + memory.create_index() + event_log = AgentEventLog( + redis_client=redis_client, + key_prefix=args.event_key_prefix, + ) + + print("Loading embedding model (first run downloads ~80 MB)...") + embedder = LocalEmbedder() + + demo = AgentMemoryDemo( + session_store=session_store, + memory=memory, + event_log=event_log, + embedder=embedder, + ) + + if args.reset_on_start: + print( + f"Dropping any existing memories under '{args.mem_key_prefix}*' and" + " re-seeding from the sample memory list (pass --no-reset to keep)." + ) + seeded = demo.seed(user="default", namespace="default") + print(f"Seeded {seeded} memories.") + + AgentMemoryHandler.session_store = session_store + AgentMemoryHandler.memory = memory + AgentMemoryHandler.event_log = event_log + AgentMemoryHandler.embedder = embedder + AgentMemoryHandler.demo = demo + + print( + f"Redis agent memory demo listening on " + f"http://{args.host}:{args.port}" + ) + print( + f"Using Redis at {args.redis_host}:{args.redis_port}" + f" with memory index '{args.mem_index_name}'" + ) + + server = ThreadingHTTPServer((args.host, args.port), AgentMemoryHandler) + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/content/develop/use-cases/agent-memory/redis-py/embeddings.py b/content/develop/use-cases/agent-memory/redis-py/embeddings.py new file mode 100644 index 0000000000..0cfbe6784f --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/embeddings.py @@ -0,0 +1,54 @@ +""" +Local text-embedding helper backed by sentence-transformers. + +This is a thin wrapper around the ``sentence-transformers`` model +``all-MiniLM-L6-v2``: a 384-dimensional encoder that runs on CPU, +needs no API key, and has a small footprint (~80 MB). On the first +call the model is downloaded into the local Hugging Face cache; every +later call runs locally. + +Vectors are L2-normalized on output so a Redis Search index declared +with ``DISTANCE_METRIC COSINE`` returns scores that are directly +comparable across entries. +""" + +from __future__ import annotations + +from typing import Iterable + +import numpy as np + + +_DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L6-v2" + + +class LocalEmbedder: + """Encode short strings into normalized float32 vectors. + + A single instance loads the model once and reuses it for every + call. The demo server keeps one ``LocalEmbedder`` around for the + lifetime of the process so each turn embeds the new memory and + the recall query with the same model in two encode calls. + """ + + def __init__(self, model_name: str = _DEFAULT_MODEL) -> None: + from sentence_transformers import SentenceTransformer + + self.model_name = model_name + self.model = SentenceTransformer(model_name) + self.dim = int(self.model.get_sentence_embedding_dimension()) + + def encode_one(self, text: str) -> np.ndarray: + """Encode a single string. Returns a 1-D ``float32`` array.""" + return self.encode_many([text])[0] + + def encode_many(self, texts: Iterable[str]) -> np.ndarray: + """Encode a batch. Returns an ``(N, dim) float32`` array.""" + batch = list(texts) + vectors = self.model.encode( + batch, + batch_size=32, + normalize_embeddings=True, + convert_to_numpy=True, + ) + return vectors.astype(np.float32, copy=False) diff --git a/content/develop/use-cases/agent-memory/redis-py/event_log.py b/content/develop/use-cases/agent-memory/redis-py/event_log.py new file mode 100644 index 0000000000..c8d066efd4 --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/event_log.py @@ -0,0 +1,121 @@ +""" +Append-only event log for an agent thread, backed by a Redis Stream. + +Each thread gets a stream at ``agent:events:{thread_id}``. Every +action the agent takes (a user turn arriving, a memory being +recalled, a memory being written, a tool being called) is one +``XADD`` to that stream. Replay with ``XREVRANGE`` for the most +recent N events; bound retention with ``XTRIM MAXLEN ~`` so the log +stays cheap regardless of how long the thread has been running. + +The stream is independent of the session hash (``session_store.py``) +and the long-term memory store (``long_term_memory.py``): it answers +the "what just happened" question without competing with either of +those for indexing or memory budget. Consumer groups (not used in +this demo) would let downstream workers — summarizers, +consolidators, audit pipelines — replay the log without losing +position. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Optional + +import redis + + +# Approximate cap on stream length. ``MAXLEN ~`` lets Redis trim in +# whole-node units instead of exactly-N units, which is much cheaper +# at the cost of overshooting the bound by up to a node's worth. +DEFAULT_MAXLEN = 1000 + + +@dataclass +class AgentEvent: + event_id: str + thread_id: str + action: str + detail: str + ts: float + + def to_dict(self) -> dict: + return { + "event_id": self.event_id, + "thread_id": self.thread_id, + "action": self.action, + "detail": self.detail, + "ts": self.ts, + } + + +class AgentEventLog: + """Append, replay, and bound the per-thread event stream.""" + + def __init__( + self, + redis_client: Optional[redis.Redis] = None, + key_prefix: str = "agent:events:", + max_len: int = DEFAULT_MAXLEN, + ) -> None: + self.redis = redis_client or redis.Redis( + host="localhost", port=6379, decode_responses=False, + ) + self.key_prefix = key_prefix + self.max_len = max_len + + def stream_key(self, thread_id: str) -> str: + return f"{self.key_prefix}{thread_id}" + + def record( + self, + thread_id: str, + action: str, + detail: str = "", + ) -> str: + """Append one event and return its stream id. + + ``maxlen=~N`` keeps the stream bounded with near-zero + overhead; an exact bound (``maxlen=N`` without the tilde) + forces a scan and is rarely worth the cost. + """ + return _d( + self.redis.xadd( + self.stream_key(thread_id), + {"action": action, "detail": detail, "ts": repr(time.time())}, + maxlen=self.max_len, + approximate=True, + ) + ) + + def recent(self, thread_id: str, count: int = 20) -> list[AgentEvent]: + """Return the most recent events, newest first.""" + rows = self.redis.xrevrange( + self.stream_key(thread_id), count=count, + ) + out: list[AgentEvent] = [] + for entry_id, fields in rows: + data = {_d(k): _d(v) for k, v in fields.items()} + out.append(AgentEvent( + event_id=_d(entry_id), + thread_id=thread_id, + action=data.get("action", ""), + detail=data.get("detail", ""), + ts=float(data.get("ts", "0") or 0), + )) + return out + + def length(self, thread_id: str) -> int: + return int(self.redis.xlen(self.stream_key(thread_id))) + + def clear(self, thread_id: str) -> bool: + return bool(self.redis.delete(self.stream_key(thread_id))) + + +def _d(value) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8") + return value diff --git a/content/develop/use-cases/agent-memory/redis-py/long_term_memory.py b/content/develop/use-cases/agent-memory/redis-py/long_term_memory.py new file mode 100644 index 0000000000..faf5ce6c53 --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/long_term_memory.py @@ -0,0 +1,499 @@ +""" +Long-term memory store for an agent, backed by Redis JSON and Search. + +Each memory lives as one JSON document at ``agent:mem:``. The +document holds the memory text, its embedding vector, and a small +metadata block — user, namespace, kind, source thread, timestamps — +that lets the recall query scope results without falling back to +application-side filtering. + +A single Redis Search index covers the embedding plus every metadata +field, so one ``FT.SEARCH`` call performs approximate-nearest- +neighbour over the in-scope subset and returns the top-k memories +ranked by cosine distance. The same KNN check runs at *write* time +to deduplicate near-identical memories before they enter the store, +which keeps the index from filling with paraphrases of the same fact +as the agent reasons over similar topics across sessions. + +Memories carry one of two kinds: + +* ``episodic`` — "what happened" snapshots from a specific thread, + written with a medium TTL so old session detail decays naturally. +* ``semantic`` — distilled facts and preferences the agent should + carry forward indefinitely. Written with no TTL by default. + +This split is enforced as a TAG on the index, so the recall query +can ask for one kind or both with a filter — no separate keyspaces. + +The Redis client used here is constructed with +``decode_responses=False``. JSON.GET responses are JSON bytes which +this module decodes explicitly; the vector parameter to FT.SEARCH is +binary float32 regardless of how the document stores the embedding. +""" + +from __future__ import annotations + +import json +import time +import uuid +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import redis +from redis.commands.search.field import ( + NumericField, + TagField, + TextField, + VectorField, +) +from redis.commands.search.index_definition import IndexDefinition, IndexType +from redis.commands.search.query import Query + + +VECTOR_DIM_DEFAULT = 384 + +# How close (cosine distance) a candidate must be to an existing +# memory to count as a duplicate at write time. Smaller = stricter. +# 0.20 is calibrated to the ``all-MiniLM-L6-v2`` embedding model used +# in the demo, where a paraphrase of an existing memory lands in the +# 0.10 – 0.20 range and a distinct memory lands above 0.50. +DEFAULT_DEDUP_THRESHOLD = 0.20 + +# How close (cosine distance) a candidate must be to count as a +# relevant recall result. Larger than the dedup threshold so the +# agent gets a wider net at read time than at write time. +DEFAULT_RECALL_THRESHOLD = 0.55 + +# TTL tiers, in seconds. ``None`` means "no TTL" — the memory +# persists until explicitly deleted or evicted under memory pressure. +TTL_BY_KIND: dict[str, Optional[int]] = { + "episodic": 7 * 24 * 3600, + "semantic": None, +} + + +@dataclass +class MemoryRecord: + """A single memory document returned from the store.""" + + id: str + user: str + namespace: str + kind: str + source_thread: str + text: str + created_ts: float + hit_count: int + distance: Optional[float] = None + ttl_seconds: Optional[int] = None + + def to_dict(self) -> dict: + return { + "id": self.id, + "user": self.user, + "namespace": self.namespace, + "kind": self.kind, + "source_thread": self.source_thread, + "text": self.text, + "created_ts": self.created_ts, + "hit_count": self.hit_count, + "distance": + round(self.distance, 4) if self.distance is not None else None, + "ttl_seconds": self.ttl_seconds, + } + + +@dataclass +class WriteResult: + """Outcome of a ``remember`` call. + + ``deduped`` is ``True`` when the write skipped because a similar + memory already existed; ``id`` is then the existing memory's id. + ``existing_distance`` is the cosine distance to that nearest + memory regardless of which branch was taken — useful for tracing. + """ + + id: str + deduped: bool + existing_distance: Optional[float] + + def to_dict(self) -> dict: + return { + "id": self.id, + "deduped": self.deduped, + "existing_distance": + round(self.existing_distance, 4) + if self.existing_distance is not None else None, + } + + +class LongTermMemory: + """Index, dedupe, recall, and bound long-term agent memories.""" + + def __init__( + self, + redis_client: Optional[redis.Redis] = None, + index_name: str = "agentmem:idx", + key_prefix: str = "agent:mem:", + vector_dim: int = VECTOR_DIM_DEFAULT, + dedup_threshold: float = DEFAULT_DEDUP_THRESHOLD, + recall_threshold: float = DEFAULT_RECALL_THRESHOLD, + ttl_by_kind: Optional[dict[str, Optional[int]]] = None, + ) -> None: + self.redis = redis_client or redis.Redis( + host="localhost", port=6379, decode_responses=False, + ) + self.index_name = index_name + self.key_prefix = key_prefix + self.vector_dim = vector_dim + self.dedup_threshold = dedup_threshold + self.recall_threshold = recall_threshold + self.ttl_by_kind = ttl_by_kind or dict(TTL_BY_KIND) + + # ------------------------------------------------------------------ + # Keys and index + # ------------------------------------------------------------------ + + def memory_key(self, memory_id: str) -> str: + return f"{self.key_prefix}{memory_id}" + + def create_index(self) -> None: + """Create the Redis Search index if it doesn't already exist. + + The index is declared on the JSON document type, with a + ``$.embedding`` path holding the vector and tag fields for + ``user``, ``namespace``, ``kind``, and ``source_thread``. One + ``FT.SEARCH`` can therefore pre-filter by any combination of + those tags and KNN-rank the matching memories in one pass. + """ + schema = ( + TextField("$.text", as_name="text"), + TagField("$.user", as_name="user"), + TagField("$.namespace", as_name="namespace"), + TagField("$.kind", as_name="kind"), + TagField("$.source_thread", as_name="source_thread"), + NumericField("$.created_ts", as_name="created_ts", sortable=True), + NumericField("$.hit_count", as_name="hit_count", sortable=True), + VectorField( + "$.embedding", + "HNSW", + { + "TYPE": "FLOAT32", + "DIM": self.vector_dim, + "DISTANCE_METRIC": "COSINE", + }, + as_name="embedding", + ), + ) + definition = IndexDefinition( + prefix=[self.key_prefix], index_type=IndexType.JSON, + ) + try: + self.redis.ft(self.index_name).create_index( + fields=schema, definition=definition, + ) + except redis.ResponseError as exc: + if "Index already exists" not in str(exc): + raise + + def drop_index(self, delete_documents: bool = False) -> None: + """Drop the search index. Optionally also delete the JSON docs.""" + try: + self.redis.ft(self.index_name).dropindex( + delete_documents=delete_documents, + ) + except redis.ResponseError as exc: + message = str(exc).lower() + if "no such index" not in message \ + and "unknown index name" not in message: + raise + + # ------------------------------------------------------------------ + # Write + # ------------------------------------------------------------------ + + def remember( + self, + text: str, + embedding: np.ndarray, + user: str = "default", + namespace: str = "default", + kind: str = "episodic", + source_thread: str = "", + ttl_seconds: int | None | object = ..., + ) -> WriteResult: + """Write a new memory, deduplicating against existing entries. + + Runs one in-scope KNN(1) against the index first. If the + nearest existing memory is within ``dedup_threshold``, the + new memory is skipped (its content is already represented) + and the existing memory's ``hit_count`` is bumped. Otherwise + a fresh JSON document is written under a new id with a TTL + derived from the memory's ``kind``. + """ + if embedding.shape != (self.vector_dim,): + raise ValueError( + f"embedding has shape {embedding.shape}; " + f"index expects ({self.vector_dim},)" + ) + if embedding.dtype != np.float32: + embedding = embedding.astype(np.float32, copy=False) + + nearest = self._nearest( + embedding, user=user, namespace=namespace, kind=kind, k=1, + ) + nearest_distance = nearest[0].distance if nearest else None + if nearest and nearest[0].distance is not None \ + and nearest[0].distance <= self.dedup_threshold: + # Duplicate. Bump the hit count on the existing memory so + # the admin UI can show how often it's been re-derived. + self._bump_hit_count(nearest[0].id) + return WriteResult( + id=nearest[0].id, + deduped=True, + existing_distance=nearest_distance, + ) + + memory_id = uuid.uuid4().hex[:12] + key = self.memory_key(memory_id) + now = time.time() + doc = { + "id": memory_id, + "user": user, + "namespace": namespace, + "kind": kind, + "source_thread": source_thread, + "text": text, + "embedding": embedding.tolist(), + "created_ts": now, + "hit_count": 0, + } + ttl = self._resolve_ttl(kind, ttl_seconds) + + # MULTI/EXEC so the document and its TTL apply together. + pipe = self.redis.pipeline(transaction=True) + pipe.json().set(key, "$", doc) + if ttl is not None: + pipe.expire(key, ttl) + pipe.execute() + return WriteResult( + id=memory_id, + deduped=False, + existing_distance=nearest_distance, + ) + + # ------------------------------------------------------------------ + # Recall + # ------------------------------------------------------------------ + + def recall( + self, + query_embedding: np.ndarray, + user: str = "default", + namespace: str | None = "default", + kind: str | None = None, + k: int = 5, + distance_threshold: float | None = None, + ) -> list[MemoryRecord]: + """Return the top-k in-scope memories ranked by similarity. + + Memories beyond ``distance_threshold`` (or the instance + default) are dropped — the index always returns *something* + for KNN, so a recall result on an unrelated query would + otherwise be a confidently-wrong false positive. + """ + threshold = ( + distance_threshold if distance_threshold is not None + else self.recall_threshold + ) + candidates = self._nearest( + query_embedding, + user=user, namespace=namespace, kind=kind, k=k, + ) + return [c for c in candidates + if c.distance is not None and c.distance <= threshold] + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _nearest( + self, + embedding: np.ndarray, + user: str | None, + namespace: str | None, + kind: str | None, + k: int, + ) -> list[MemoryRecord]: + if embedding.shape != (self.vector_dim,): + raise ValueError( + f"embedding has shape {embedding.shape}; " + f"index expects ({self.vector_dim},)" + ) + filter_clause = self._build_filter_clause( + user=user, namespace=namespace, kind=kind, + ) + knn_query = ( + f"{filter_clause}=>[KNN {k} @embedding $vec AS distance]" + ) + q = ( + Query(knn_query) + .sort_by("distance") + .return_fields( + "user", "namespace", "kind", "source_thread", + "text", "created_ts", "hit_count", "distance", + ) + .paging(0, k) + .dialect(2) + ) + result = self.redis.ft(self.index_name).search( + q, + query_params={ + "vec": embedding.astype(np.float32).tobytes(), + }, + ) + out: list[MemoryRecord] = [] + for doc in result.docs: + # ``doc.id`` is the full Redis key (e.g. ``agent:mem:abc123``). + # Strip the prefix so the MemoryRecord exposes only the + # opaque id the UI and ``delete_memory`` work with. + memory_id = self._strip_prefix(_d(getattr(doc, "id", ""))) + ttl = self.redis.ttl(self.memory_key(memory_id)) + out.append(MemoryRecord( + id=memory_id, + user=_d(getattr(doc, "user", "")), + namespace=_d(getattr(doc, "namespace", "")), + kind=_d(getattr(doc, "kind", "")), + source_thread=_d(getattr(doc, "source_thread", "")), + text=_d(getattr(doc, "text", "")), + created_ts=float(_d(getattr(doc, "created_ts", "0")) or 0), + hit_count=int(_d(getattr(doc, "hit_count", "0")) or 0), + distance=float(_d(getattr(doc, "distance", "0")) or 0), + ttl_seconds=int(ttl) if ttl and ttl > 0 else None, + )) + return out + + def _bump_hit_count(self, memory_id: str) -> None: + key = self.memory_key(memory_id) + try: + self.redis.json().numincrby(key, "$.hit_count", 1) + except redis.ResponseError: + # The doc may have expired between recall and bump — fine, + # we just lose the hit count update. + pass + + def _resolve_ttl(self, kind: str, override: object) -> int | None: + if override is ...: + return self.ttl_by_kind.get(kind) + return override # type: ignore[return-value] + + def _strip_prefix(self, raw_key: str) -> str: + if raw_key.startswith(self.key_prefix): + return raw_key[len(self.key_prefix):] + return raw_key + + # Characters Redis Search treats as syntax inside a TAG value; any + # of them in a user-supplied filter must be backslash-escaped or + # the surrounding ``{...}`` block won't parse correctly. + _TAG_SPECIAL = set("\\,.<>{}[]\"':;!@#$%^&*()-+=~| ") + + @classmethod + def _escape_tag_value(cls, value: str) -> str: + return "".join( + "\\" + ch if ch in cls._TAG_SPECIAL else ch for ch in value + ) + + @classmethod + def _build_filter_clause( + cls, + *, + user: str | None, + namespace: str | None, + kind: str | None, + ) -> str: + clauses: list[str] = [] + if user: + clauses.append(f"@user:{{{cls._escape_tag_value(user)}}}") + if namespace: + clauses.append(f"@namespace:{{{cls._escape_tag_value(namespace)}}}") + if kind: + clauses.append(f"@kind:{{{cls._escape_tag_value(kind)}}}") + return "(" + " ".join(clauses) + ")" if clauses else "(*)" + + # ------------------------------------------------------------------ + # Admin / inspection + # ------------------------------------------------------------------ + + def index_info(self) -> dict: + try: + info = self.redis.ft(self.index_name).info() + except redis.ResponseError: + return {"num_docs": 0, "indexing_failures": 0} + return { + "num_docs": int(info.get("num_docs", 0)), + "indexing_failures": int(info.get("hash_indexing_failures", 0)), + } + + def list_memories( + self, + user: str | None = None, + namespace: str | None = None, + kind: str | None = None, + limit: int = 100, + ) -> list[MemoryRecord]: + """Return memories matching the filters, newest first.""" + filter_clause = self._build_filter_clause( + user=user, namespace=namespace, kind=kind, + ) + q = ( + Query(filter_clause) + .return_fields( + "user", "namespace", "kind", "source_thread", + "text", "created_ts", "hit_count", + ) + .paging(0, limit) + .sort_by("created_ts", asc=False) + .dialect(2) + ) + result = self.redis.ft(self.index_name).search(q) + out: list[MemoryRecord] = [] + for doc in result.docs: + memory_id = self._strip_prefix(_d(getattr(doc, "id", ""))) + ttl = self.redis.ttl(self.memory_key(memory_id)) + out.append(MemoryRecord( + id=memory_id, + user=_d(getattr(doc, "user", "")), + namespace=_d(getattr(doc, "namespace", "")), + kind=_d(getattr(doc, "kind", "")), + source_thread=_d(getattr(doc, "source_thread", "")), + text=_d(getattr(doc, "text", "")), + created_ts=float(_d(getattr(doc, "created_ts", "0")) or 0), + hit_count=int(_d(getattr(doc, "hit_count", "0")) or 0), + ttl_seconds=int(ttl) if ttl and ttl > 0 else None, + )) + return out + + def delete_memory(self, memory_id: str) -> bool: + return bool(self.redis.delete(self.memory_key(memory_id))) + + def clear(self) -> int: + """Drop the index and every memory document. + + Returns the count of documents that were removed. In + production the equivalent is ``FLUSHDB`` on a dedicated + memory database, or letting TTLs and eviction expire entries + naturally. + """ + before = self.index_info()["num_docs"] + self.drop_index(delete_documents=True) + self.create_index() + return before + + +def _d(value) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8") + return value diff --git a/content/develop/use-cases/agent-memory/redis-py/seed_memory.py b/content/develop/use-cases/agent-memory/redis-py/seed_memory.py new file mode 100644 index 0000000000..35824af3ab --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/seed_memory.py @@ -0,0 +1,96 @@ +""" +Pre-seed the long-term memory store with sample memories. + +In a real deployment the memory store fills up organically as the +agent reasons over user turns: each turn produces zero or more +memories (preferences, facts, episodic summaries) that flow into the +store with deduplication. To make the demo immediately useful — so +the first recall query lands on relevant results instead of an empty +list — we seed a small set of canonical memories for a default user +at startup. + +The seed list mixes ``semantic`` memories (long-lived preferences +and facts) with ``episodic`` memories (snapshots of past sessions), +so the demo can show how the ``kind`` filter scopes recall. +""" + +from __future__ import annotations + +import numpy as np + +from embeddings import LocalEmbedder +from long_term_memory import LongTermMemory + + +SEED_MEMORIES: list[dict] = [ + { + "text": "The user prefers concise answers without filler phrases.", + "kind": "semantic", + }, + { + "text": "The user is a Python developer working on a logistics platform.", + "kind": "semantic", + }, + { + "text": "The user lives in Berlin and works in the Europe/Berlin time zone.", + "kind": "semantic", + }, + { + "text": + "The user dislikes dark mode and prefers a high-contrast light " + "theme in editors and dashboards.", + "kind": "semantic", + }, + { + "text": + "The user is allergic to peanuts; any restaurant suggestion must " + "avoid dishes that commonly contain them.", + "kind": "semantic", + }, + { + "text": + "Last Tuesday the user asked the agent to draft a postmortem for " + "the order-routing outage. The agent produced a five-section " + "draft and the user approved sections 1, 2, and 4 with minor " + "edits.", + "kind": "episodic", + }, + { + "text": + "In a previous session the user asked for help debugging a flaky " + "test in the inventory service. The fix turned out to be a race " + "condition in the warehouse webhook handler.", + "kind": "episodic", + }, + { + "text": + "Two weeks ago the user mentioned they were planning to migrate " + "the analytics warehouse from Snowflake to BigQuery in Q3.", + "kind": "episodic", + }, +] + + +def seed( + memory: LongTermMemory, + embedder: LocalEmbedder, + user: str = "default", + namespace: str = "default", + source_thread: str = "seed", +) -> int: + """Embed and write the seed memories. Returns the count actually written.""" + prompts = [m["text"] for m in SEED_MEMORIES] + vectors = embedder.encode_many(prompts) + written = 0 + for entry, vec in zip(SEED_MEMORIES, vectors): + result = memory.remember( + text=entry["text"], + embedding=np.asarray(vec, dtype=np.float32), + user=user, + namespace=namespace, + kind=entry["kind"], + source_thread=source_thread, + ) + if not result.deduped: + written += 1 + return written diff --git a/content/develop/use-cases/agent-memory/redis-py/session_store.py b/content/develop/use-cases/agent-memory/redis-py/session_store.py new file mode 100644 index 0000000000..852baf46ee --- /dev/null +++ b/content/develop/use-cases/agent-memory/redis-py/session_store.py @@ -0,0 +1,297 @@ +""" +Working-memory store for an agent session, backed by a Redis Hash. + +Each session is one Hash document at ``agent:session:{thread_id}``. +The hash holds the running scratchpad, the current goal, a rolling +window of recent turns (serialized as a JSON list to fit in one +field), and a few audit fields. One ``HGETALL`` returns the whole +session in a single round trip on every step of the agent loop. + +Every write refreshes the key's TTL with ``EXPIRE``, so idle sessions +fall off without a separate cleanup job and active sessions stay +alive as long as the agent keeps touching them. A separate +``LongTermMemory`` (see ``long_term_memory.py``) is what survives +beyond a session's TTL. + +The turn window is bounded to ``MAX_TURNS`` in application code; the +hash itself doesn't grow, so the working set per thread stays +constant regardless of how long the agent has been running. +""" + +from __future__ import annotations + +import json +import time +import uuid +from dataclasses import dataclass, field +from typing import Optional + +import redis + + +# How many recent turns to keep inline on the session hash. Older +# turns flow through the event log (see ``event_log.py``) and the +# long-term memory store (see ``long_term_memory.py``). +MAX_TURNS = 20 + + +@dataclass +class SessionTurn: + role: str # "user" | "assistant" | "tool" + content: str + ts: float + + def to_dict(self) -> dict: + return {"role": self.role, "content": self.content, "ts": self.ts} + + @classmethod + def from_dict(cls, data: dict) -> "SessionTurn": + return cls( + role=data.get("role", ""), + content=data.get("content", ""), + ts=float(data.get("ts", 0.0)), + ) + + +@dataclass +class SessionState: + thread_id: str + user: str = "default" + agent: str = "default" + goal: str = "" + scratchpad: str = "" + turn_count: int = 0 + created_ts: float = 0.0 + last_active_ts: float = 0.0 + recent_turns: list[SessionTurn] = field(default_factory=list) + ttl_seconds: int = 0 + + def to_dict(self) -> dict: + return { + "thread_id": self.thread_id, + "user": self.user, + "agent": self.agent, + "goal": self.goal, + "scratchpad": self.scratchpad, + "turn_count": self.turn_count, + "created_ts": self.created_ts, + "last_active_ts": self.last_active_ts, + "recent_turns": [t.to_dict() for t in self.recent_turns], + "ttl_seconds": self.ttl_seconds, + } + + +class AgentSession: + """Load, write, and bound the working-memory hash for one thread.""" + + def __init__( + self, + redis_client: Optional[redis.Redis] = None, + key_prefix: str = "agent:session:", + default_ttl_seconds: int = 3600, + max_turns: int = MAX_TURNS, + ) -> None: + self.redis = redis_client or redis.Redis( + host="localhost", port=6379, decode_responses=False, + ) + self.key_prefix = key_prefix + self.default_ttl_seconds = default_ttl_seconds + self.max_turns = max_turns + + def session_key(self, thread_id: str) -> str: + return f"{self.key_prefix}{thread_id}" + + def new_thread_id(self) -> str: + return uuid.uuid4().hex[:12] + + def start( + self, + thread_id: str, + user: str = "default", + agent: str = "default", + goal: str = "", + ttl_seconds: int | None = None, + ) -> SessionState: + """Create a fresh working memory for a thread. + + Overwrites any existing session at the same key. The agent + normally calls this once per thread at the first turn and + relies on ``load`` / ``append_turn`` for subsequent steps. + """ + ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl_seconds + now = time.time() + state = SessionState( + thread_id=thread_id, + user=user, + agent=agent, + goal=goal, + scratchpad="", + turn_count=0, + created_ts=now, + last_active_ts=now, + recent_turns=[], + ttl_seconds=ttl, + ) + self._write(state, ttl) + return state + + def load(self, thread_id: str) -> SessionState | None: + """Return the session state, or ``None`` if it has expired.""" + key = self.session_key(thread_id) + raw = self.redis.hgetall(key) + if not raw: + return None + data = {_d(k): _d(v) for k, v in raw.items()} + ttl = self.redis.ttl(key) + turns_blob = data.get("recent_turns", "[]") + try: + turns = [SessionTurn.from_dict(t) for t in json.loads(turns_blob)] + except json.JSONDecodeError: + turns = [] + return SessionState( + thread_id=thread_id, + user=data.get("user", "default"), + agent=data.get("agent", "default"), + goal=data.get("goal", ""), + scratchpad=data.get("scratchpad", ""), + turn_count=int(data.get("turn_count", "0") or 0), + created_ts=float(data.get("created_ts", "0") or 0), + last_active_ts=float(data.get("last_active_ts", "0") or 0), + recent_turns=turns, + ttl_seconds=int(ttl) if ttl and ttl > 0 else 0, + ) + + def append_turn( + self, + thread_id: str, + role: str, + content: str, + user: str | None = None, + agent: str | None = None, + ttl_seconds: int | None = None, + ) -> SessionState: + """Append a turn, bound the rolling window, refresh the TTL. + + ``user`` and ``agent`` are only consulted when the session + does not yet exist — they seed the auto-created session so + the working-memory hash matches the user the caller is + operating against. On an existing session they're ignored; + the original ``start`` values stand. + + Read-modify-write here is last-writer-wins on the turn list + if two concurrent turns reach the same thread; the demo never + triggers that race in practice (one browser, one turn at a + time) but a multi-worker agent that shares a thread id would + wrap this in ``WATCH`` / ``MULTI`` / ``EXEC`` or a Lua script + that does the append atomically server-side. + """ + state = self.load(thread_id) + if state is None: + state = self.start( + thread_id, + user=user if user is not None else "default", + agent=agent if agent is not None else "default", + ttl_seconds=ttl_seconds, + ) + state.recent_turns.append( + SessionTurn(role=role, content=content, ts=time.time()) + ) + if len(state.recent_turns) > self.max_turns: + state.recent_turns = state.recent_turns[-self.max_turns:] + state.turn_count += 1 + state.last_active_ts = time.time() + ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl_seconds + state.ttl_seconds = ttl + self._write(state, ttl) + return state + + def set_scratchpad( + self, + thread_id: str, + text: str, + ttl_seconds: int | None = None, + ) -> SessionState | None: + """Update the agent's running scratchpad and refresh TTL.""" + state = self.load(thread_id) + if state is None: + return None + state.scratchpad = text + state.last_active_ts = time.time() + ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl_seconds + state.ttl_seconds = ttl + self._write(state, ttl) + return state + + def set_goal( + self, + thread_id: str, + text: str, + user: str | None = None, + agent: str | None = None, + ttl_seconds: int | None = None, + ) -> SessionState: + """Update the goal field without touching turns or the scratchpad. + + Creates the session if it doesn't exist yet — setting a goal + on a fresh thread is a sensible first step in the agent loop, + so this method covers both the "rename the goal mid-session" + and the "start a thread with this goal" cases. + """ + state = self.load(thread_id) + if state is None: + return self.start( + thread_id, + user=user if user is not None else "default", + agent=agent if agent is not None else "default", + goal=text, + ttl_seconds=ttl_seconds, + ) + state.goal = text + state.last_active_ts = time.time() + ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl_seconds + state.ttl_seconds = ttl + self._write(state, ttl) + return state + + def delete(self, thread_id: str) -> bool: + """Drop the session immediately. Returns ``True`` if it existed.""" + return bool(self.redis.delete(self.session_key(thread_id))) + + def list_threads(self, limit: int = 100) -> list[str]: + """Return active thread ids (for the demo's thread switcher).""" + out: list[str] = [] + for key in self.redis.scan_iter(match=f"{self.key_prefix}*", count=200): + thread_id = _d(key)[len(self.key_prefix):] + out.append(thread_id) + if len(out) >= limit: + break + return out + + def _write(self, state: SessionState, ttl: int) -> None: + key = self.session_key(state.thread_id) + mapping = { + "thread_id": state.thread_id, + "user": state.user, + "agent": state.agent, + "goal": state.goal, + "scratchpad": state.scratchpad, + "turn_count": str(state.turn_count), + "created_ts": repr(state.created_ts), + "last_active_ts": repr(state.last_active_ts), + "recent_turns": json.dumps([t.to_dict() for t in state.recent_turns]), + } + # MULTI/EXEC so HSET and EXPIRE either both apply or neither + # does. A connection drop between the two writes would + # otherwise leave the session without a TTL. + pipe = self.redis.pipeline(transaction=True) + pipe.hset(key, mapping=mapping) + pipe.expire(key, ttl) + pipe.execute() + + +def _d(value) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8") + return value diff --git a/content/develop/use-cases/agent-memory/ruby/.gitignore b/content/develop/use-cases/agent-memory/ruby/.gitignore new file mode 100644 index 0000000000..b4677582b2 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/.gitignore @@ -0,0 +1,4 @@ +.bundle/ +vendor/bundle/ +*.log +.DS_Store diff --git a/content/develop/use-cases/agent-memory/ruby/Gemfile b/content/develop/use-cases/agent-memory/ruby/Gemfile new file mode 100644 index 0000000000..1bcf2cafa2 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/Gemfile @@ -0,0 +1,16 @@ +# Redis agent-memory demo (Ruby). +# +# Pinned to Ruby 3.2+ baseline. The runtime gems pull in: +# * `redis-client` (the lower-level transport under `redis`) +# * `onnxruntime` (the ONNX backend `informers` runs the encoder on) +# * `tokenizers` (the Hugging Face fast tokenizer used by `informers`) +# `webrick` was extracted from the stdlib in Ruby 3.0; declaring it here +# means `bundle install` resolves it on every supported Ruby version. + +source 'https://rubygems.org' + +ruby '>= 3.2' + +gem 'redis', '~> 5.4' +gem 'informers', '~> 1.3' +gem 'webrick', '~> 1.8' diff --git a/content/develop/use-cases/agent-memory/ruby/_index.md b/content/develop/use-cases/agent-memory/ruby/_index.md new file mode 100644 index 0000000000..b3567e948c --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/_index.md @@ -0,0 +1,349 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Ruby with redis-rb, informers, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: redis-rb example (Ruby) +title: Redis agent memory with redis-rb +weight: 9 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Ruby with [`redis-rb`]({{< relref "/develop/clients/ruby" >}}) and the [`informers`](https://github.com/ankane/informers) gem, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with the standard-library [`WEBrick`](https://github.com/ruby/webrick) HTTP server so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [`informers`](https://github.com/ankane/informers), Ankane's Ruby port of Hugging Face transformers, running the ONNX-exported [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) encoder through the `onnxruntime` gem — the same 384-d model the [Python example]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) and every other port use. The Ruby ONNX path produces vectors that match the Python PyTorch reference closely enough that paraphrase distances land on the same numbers down to the fourth decimal place; a memory written by one demo can be recalled by the other against the same Redis instance, and the distance bands the Python walkthrough quotes carry over to this one without recalibration. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* A single round trip per tier: one [`HGETALL`]({{< relref "/commands/hgetall" >}}) for the session, one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) for recall, one [`XADD`]({{< relref "/commands/xadd" >}}) for the event log. ([`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) itself is one round trip; the helper also issues one [`TTL`]({{< relref "/commands/ttl" >}}) call per returned row to populate `ttl_seconds` for the admin panel.) +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.encode_one(text)` to turn the incoming turn into a 384-element `Array`. `informers` runs L2 normalization inside the ONNX graph itself, so the returned vector is already a unit vector. +2. `session.append_turn(thread_id, role:, content:)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI/EXEC`]({{< relref "/commands/multi" >}}). The session TTL refreshes on every write so an active thread stays alive. +3. `memory.recall(query_vec, user:, namespace:, k: 5)` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.remember(text:, embedding:, user:, namespace:, kind:)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `event_log.record(thread_id, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/ruby/lib/session_store.rb)): + +```ruby +require 'redis' +require_relative 'lib/session_store' + +client = Redis.new(host: 'localhost', port: 6379) +session = AgentMemory::AgentSession.new( + redis_client: client, + key_prefix: 'agent:session:', + default_ttl_seconds: 3600, # one hour + max_turns: 20 # rolling window per thread +) + +thread_id = session.new_thread_id +session.start( + thread_id, + user: 'alice', + agent: 'demo-agent', + goal: "Plan next week's meetings." +) +session.append_turn( + thread_id, + role: 'user', + content: 'Schedule a budget review with finance.' +) +state = session.load(thread_id) +puts [state.turn_count, state.recent_turns.length, state.ttl_seconds].inspect +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `append_turn`, `set_scratchpad` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/ruby/lib/long_term_memory.rb)): + +```ruby +require_relative 'lib/embeddings' +require_relative 'lib/long_term_memory' + +memory = AgentMemory::LongTermMemory.new( + redis_client: client, + index_name: 'agentmem:idx', + key_prefix: 'agent:mem:', + dedup_threshold: 0.20, # cosine distance — tight at write time + recall_threshold: 0.55 # looser at read time +) +embedder = AgentMemory::LocalEmbedder.new +memory.create_index # idempotent + +# Write a memory. The same KNN that powers recall also runs here at +# a tighter threshold so paraphrases of the same fact collapse. +vec = embedder.encode_one('The user prefers light mode in editors.') +result = memory.remember( + text: 'The user prefers light mode in editors.', + embedding: vec, + user: 'alice', + namespace: 'default', + kind: 'semantic', + source_thread: '9f3d2a4b8c61' +) +puts result.deduped, result.id, result.existing_distance + +# Recall against a later question. +q = embedder.encode_one('Which theme does this user like?') +hits = memory.recall(q, user: 'alice', namespace: 'default', k: 5) +hits.each { |h| puts format('%.3f [%s] %s', h.distance, h.kind, h.text) } +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (the demo packs them with Ruby's [`Array#pack('e*')`](https://docs.ruby-lang.org/en/master/Array.html#method-i-pack), which emits little-endian float32), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with `AS` aliases so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write with `ttl_seconds: ...` on `remember`, or pass a different `ttl_by_kind: { ... }` to the `LongTermMemory` constructor — for example, to give semantic memories a six-month TTL while leaving episodic memories at seven days. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/ruby/lib/event_log.rb)): + +```ruby +require_relative 'lib/event_log' + +events = AgentMemory::AgentEventLog.new(redis_client: client, max_len: 1000) +events.record(thread_id, 'turn_appended:user', + 'Schedule a budget review with finance.') +events.record(thread_id, 'memory_written', + 'wrote 7c3f8a1b9e02 as semantic') + +events.recent(thread_id, count: 20).each do |e| + puts "#{e.action} #{e.detail}" +end +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarizers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession#append_turn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recent_turns` array in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recent_turns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. + +* **Long-term dedup is not atomic.** `LongTermMemory#remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `current_thread_id` accessor on the `AgentMemoryDemo` object that `/new_thread` and `/reset` mutate; `handle_turn` reads it without coordination, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +* **The embedder is shared and the demo server is multi-threaded.** `informers` constructs the ONNX session once and reuses it across every request, and `WEBrick::HTTPServer` runs each request in its own Ruby thread. The MRI GIL serializes Ruby bytecode but does not protect calls into `onnxruntime`'s C++ session, which is undocumented for thread-safety; the demo gets away with the unsynchronized sharing because a single browser sends one turn at a time. A multi-user agent — or anything that lets the browser fire requests in parallel — should wrap the embedder call in a `Mutex` or hand each worker its own `LocalEmbedder` instance. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `seed_memory.rb` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/ruby/lib/seed_memory.rb)): + +```ruby +require_relative 'lib/seed_memory' + +memory = AgentMemory::LongTermMemory.new(redis_client: client) +embedder = AgentMemory::LocalEmbedder.new +memory.create_index +AgentMemory::SeedMemory.seed(memory, embedder, + user: 'default', namespace: 'default') +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`demo_server.rb` runs a [`WEBrick::HTTPServer`](https://github.com/ruby/webrick) on port 8091. WEBrick was extracted from the stdlib in Ruby 3.0, so the `Gemfile` declares it as a runtime gem; the [semantic-cache Ruby example]({{< relref "/develop/use-cases/semantic-cache/ruby" >}}) uses the same setup. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process. The "current thread" is a single accessor on the demo object that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the + example directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/ruby + ``` + +2. Install the dependencies. Ruby 3.2 or later is required. + + ```bash + bundle install + ``` + + The committed `Gemfile.lock` was regenerated on a newer Ruby and Bundler + than the 3.2 floor (the `BUNDLED WITH` line records the Bundler that wrote + it); if `bundle install` complains about a Bundler version mismatch on an + older 3.2 install, either upgrade Bundler (`gem install bundler`) or delete + `Gemfile.lock` to regenerate it under your local Bundler. + + The `onnxruntime` gem ships a pre-built shared library for macOS (arm64 / x86_64) + and Linux; on those platforms `bundle install` is the whole install. On other + platforms see the [`onnxruntime` README](https://github.com/ankane/onnxruntime-ruby#installation). + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo server. The first run downloads the + [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) + ONNX weights (around 80 MB) into the local Hugging Face cache: + + ```bash + bundle exec ruby demo_server.rb + ``` + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing recalls + above the threshold (every seed lands above 0.8), and the new memory is + written as `episodic` (or `semantic`, depending on the dropdown) under a + fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + `sentence-transformers/all-MiniLM-L6-v2` puts a faithful paraphrase in the + 0.15 – 0.50 cosine-distance range, a loose paraphrase or related topic in the + 0.50 – 0.80 range, and unrelated queries above 0.8 — which is what motivates + the 0.55 default recall threshold and the 0.20 default dedup threshold. A + stricter embedding model (or a domain-tuned one) would let you tighten both; + a noisier one would push them up. The right thresholds are always a function + of the model, the corpus, and how conservative the agent needs to be about + accepting a memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags: + +* `--host` — HTTP bind host (default `127.0.0.1`). +* `--port` — HTTP bind port (default `8091`). +* `--redis-host`, `--redis-port` — point the demo at a non-local Redis. +* `--mem-index-name`, `--mem-key-prefix`, `--session-key-prefix`, `--event-key-prefix` — change the default key namespacing. +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/ruby/demo_server.rb b/content/develop/use-cases/agent-memory/ruby/demo_server.rb new file mode 100644 index 0000000000..1c3bc818f1 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/demo_server.rb @@ -0,0 +1,546 @@ +#!/usr/bin/env ruby +# frozen_string_literal: true + +# Redis agent-memory demo server (Ruby). +# +# Run this file and visit http://localhost:8091 to drive a small +# agent-memory demo backed by Redis Hashes, JSON, Search, and +# Streams. The UI lets you: +# +# * Type a turn as the user (or paste a goal). The server appends +# the turn to the per-thread working-memory hash, embeds the +# turn, recalls the top-k semantically nearest long-term +# memories, optionally writes the turn back as a new memory with +# write-time deduplication, and appends an event to the per-thread +# stream. +# * Watch the three memory tiers update in place: working memory in +# one Hash, long-term memories as JSON documents under one index, +# and the event log in one Stream. +# * Switch user, namespace, kind, and recall threshold to see how +# scoping changes which memories the agent sees. +# * Inspect every long-term memory and drop individual memories to +# simulate eviction. +# +# The server holds a single `LocalEmbedder`, one `AgentSession`, one +# `LongTermMemory`, and one `AgentEventLog` for the lifetime of the +# process. The first run downloads the ONNX-exported embedding model +# into the local Hugging Face cache; everything after is local. + +require 'json' +require 'optparse' +require 'uri' +require 'webrick' + +require 'redis' + +$LOAD_PATH.unshift(File.expand_path('lib', __dir__)) +require 'embeddings' +require 'event_log' +require 'long_term_memory' +require 'seed_memory' +require 'session_store' + +module AgentMemory + # AgentMemoryDemo owns the embedder, session store, long-term + # memory, and event log for the lifetime of the process. The + # handlers thread requests through `handle_turn` and the + # seed / new-thread endpoints reuse it so there is only one + # description of the demo lifecycle. + class AgentMemoryDemo + attr_reader :session_store, :memory, :event_log, :embedder + attr_accessor :current_thread_id + + def initialize(session_store:, memory:, event_log:, embedder:, + default_user: 'default', default_namespace: 'default') + @session_store = session_store + @memory = memory + @event_log = event_log + @embedder = embedder + @default_user = default_user + @default_namespace = default_namespace + @current_thread_id = session_store.new_thread_id + end + + # `seed` / `new_thread` / `handle_turn` all touch + # `current_thread_id` without coordination — see the walkthrough's + # "Concurrency caveats" section. The demo is single-user in + # practice, so the race never triggers; a multi-user agent would + # carry the thread id on each request instead of holding it as + # shared server state. + + def seed(user:, namespace:) + @memory.clear + @session_store.delete(@current_thread_id) + @event_log.clear(@current_thread_id) + written = SeedMemory.seed( + @memory, @embedder, + user: user, namespace: namespace, source_thread: 'seed' + ) + @current_thread_id = @session_store.new_thread_id + written + end + + def new_thread(user:, namespace:) + @event_log.clear(@current_thread_id) + @current_thread_id = @session_store.new_thread_id + @session_store.start( + @current_thread_id, user: user, agent: 'demo-agent', goal: '' + ) + @event_log.record( + @current_thread_id, 'thread_started', + "user=#{user} namespace=#{namespace}" + ) + @current_thread_id + end + + # One pass through the agent loop: append, recall, remember, log. + # + # The order matters. We embed once and reuse the vector for both + # the recall and (if asked) the remember step — no point encoding + # the same text twice. Recall runs *before* the remember write so + # the agent doesn't see its own just-written turn echoed back as + # a recalled memory. + def handle_turn(text:, user:, namespace:, kind:, role:, + threshold:, action:) + thread_id = @current_thread_id + + t0 = monotonic_ms + vec = @embedder.encode_one(text) + embed_ms = monotonic_ms - t0 + + if action == 'goal' + @session_store.set_goal( + thread_id, text, user: user, agent: 'demo-agent' + ) + session_action = 'goal_set' + else + @session_store.append_turn( + thread_id, role: role, content: text, + user: user, agent: 'demo-agent' + ) + session_action = "turn_appended:#{role}" + end + + t1 = monotonic_ms + recalled = @memory.recall( + vec, + user: user, namespace: namespace, k: 5, + distance_threshold: threshold + ) + recall_ms = monotonic_ms - t1 + + write_skipped = (kind == 'skip' || action == 'goal') + write_result = nil + write_ms = 0.0 + unless write_skipped + t2 = monotonic_ms + write_result = @memory.remember( + text: text, embedding: vec, + user: user, namespace: namespace, + kind: kind, source_thread: thread_id + ) + write_ms = monotonic_ms - t2 + end + + if write_result + event_detail = write_result.deduped \ + ? "deduped onto #{write_result.id}" \ + : "wrote #{write_result.id} as #{kind}" + @event_log.record(thread_id, session_action, event_detail) + else + @event_log.record(thread_id, session_action, '') + end + + { + thread_id: thread_id, + write_skipped: write_skipped, + memory_id: write_result&.id, + deduped: write_result ? write_result.deduped : false, + existing_distance: write_result&.existing_distance, + kind: write_skipped ? nil : kind, + recalled: recalled.map(&:to_h), + embed_ms: embed_ms, + recall_ms: recall_ms, + write_ms: write_ms + } + end + + private + + def monotonic_ms + Process.clock_gettime(Process::CLOCK_MONOTONIC) * 1000.0 + end + end + + # ---------------------------------------------------------------- + # HTTP plumbing + # ---------------------------------------------------------------- + + # Cap POST bodies so a runaway client (or, more realistically, a + # `curl --data-binary @big-file` by mistake) cannot accumulate + # unbounded memory before the handler runs. WEBrick has no + # built-in cap, so each POST handler calls `body_too_large?` + # before touching `req.body` and returns 413 if the request's + # `Content-Length` exceeds the limit. The demo's largest + # legitimate body is a few hundred bytes of form-encoded query + # fields; 1 MiB matches the Python / Node / Go / Java / PHP caps. + MAX_BODY_BYTES = 1 * 1024 * 1024 + + def self.body_too_large?(req) + req['Content-Length'].to_i > MAX_BODY_BYTES + end + + # Sanitize the threshold parameter from the form body. + # `Float()` happily handles "nan" → NaN and "inf" → +Inf. Either + # would silently turn recall into "every memory" or "nothing". + # Clamp to the meaningful cosine-distance range so a malformed POST + # cannot override the threshold semantics. Falls back to the + # configured `--recall-threshold` rather than a hard-coded constant + # so the server-wide flag actually drives the default. + def self.clamp_threshold(raw, fallback) + return fallback if raw.nil? || raw.empty? + parsed = Float(raw, exception: false) + return fallback if parsed.nil? || !parsed.finite? + [[parsed, 0.0].max, 2.0].min + end + + # Build the response shape /state serves. The Python / Node / + # other-language siblings serve the same shape so the shared HTML + # works without modification. + def self.build_state(deps, user:, namespace:) + memory = deps.fetch(:memory) + session_store = deps.fetch(:session_store) + event_log = deps.fetch(:event_log) + embedder = deps.fetch(:embedder) + demo = deps.fetch(:demo) + stack_label = deps.fetch(:stack_label) + + info = memory.index_info + thread_id = demo.current_thread_id + session = session_store.load(thread_id) + memories = memory.list_memories( + user: user, namespace: namespace, limit: 200 + ) + events = event_log.recent(thread_id, count: 20) + { + index: { + num_docs: info[:num_docs], + indexing_failures: info[:indexing_failures], + index_name: memory.index_name, + model: embedder.model_name, + session_ttl_seconds: session_store.default_ttl_seconds, + dedup_threshold: memory.dedup_threshold, + default_recall_threshold: memory.recall_threshold, + stack_label: stack_label + }, + thread_id: thread_id, + session: session&.to_h, + memories: memories.map(&:to_h), + events: events.map(&:to_h), + # `recalled` is populated by /turn; on plain /state reads the + # UI keeps showing the last turn's result, which is the useful + # behavior for an "agent" panel. + recalled: [] + } + end + + # Parse a URL-encoded form body into a plain Hash. + def self.parse_form(body) + URI.decode_www_form(body.to_s).to_h + rescue ArgumentError + {} + end + + # Wrap every handler so an uncaught exception lands as a JSON 500 + # rather than letting WEBrick render a plain-text stack trace. The + # demo's JS client always calls `await res.json()`, so a non-JSON + # body would surface as an opaque parse error. + def self.with_json_errors(response) + yield + rescue StandardError => e + warn("[demo] handler error: #{e.class}: #{e.message}") + warn(e.backtrace.first(8).join("\n")) + response.status = 500 + response['Content-Type'] = 'application/json' + response.body = JSON.generate(error: e.message, type: e.class.name) + end + + def self.send_json(response, payload, status: 200) + response.status = status + response['Content-Type'] = 'application/json' + response.body = JSON.generate(payload) + end + + def self.send_html(response, html, status: 200) + response.status = status + response['Content-Type'] = 'text/html; charset=utf-8' + response.body = html + end + + def self.empty_or(value, default) + value.nil? || value.empty? ? default : value + end + + # ---------------------------------------------------------------- + # Handlers + # ---------------------------------------------------------------- + + def self.install_handlers(server, deps) + memory = deps.fetch(:memory) + demo = deps.fetch(:demo) + html_page = deps.fetch(:html_page) + recall_threshold = deps.fetch(:recall_threshold) + + server.mount_proc '/' do |req, res| + with_json_errors(res) do + if req.path != '/' && req.path != '/index.html' + send_json(res, { error: 'not found' }, status: 404) + next + end + if req.request_method != 'GET' + send_json(res, { error: 'method not allowed' }, status: 405) + next + end + send_html(res, html_page) + end + end + + server.mount_proc '/state' do |req, res| + with_json_errors(res) do + if req.request_method != 'GET' + send_json(res, { error: 'method not allowed' }, status: 405) + next + end + query = req.query + send_json(res, build_state( + deps, + user: empty_or(query['user'], 'default'), + namespace: empty_or(query['namespace'], 'default') + )) + end + end + + server.mount_proc '/turn' do |req, res| + with_json_errors(res) do + if req.request_method != 'POST' + send_json(res, { error: 'method not allowed' }, status: 405) + next + end + if body_too_large?(req) + send_json(res, { + error: "request body exceeds #{MAX_BODY_BYTES} bytes" + }, status: 413) + next + end + params = parse_form(req.body) + text = (params['text'] || '').strip + if text.empty? + send_json(res, { error: 'text is required' }, status: 400) + next + end + payload = demo.handle_turn( + text: text, + user: empty_or(params['user'], 'default'), + namespace: empty_or(params['namespace'], 'default'), + kind: empty_or(params['kind'], 'episodic'), + role: empty_or(params['role'], 'user'), + threshold: clamp_threshold(params['threshold'], recall_threshold), + action: empty_or(params['action'], 'turn') + ) + send_json(res, payload) + end + end + + server.mount_proc '/new_thread' do |req, res| + with_json_errors(res) do + if req.request_method != 'POST' + send_json(res, { error: 'method not allowed' }, status: 405) + next + end + if body_too_large?(req) + send_json(res, { + error: "request body exceeds #{MAX_BODY_BYTES} bytes" + }, status: 413) + next + end + params = parse_form(req.body) + thread_id = demo.new_thread( + user: empty_or(params['user'], 'default'), + namespace: empty_or(params['namespace'], 'default') + ) + send_json(res, { thread_id: thread_id }) + end + end + + server.mount_proc '/reset' do |req, res| + with_json_errors(res) do + if req.request_method != 'POST' + send_json(res, { error: 'method not allowed' }, status: 405) + next + end + if body_too_large?(req) + send_json(res, { + error: "request body exceeds #{MAX_BODY_BYTES} bytes" + }, status: 413) + next + end + params = parse_form(req.body) + seeded = demo.seed( + user: empty_or(params['user'], 'default'), + namespace: empty_or(params['namespace'], 'default') + ) + send_json(res, { seeded: seeded }) + end + end + + server.mount_proc '/drop_memory' do |req, res| + with_json_errors(res) do + if req.request_method != 'POST' + send_json(res, { error: 'method not allowed' }, status: 405) + next + end + if body_too_large?(req) + send_json(res, { + error: "request body exceeds #{MAX_BODY_BYTES} bytes" + }, status: 413) + next + end + params = parse_form(req.body) + memory_id = (params['memory_id'] || '').strip + if memory_id.empty? + send_json(res, { error: 'memory_id is required' }, status: 400) + next + end + deleted = memory.delete_memory(memory_id) + send_json(res, { deleted: deleted, memory_id: memory_id }) + end + end + end + + # ---------------------------------------------------------------- + # Main + # ---------------------------------------------------------------- + + def self.parse_flags(argv) + options = { + host: '127.0.0.1', + port: 8091, + redis_host: 'localhost', + redis_port: 6379, + mem_index_name: 'agentmem:idx', + mem_key_prefix: 'agent:mem:', + session_key_prefix: 'agent:session:', + event_key_prefix: 'agent:events:', + session_ttl_seconds: 3600, + dedup_threshold: DEFAULT_DEDUP_THRESHOLD, + recall_threshold: DEFAULT_RECALL_THRESHOLD, + no_reset: false + } + OptionParser.new do |opts| + opts.banner = 'Usage: ruby demo_server.rb [options]' + opts.on('--host HOST', 'Interface to bind to') { |v| options[:host] = v } + opts.on('--port PORT', Integer, 'HTTP port') { |v| options[:port] = v } + opts.on('--redis-host HOST', 'Redis host') { |v| options[:redis_host] = v } + opts.on('--redis-port PORT', Integer, 'Redis port') { |v| options[:redis_port] = v } + opts.on('--mem-index-name NAME', 'Memory index name') { |v| options[:mem_index_name] = v } + opts.on('--mem-key-prefix PREFIX', 'JSON memory key prefix') { |v| options[:mem_key_prefix] = v } + opts.on('--session-key-prefix PREFIX', 'Session hash key prefix') { |v| options[:session_key_prefix] = v } + opts.on('--event-key-prefix PREFIX', 'Event stream key prefix') { |v| options[:event_key_prefix] = v } + opts.on('--session-ttl-seconds N', Integer, 'Working-memory TTL') { |v| options[:session_ttl_seconds] = v } + opts.on('--dedup-threshold F', Float, 'Dedup cosine-distance cutoff') { |v| options[:dedup_threshold] = v } + opts.on('--recall-threshold F', Float, 'Recall cosine-distance cutoff') { |v| options[:recall_threshold] = v } + opts.on('--no-reset', 'Keep existing memories on startup') { options[:no_reset] = true } + end.parse!(argv) + options + end + + def self.run!(argv = ARGV) + args = parse_flags(argv) + + client = Redis.new(host: args[:redis_host], port: args[:redis_port]) + begin + client.ping + rescue StandardError => e + warn("Error: cannot reach Redis at #{args[:redis_host]}:#{args[:redis_port]}") + warn(" (#{e.message})") + exit 1 + end + + session_store = AgentSession.new( + redis_client: client, + key_prefix: args[:session_key_prefix], + default_ttl_seconds: args[:session_ttl_seconds] + ) + memory = LongTermMemory.new( + redis_client: client, + index_name: args[:mem_index_name], + key_prefix: args[:mem_key_prefix], + dedup_threshold: args[:dedup_threshold], + recall_threshold: args[:recall_threshold] + ) + memory.create_index + event_log = AgentEventLog.new( + redis_client: client, + key_prefix: args[:event_key_prefix] + ) + + puts 'Loading embedding model (first run downloads the ONNX weights)...' + embedder = LocalEmbedder.new + + demo = AgentMemoryDemo.new( + session_store: session_store, memory: memory, + event_log: event_log, embedder: embedder + ) + + unless args[:no_reset] + puts "Dropping any existing memories under '#{args[:mem_key_prefix]}*' " \ + "and re-seeding from the sample memory list (pass --no-reset to keep)." + seeded = demo.seed(user: 'default', namespace: 'default') + puts "Seeded #{seeded} memories." + end + + # Load the HTML once and replace the template tokens with the + # configured key prefixes and index name so the lede shows the + # actual values in use rather than the default copies. + raw_html = File.read(File.expand_path('index.html', __dir__)) + html_page = raw_html + .gsub('__SESSION_PREFIX__', args[:session_key_prefix]) + .gsub('__MEM_PREFIX__', args[:mem_key_prefix]) + .gsub('__MEM_INDEX__', args[:mem_index_name]) + .gsub('__EVENT_PREFIX__', args[:event_key_prefix]) + + stack_label = 'redis-rb + informers + WEBrick' + + # WEBrick: turn down access logging so the console isn't a flood + # of GET / lines while the demo is running. + server = WEBrick::HTTPServer.new( + BindAddress: args[:host], + Port: args[:port], + Logger: WEBrick::Log.new($stderr, WEBrick::Log::WARN), + AccessLog: [] + ) + + install_handlers(server, { + session_store: session_store, + memory: memory, + event_log: event_log, + embedder: embedder, + demo: demo, + html_page: html_page, + stack_label: stack_label, + recall_threshold: args[:recall_threshold] + }) + + trap('INT') { server.shutdown } + trap('TERM') { server.shutdown } + + puts "Redis agent memory demo listening on http://#{args[:host]}:#{args[:port]}" + puts "Using Redis at #{args[:redis_host]}:#{args[:redis_port]} " \ + "with memory index '#{args[:mem_index_name]}'" + server.start + ensure + client&.close + end +end + +AgentMemory.run! if $PROGRAM_NAME == __FILE__ diff --git a/content/develop/use-cases/agent-memory/ruby/index.html b/content/develop/use-cases/agent-memory/ruby/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/ruby/lib/embeddings.rb b/content/develop/use-cases/agent-memory/ruby/lib/embeddings.rb new file mode 100644 index 0000000000..74d536c741 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/lib/embeddings.rb @@ -0,0 +1,90 @@ +# Local text-embedding helper backed by the `informers` gem. +# +# `informers` is a Ruby port of Hugging Face transformers that runs +# the ONNX-exported `sentence-transformers/all-MiniLM-L6-v2` encoder +# through the `onnxruntime` gem — same 384-d model the Python, Node.js, +# .NET, Rust, Go, Jedis, Lettuce, and PHP siblings use. Vectors are +# L2-normalized so a Redis Search index declared with +# `DISTANCE_METRIC COSINE` returns scores that are directly +# comparable across entries. +# +# Embeddings are numerically very close to the PyTorch reference (the +# Ruby ONNX path matches the Node.js Xenova ONNX path to ~0.01 in +# cosine distance); the model is downloaded into the local Hugging +# Face cache on the first call and every later call runs offline. + +require 'informers' + +module AgentMemory + # `informers` exposes a synchronous API, so the constructor does the + # model load directly. We probe the output shape once and record the + # dimension on the instance so callers can compare against the + # index's expected vector dimension before doing any inserts. + # LongTermMemory also checks length on every remember / recall, so a + # model swap that produces wrong-dim vectors fails at the call site + # with a clear error. + class LocalEmbedder + DEFAULT_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' + + attr_reader :model_name, :dim + + def initialize(model_name: DEFAULT_MODEL) + @model_name = model_name + # `Informers.pipeline("embedding", ...)` returns a configured + # EmbeddingPipeline. The `call(text, pooling:, normalize:)` API + # mirrors @xenova/transformers' feature-extraction pipeline so + # the Node.js sibling's code looks structurally identical. + @model = Informers.pipeline('embedding', model_name) + probe = encode_one('dimension probe') + @dim = probe.length + end + + # Encode a single string. Returns a 384-element Array of Float + # (Ruby doubles; the values themselves are float32 round-trips + # from the ONNX session so the precision is the model's). + # + # We pass `normalize: true` to informers, which L2-normalizes in + # the ONNX graph itself — the result is already a unit vector, + # so a second pass through `l2_normalize` would be redundant. + def encode_one(text) + vec = @model.(text, pooling: 'mean', normalize: true) + # The pipeline returns a flat Array when the input is a single + # string and an Array when the input is an Array; we + # special-case below in encode_many. Defensive flatten in case + # a future release unifies the shapes. + vec = vec.first if vec.first.is_a?(Array) + validate_dim!(vec) + vec + end + + # Encode several strings in one pipeline call. Returns an + # Array of float values, one row per input string. Raises + # if the model produces a different number of rows than inputs — + # that would silently misalign the seed phase otherwise. + def encode_many(texts) + rows = @model.(texts, pooling: 'mean', normalize: true) + if rows.length != texts.length + raise "informers returned #{rows.length} vectors for #{texts.length} inputs" + end + rows.each { |row| validate_dim!(row) } + rows + end + + # Pack a Ruby Array of Float into the bytes Redis Search expects: + # raw little-endian float32, no header, exactly `dim * 4` bytes. + # Ruby's `Array#pack` directive `'e'` is little-endian single + # precision float; `'e*'` packs every element. This is the + # encoding RediSearch reads for a `VECTOR ... TYPE FLOAT32` field. + def self.to_bytes(vector) + vector.pack('e*') + end + + private + + def validate_dim!(vec) + return if @dim.nil? + return if vec.length == @dim + raise "encoder produced #{vec.length}-d vector; expected #{@dim}-d" + end + end +end diff --git a/content/develop/use-cases/agent-memory/ruby/lib/event_log.rb b/content/develop/use-cases/agent-memory/ruby/lib/event_log.rb new file mode 100644 index 0000000000..bc6a14bbf8 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/lib/event_log.rb @@ -0,0 +1,97 @@ +# Append-only event log for an agent thread, backed by a Redis Stream. +# +# Each thread gets a stream at `agent:events:{thread_id}`. Every +# action the agent takes (a user turn arriving, a memory being +# recalled, a memory being written, a tool being called) is one +# `XADD` to that stream. Replay with `XREVRANGE` for the most recent +# N events; bound retention with `XADD MAXLEN ~` so the log stays +# cheap regardless of how long the thread has been running. +# +# The stream is independent of the session Hash (`session_store.rb`) +# and the long-term memory store (`long_term_memory.rb`): it answers +# the "what just happened" question without competing with either of +# those for indexing or memory budget. Consumer groups (not used in +# this demo) would let downstream workers — summarizers, +# consolidators, audit pipelines — replay the log without losing +# position. + +require 'redis' + +module AgentMemory + # Approximate cap on stream length. `MAXLEN ~` lets Redis trim in + # whole-node units instead of exactly-N units, which is much + # cheaper at the cost of overshooting the bound by up to a node's + # worth. + DEFAULT_MAXLEN = 1000 + + AgentEvent = Struct.new( + :event_id, :thread_id, :action, :detail, :ts, + keyword_init: true + ) do + def to_h + { + event_id: event_id, + thread_id: thread_id, + action: action, + detail: detail, + ts: ts + } + end + end + + class AgentEventLog + attr_reader :redis, :key_prefix, :max_len + + def initialize(redis_client: nil, + key_prefix: 'agent:events:', + max_len: DEFAULT_MAXLEN) + @redis = redis_client || Redis.new(host: 'localhost', port: 6379) + @key_prefix = key_prefix + @max_len = max_len + end + + def stream_key(thread_id) + "#{@key_prefix}#{thread_id}" + end + + # Append one event and return its stream id. + # + # `MAXLEN ~ N` keeps the stream bounded with near-zero overhead; + # an exact bound (`MAXLEN N` without the tilde) forces a scan + # and is rarely worth the cost. + def record(thread_id, action, detail = '') + @redis.xadd( + stream_key(thread_id), + { + 'action' => action, + 'detail' => detail, + 'ts' => Time.now.to_f.to_s + }, + approximate: true, + maxlen: @max_len + ) + end + + # Return the most recent events, newest first. + def recent(thread_id, count: 20) + rows = @redis.xrevrange(stream_key(thread_id), count: count) + rows.map do |entry_id, fields| + AgentEvent.new( + event_id: entry_id, + thread_id: thread_id, + action: fields['action'] || '', + detail: fields['detail'] || '', + ts: (fields['ts'] || '0').to_f + ) + end + end + + def length(thread_id) + @redis.xlen(stream_key(thread_id)) + end + + def clear(thread_id) + @redis.del(stream_key(thread_id)).positive? + end + end +end diff --git a/content/develop/use-cases/agent-memory/ruby/lib/long_term_memory.rb b/content/develop/use-cases/agent-memory/ruby/lib/long_term_memory.rb new file mode 100644 index 0000000000..3c535881e5 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/lib/long_term_memory.rb @@ -0,0 +1,450 @@ +# Long-term memory store for an agent, backed by Redis JSON and Search. +# +# Each memory lives as one JSON document at `agent:mem:`. The +# document holds the memory text, its embedding vector, and a small +# metadata block — user, namespace, kind, source thread, timestamps — +# that lets the recall query scope results without falling back to +# application-side filtering. +# +# A single Redis Search index covers the embedding plus every metadata +# field, so one `FT.SEARCH` call performs approximate-nearest- +# neighbor over the in-scope subset and returns the top-k memories +# ranked by cosine distance. The same KNN check runs at *write* time +# to deduplicate near-identical memories before they enter the store, +# which keeps the index from filling with paraphrases of the same fact +# as the agent reasons over similar topics across sessions. +# +# Memories carry one of two kinds: +# +# * `episodic` — "what happened" snapshots from a specific thread, +# written with a medium TTL so old session detail decays naturally. +# * `semantic` — distilled facts and preferences the agent should +# carry forward indefinitely. Written with no TTL by default. +# +# The split is enforced as a TAG on the index, so the recall query +# can ask for one kind or both with a filter — no separate keyspaces. + +require 'json' +require 'redis' +require 'securerandom' +require 'set' + +require_relative 'embeddings' + +module AgentMemory + VECTOR_DIM_DEFAULT = 384 + + # How close (cosine distance) a candidate must be to an existing + # memory to count as a duplicate at write time. Smaller = stricter. + # 0.20 is calibrated to the `all-MiniLM-L6-v2` embedding model used + # in the demo, where a paraphrase of an existing memory lands in + # the 0.10 – 0.20 range and a distinct memory lands above 0.50. + DEFAULT_DEDUP_THRESHOLD = 0.20 + + # How close (cosine distance) a candidate must be to count as a + # relevant recall result. Larger than the dedup threshold so the + # agent gets a wider net at read time than at write time. + DEFAULT_RECALL_THRESHOLD = 0.55 + + # TTL tiers, in seconds. `nil` means "no TTL" — the memory persists + # until explicitly deleted or evicted under memory pressure. + TTL_BY_KIND = { + 'episodic' => 7 * 24 * 3600, + 'semantic' => nil + }.freeze + + # A single memory document returned from the store. + MemoryRecord = Struct.new( + :id, :user, :namespace, :kind, :source_thread, + :text, :created_ts, :hit_count, + :distance, :ttl_seconds, keyword_init: true + ) do + def to_h + { + id: id, + user: user, + namespace: namespace, + kind: kind, + source_thread: source_thread, + text: text, + created_ts: created_ts, + hit_count: hit_count, + distance: distance.nil? ? nil : distance.round(4), + ttl_seconds: ttl_seconds + } + end + end + + # Outcome of a `remember` call. `deduped` is `true` when the write + # skipped because a similar memory already existed; `id` is then + # the existing memory's id. `existing_distance` is the cosine + # distance to the nearest memory regardless of which branch was + # taken — useful for tracing. + WriteResult = Struct.new( + :id, :deduped, :existing_distance, keyword_init: true + ) do + def to_h + { + id: id, + deduped: deduped, + existing_distance: + existing_distance.nil? ? nil : existing_distance.round(4) + } + end + end + + class LongTermMemory + # Characters Redis Search treats as syntax inside a TAG value; any + # of them in a user-supplied filter must be backslash-escaped or + # the surrounding `{...}` block won't parse correctly. + TAG_SPECIAL = Set.new("\\,.<>{}[]\"':;!@#$%^&*()-+=~| ".chars).freeze + + attr_reader :redis, :index_name, :key_prefix, :vector_dim, + :dedup_threshold, :recall_threshold, :ttl_by_kind + + def initialize(redis_client: nil, + index_name: 'agentmem:idx', + key_prefix: 'agent:mem:', + vector_dim: VECTOR_DIM_DEFAULT, + dedup_threshold: DEFAULT_DEDUP_THRESHOLD, + recall_threshold: DEFAULT_RECALL_THRESHOLD, + ttl_by_kind: nil) + @redis = redis_client || Redis.new(host: 'localhost', port: 6379) + @index_name = index_name + @key_prefix = key_prefix + @vector_dim = vector_dim + @dedup_threshold = dedup_threshold + @recall_threshold = recall_threshold + @ttl_by_kind = ttl_by_kind || TTL_BY_KIND.dup + end + + # ---------------------------------------------------------------- + # Keys and index + # ---------------------------------------------------------------- + + def memory_key(memory_id) + "#{@key_prefix}#{memory_id}" + end + + # Create the Redis Search index if it doesn't already exist. + # + # The index is declared on the JSON document type, with a + # `$.embedding` path holding the vector and TAG fields for + # `user`, `namespace`, `kind`, and `source_thread`. One + # `FT.SEARCH` can therefore pre-filter by any combination of + # those tags and KNN-rank the matching memories in one pass. + def create_index + args = [ + 'FT.CREATE', @index_name, + 'ON', 'JSON', + 'PREFIX', '1', @key_prefix, + 'SCHEMA', + '$.text', 'AS', 'text', 'TEXT', + '$.user', 'AS', 'user', 'TAG', + '$.namespace', 'AS', 'namespace', 'TAG', + '$.kind', 'AS', 'kind', 'TAG', + '$.source_thread', 'AS', 'source_thread', 'TAG', + '$.created_ts', 'AS', 'created_ts', 'NUMERIC', 'SORTABLE', + '$.hit_count', 'AS', 'hit_count', 'NUMERIC', 'SORTABLE', + '$.embedding', 'AS', 'embedding', 'VECTOR', 'HNSW', '6', + 'TYPE', 'FLOAT32', + 'DIM', @vector_dim.to_s, + 'DISTANCE_METRIC', 'COSINE' + ] + @redis.call(*args) + rescue Redis::CommandError => e + raise unless e.message.include?('Index already exists') + end + + def drop_index(delete_documents: false) + args = ['FT.DROPINDEX', @index_name] + args << 'DD' if delete_documents + @redis.call(*args) + rescue Redis::CommandError => e + message = e.message.downcase + raise unless message.include?('no such index') || + message.include?('unknown index name') + end + + # ---------------------------------------------------------------- + # Write + # ---------------------------------------------------------------- + + # Write a new memory, deduplicating against existing entries. + # + # Runs one in-scope KNN(1) against the index first. If the + # nearest existing memory is within `dedup_threshold`, the new + # memory is skipped (its content is already represented) and the + # existing memory's `hit_count` is bumped. Otherwise a fresh JSON + # document is written under a new id with a TTL derived from the + # memory's `kind`. + # + # The KNN-then-write sequence is not atomic; two workers that + # remember the same fact at the same time can both miss each + # other's in-flight write and insert duplicate memories. See the + # walkthrough's "Concurrency caveats" section for the production + # fix (periodic background consolidator that merges + # near-duplicates). + def remember(text:, embedding:, + user: 'default', namespace: 'default', + kind: 'episodic', source_thread: '', + ttl_seconds: :default) + validate_dim!(embedding, 'embedding') + + nearest = nearest_neighbors( + embedding, user: user, namespace: namespace, kind: kind, k: 1 + ) + nearest_distance = nearest.first&.distance + if !nearest.empty? && + !nearest.first.distance.nil? && + nearest.first.distance <= @dedup_threshold + # Duplicate. Bump the hit count on the existing memory so the + # admin UI can show how often it's been re-derived. + bump_hit_count(nearest.first.id) + return WriteResult.new( + id: nearest.first.id, + deduped: true, + existing_distance: nearest_distance + ) + end + + memory_id = SecureRandom.hex(6) + key = memory_key(memory_id) + now = Time.now.to_f + doc = { + 'id' => memory_id, + 'user' => user, + 'namespace' => namespace, + 'kind' => kind, + 'source_thread' => source_thread, + 'text' => text, + 'embedding' => embedding, + 'created_ts' => now, + 'hit_count' => 0 + } + ttl = resolve_ttl(kind, ttl_seconds) + + # MULTI/EXEC so the JSON document and its TTL apply together. A + # connection drop between the JSON.SET and EXPIRE would + # otherwise leave the memory without an expiry. + @redis.multi do |m| + m.call('JSON.SET', key, '$', JSON.generate(doc)) + m.expire(key, ttl) unless ttl.nil? + end + WriteResult.new( + id: memory_id, + deduped: false, + existing_distance: nearest_distance + ) + end + + # ---------------------------------------------------------------- + # Recall + # ---------------------------------------------------------------- + + # Return the top-k in-scope memories ranked by similarity. + # + # Memories beyond `distance_threshold` (or the instance default) + # are dropped — the index always returns *something* for KNN, so + # a recall result on an unrelated query would otherwise be a + # confidently-wrong false positive. + def recall(query_embedding, + user: 'default', namespace: 'default', + kind: nil, k: 5, distance_threshold: nil) + threshold = distance_threshold || @recall_threshold + candidates = nearest_neighbors( + query_embedding, + user: user, namespace: namespace, kind: kind, k: k + ) + candidates.select do |c| + !c.distance.nil? && c.distance <= threshold + end + end + + # ---------------------------------------------------------------- + # Admin / inspection + # ---------------------------------------------------------------- + + # Subset of `FT.INFO` useful for the demo UI. + def index_info + raw = @redis.call('FT.INFO', @index_name) + info = ft_info_to_hash(raw) + { + num_docs: (info['num_docs'] || 0).to_i, + indexing_failures: (info['hash_indexing_failures'] || 0).to_i + } + rescue Redis::CommandError + { num_docs: 0, indexing_failures: 0 } + end + + # Return memories matching the filters, newest first. The + # admin-panel listing skips the distance column (no KNN ran). + def list_memories(user: nil, namespace: nil, kind: nil, limit: 100) + filter_clause = self.class.build_filter_clause( + user: user, namespace: namespace, kind: kind + ) + args = [ + 'FT.SEARCH', @index_name, filter_clause, + 'RETURN', '7', + 'user', 'namespace', 'kind', 'source_thread', + 'text', 'created_ts', 'hit_count', + 'SORTBY', 'created_ts', 'DESC', + 'LIMIT', '0', limit.to_s, + 'DIALECT', '2' + ] + begin + result = @redis.call(*args) + rescue Redis::CommandError + return [] + end + parse_search_result(result).map do |doc| + raw_key = doc[:_key] + memory_id = strip_prefix(raw_key) + ttl = @redis.ttl(memory_key(memory_id)) + MemoryRecord.new( + id: memory_id, + user: doc[:user] || '', + namespace: doc[:namespace] || '', + kind: doc[:kind] || '', + source_thread: doc[:source_thread] || '', + text: doc[:text] || '', + created_ts: (doc[:created_ts] || '0').to_f, + hit_count: (doc[:hit_count] || '0').to_i, + distance: nil, + ttl_seconds: ttl && ttl.positive? ? ttl.to_i : nil + ) + end + end + + def delete_memory(memory_id) + @redis.del(memory_key(memory_id)).positive? + end + + # Drop the index and every memory document. Returns the count of + # documents that were removed. In production the equivalent is + # `FLUSHDB` on a dedicated memory database, or letting TTLs and + # eviction expire entries naturally. + def clear + before = index_info[:num_docs] + drop_index(delete_documents: true) + create_index + before + end + + # ---------------------------------------------------------------- + # Filter clause + # ---------------------------------------------------------------- + + def self.escape_tag_value(value) + value.each_char.map { |c| TAG_SPECIAL.include?(c) ? "\\#{c}" : c }.join + end + + def self.build_filter_clause(user:, namespace:, kind:) + # `nil` and `""` mean "no filter"; any other value, including + # `"0"`, must scope so a user named "0" isn't silently merged + # into the rest of the corpus. + clauses = [] + clauses << "@user:{#{escape_tag_value(user)}}" unless user.nil? || user.empty? + clauses << "@namespace:{#{escape_tag_value(namespace)}}" unless namespace.nil? || namespace.empty? + clauses << "@kind:{#{escape_tag_value(kind)}}" unless kind.nil? || kind.empty? + clauses.empty? ? '(*)' : "(#{clauses.join(' ')})" + end + + private + + def nearest_neighbors(embedding, user:, namespace:, kind:, k:) + validate_dim!(embedding, 'embedding') + filter_clause = self.class.build_filter_clause( + user: user, namespace: namespace, kind: kind + ) + knn_query = "#{filter_clause}=>[KNN #{k} @embedding $vec AS distance]" + vec_bytes = LocalEmbedder.to_bytes(embedding) + args = [ + 'FT.SEARCH', @index_name, knn_query, + 'PARAMS', '2', 'vec', vec_bytes, + 'SORTBY', 'distance', 'ASC', + 'RETURN', '8', + 'user', 'namespace', 'kind', 'source_thread', + 'text', 'created_ts', 'hit_count', 'distance', + 'LIMIT', '0', k.to_s, + 'DIALECT', '2' + ] + result = @redis.call(*args) + parse_search_result(result).map do |doc| + raw_key = doc[:_key] + memory_id = strip_prefix(raw_key) + ttl = @redis.ttl(memory_key(memory_id)) + MemoryRecord.new( + id: memory_id, + user: doc[:user] || '', + namespace: doc[:namespace] || '', + kind: doc[:kind] || '', + source_thread: doc[:source_thread] || '', + text: doc[:text] || '', + created_ts: (doc[:created_ts] || '0').to_f, + hit_count: (doc[:hit_count] || '0').to_i, + distance: doc[:distance].nil? ? nil : doc[:distance].to_f, + ttl_seconds: ttl && ttl.positive? ? ttl.to_i : nil + ) + end + end + + def bump_hit_count(memory_id) + @redis.call('JSON.NUMINCRBY', memory_key(memory_id), '$.hit_count', 1) + rescue Redis::CommandError + # The doc may have expired between recall and bump — fine, we + # just lose the hit count update. + end + + def resolve_ttl(kind, override) + return @ttl_by_kind[kind] if override == :default + override + end + + def strip_prefix(raw_key) + raw_key.start_with?(@key_prefix) ? raw_key[@key_prefix.length..] : raw_key + end + + def validate_dim!(vector, label) + unless vector.respond_to?(:length) && vector.length == @vector_dim + actual = vector.respond_to?(:length) ? vector.length : 'unknown' + raise ArgumentError, + "#{label} has length #{actual}; index expects #{@vector_dim}" + end + end + + # Parse the raw `FT.SEARCH` reply (RESP2 layout). The shape is: + # [ total, key1, [field1, value1, field2, value2, ...], key2, ... ] + # where each key is followed by a flat field/value array. + def parse_search_result(reply) + return [] unless reply.is_a?(Array) && !reply.empty? + docs = [] + i = 1 + while i < reply.length + key = reply[i] + fields = reply[i + 1] + i += 2 + next if fields.nil? + doc = { _key: key } + j = 0 + while j < fields.length + doc[fields[j].to_s.to_sym] = fields[j + 1] + j += 2 + end + docs << doc + end + docs + end + + def ft_info_to_hash(reply) + return {} unless reply.is_a?(Array) + out = {} + i = 0 + while i < reply.length + out[reply[i].to_s] = reply[i + 1] + i += 2 + end + out + end + end +end diff --git a/content/develop/use-cases/agent-memory/ruby/lib/seed_memory.rb b/content/develop/use-cases/agent-memory/ruby/lib/seed_memory.rb new file mode 100644 index 0000000000..a6a7433c71 --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/lib/seed_memory.rb @@ -0,0 +1,88 @@ +# Pre-seed the long-term memory store with sample memories. +# +# In a real deployment the memory store fills up organically as the +# agent reasons over user turns: each turn produces zero or more +# memories (preferences, facts, episodic summaries) that flow into the +# store with deduplication. To make the demo immediately useful — so +# the first recall query lands on relevant results instead of an empty +# list — we seed a small set of canonical memories for a default user +# at startup. +# +# The seed list mixes `semantic` memories (long-lived preferences and +# facts) with `episodic` memories (snapshots of past sessions), +# matching what every other agent-memory port seeds so the +# implementations behave identically regardless of which one you +# point at a given Redis instance. + +require_relative 'embeddings' +require_relative 'long_term_memory' + +module AgentMemory + module SeedMemory + SEED_MEMORIES = [ + { + text: 'The user prefers concise answers without filler phrases.', + kind: 'semantic' + }, + { + text: 'The user is a Python developer working on a logistics platform.', + kind: 'semantic' + }, + { + text: 'The user lives in Berlin and works in the Europe/Berlin time zone.', + kind: 'semantic' + }, + { + text: 'The user dislikes dark mode and prefers a high-contrast ' \ + 'light theme in editors and dashboards.', + kind: 'semantic' + }, + { + text: 'The user is allergic to peanuts; any restaurant suggestion ' \ + 'must avoid dishes that commonly contain them.', + kind: 'semantic' + }, + { + text: 'Last Tuesday the user asked the agent to draft a postmortem ' \ + 'for the order-routing outage. The agent produced a ' \ + 'five-section draft and the user approved sections 1, 2, ' \ + 'and 4 with minor edits.', + kind: 'episodic' + }, + { + text: 'In a previous session the user asked for help debugging a ' \ + 'flaky test in the inventory service. The fix turned out ' \ + 'to be a race condition in the warehouse webhook handler.', + kind: 'episodic' + }, + { + text: 'Two weeks ago the user mentioned they were planning to ' \ + 'migrate the analytics warehouse from Snowflake to ' \ + 'BigQuery in Q3.', + kind: 'episodic' + } + ].freeze + + # Embed and write the seed memories. Returns the count actually + # written (entries that dedup against existing memories don't + # count). + def self.seed(memory, embedder, + user: 'default', namespace: 'default', + source_thread: 'seed') + texts = SEED_MEMORIES.map { |m| m[:text] } + vectors = embedder.encode_many(texts) + written = 0 + SEED_MEMORIES.each_with_index do |entry, i| + result = memory.remember( + text: entry[:text], + embedding: vectors[i], + user: user, namespace: namespace, + kind: entry[:kind], + source_thread: source_thread + ) + written += 1 unless result.deduped + end + written + end + end +end diff --git a/content/develop/use-cases/agent-memory/ruby/lib/session_store.rb b/content/develop/use-cases/agent-memory/ruby/lib/session_store.rb new file mode 100644 index 0000000000..cb14bad71b --- /dev/null +++ b/content/develop/use-cases/agent-memory/ruby/lib/session_store.rb @@ -0,0 +1,243 @@ +# Working-memory store for an agent session, backed by a Redis Hash. +# +# Each session is one Hash document at `agent:session:{thread_id}`. +# The hash holds the running scratchpad, the current goal, a rolling +# window of recent turns (serialized as a JSON list to fit in one +# field), and a few audit fields. One `HGETALL` returns the whole +# session in a single round trip on every step of the agent loop. +# +# Every write refreshes the key's TTL with `EXPIRE`, so idle sessions +# fall off without a separate cleanup job and active sessions stay +# alive as long as the agent keeps touching them. A separate +# `LongTermMemory` (see `long_term_memory.rb`) is what survives beyond +# a session's TTL. +# +# The turn window is bounded to `max_turns` in application code; the +# hash itself doesn't grow, so the working set per thread stays +# constant regardless of how long the agent has been running. + +require 'json' +require 'redis' +require 'securerandom' + +module AgentMemory + MAX_TURNS = 20 + + # Loaded session state. `recent_turns` is an Array of Hashes + # (`role`, `content`, `ts`). `ttl_seconds` is the remaining lifetime + # at load time; `0` means no TTL or already expired. + SessionState = Struct.new( + :thread_id, :user, :agent, :goal, :scratchpad, + :turn_count, :created_ts, :last_active_ts, + :recent_turns, :ttl_seconds, keyword_init: true + ) do + def to_h + { + thread_id: thread_id, + user: user, + agent: agent, + goal: goal, + scratchpad: scratchpad, + turn_count: turn_count, + created_ts: created_ts, + last_active_ts: last_active_ts, + recent_turns: recent_turns, + ttl_seconds: ttl_seconds + } + end + end + + class AgentSession + attr_reader :redis, :key_prefix, :default_ttl_seconds, :max_turns + + def initialize(redis_client: nil, + key_prefix: 'agent:session:', + default_ttl_seconds: 3600, + max_turns: MAX_TURNS) + @redis = redis_client || Redis.new(host: 'localhost', port: 6379) + @key_prefix = key_prefix + @default_ttl_seconds = default_ttl_seconds + @max_turns = max_turns + end + + def session_key(thread_id) + "#{@key_prefix}#{thread_id}" + end + + def new_thread_id + SecureRandom.hex(6) # 12 hex chars, matches sibling demos + end + + # Create a fresh working memory for a thread. Overwrites any + # existing session at the same key. The agent normally calls this + # once per thread at the first turn and relies on `load` / + # `append_turn` for subsequent steps. + def start(thread_id, user: 'default', agent: 'default', + goal: '', ttl_seconds: nil) + ttl = ttl_seconds || @default_ttl_seconds + now = Time.now.to_f + state = SessionState.new( + thread_id: thread_id, + user: user, agent: agent, + goal: goal, scratchpad: '', + turn_count: 0, + created_ts: now, last_active_ts: now, + recent_turns: [], + ttl_seconds: ttl + ) + write_state(state, ttl) + state + end + + # Return the session state, or `nil` if it has expired. + def load(thread_id) + key = session_key(thread_id) + raw = @redis.hgetall(key) + return nil if raw.nil? || raw.empty? + ttl = @redis.ttl(key) + turns_blob = raw['recent_turns'] || '[]' + turns = begin + JSON.parse(turns_blob) + rescue JSON::ParserError + [] + end + SessionState.new( + thread_id: thread_id, + user: raw['user'] || 'default', + agent: raw['agent'] || 'default', + goal: raw['goal'] || '', + scratchpad: raw['scratchpad'] || '', + turn_count: (raw['turn_count'] || '0').to_i, + created_ts: (raw['created_ts'] || '0').to_f, + last_active_ts: (raw['last_active_ts'] || '0').to_f, + recent_turns: turns, + ttl_seconds: ttl && ttl.positive? ? ttl.to_i : 0 + ) + end + + # Append a turn, bound the rolling window, refresh the TTL. + # + # `user` and `agent` are only consulted when the session does not + # yet exist — they seed the auto-created session so the + # working-memory hash matches the user the caller is operating + # against. On an existing session they're ignored; the original + # `start` values stand. + # + # Read-modify-write here is last-writer-wins on the turn list if + # two concurrent turns reach the same thread; the demo never + # triggers that race in practice (one browser, one turn at a + # time) but a multi-worker agent that shares a thread id would + # wrap this in `WATCH` / `MULTI` / `EXEC` or a Lua script that + # does the append atomically server-side. + def append_turn(thread_id, role:, content:, + user: nil, agent: nil, ttl_seconds: nil) + state = load(thread_id) + if state.nil? + state = start( + thread_id, + user: user || 'default', + agent: agent || 'default', + ttl_seconds: ttl_seconds + ) + end + state.recent_turns << { + 'role' => role, + 'content' => content, + 'ts' => Time.now.to_f + } + if state.recent_turns.length > @max_turns + state.recent_turns = state.recent_turns.last(@max_turns) + end + state.turn_count += 1 + state.last_active_ts = Time.now.to_f + ttl = ttl_seconds || @default_ttl_seconds + state.ttl_seconds = ttl + write_state(state, ttl) + state + end + + # Update the agent's running scratchpad and refresh TTL. + def set_scratchpad(thread_id, text, ttl_seconds: nil) + state = load(thread_id) + return nil if state.nil? + state.scratchpad = text + state.last_active_ts = Time.now.to_f + ttl = ttl_seconds || @default_ttl_seconds + state.ttl_seconds = ttl + write_state(state, ttl) + state + end + + # Update the goal field without touching turns or the scratchpad. + # + # Creates the session if it doesn't exist yet — setting a goal on + # a fresh thread is a sensible first step in the agent loop, so + # this method covers both the "rename the goal mid-session" and + # the "start a thread with this goal" cases. + def set_goal(thread_id, text, + user: nil, agent: nil, ttl_seconds: nil) + state = load(thread_id) + if state.nil? + return start( + thread_id, + user: user || 'default', + agent: agent || 'default', + goal: text, + ttl_seconds: ttl_seconds + ) + end + state.goal = text + state.last_active_ts = Time.now.to_f + ttl = ttl_seconds || @default_ttl_seconds + state.ttl_seconds = ttl + write_state(state, ttl) + state + end + + # Drop the session immediately. Returns `true` if it existed. + def delete(thread_id) + @redis.del(session_key(thread_id)).positive? + end + + # Return active thread ids under this prefix (for a multi-thread + # switcher UI). `SCAN` is used so a busy instance with many other + # keys isn't blocked by a full `KEYS` sweep. + def list_threads(limit: 100) + out = [] + cursor = '0' + loop do + cursor, keys = @redis.scan(cursor, match: "#{@key_prefix}*", count: 200) + keys.each do |key| + out << (key.start_with?(@key_prefix) ? key[@key_prefix.length..] : key) + return out if out.length >= limit + end + break if cursor == '0' + end + out + end + + private + + def write_state(state, ttl) + key = session_key(state.thread_id) + mapping = { + 'thread_id' => state.thread_id, + 'user' => state.user, + 'agent' => state.agent, + 'goal' => state.goal, + 'scratchpad' => state.scratchpad, + 'turn_count' => state.turn_count.to_s, + 'created_ts' => state.created_ts.to_s, + 'last_active_ts' => state.last_active_ts.to_s, + 'recent_turns' => JSON.generate(state.recent_turns) + } + # MULTI/EXEC so HSET and EXPIRE either both apply or neither + # does. A connection drop between the two writes would + # otherwise leave the session without a TTL. + @redis.multi do |m| + m.hset(key, mapping) + m.expire(key, ttl) + end + end + end +end diff --git a/content/develop/use-cases/agent-memory/rust/Cargo.toml b/content/develop/use-cases/agent-memory/rust/Cargo.toml new file mode 100644 index 0000000000..f9442800b6 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "agent-memory-demo" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +name = "agent_memory_demo" +path = "src/lib.rs" + +[[bin]] +name = "agent-memory-demo" +path = "src/main.rs" + +# Optimise the release build. The candle-based embedder is CPU-bound +# on token throughput; without -O the first encode takes long enough +# to be noticeable in the demo, and the docs example tells readers to +# use --release for that reason. +[profile.release] +opt-level = 3 +lto = "thin" +codegen-units = 1 + +[dependencies] +redis = { version = "0.27", default-features = false } +tiny_http = "0.12" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +byteorder = "1" +url = "2" + +# Candle is the embedder. We need: +# - candle-core for the tensor type +# - candle-nn for module trait implementations BERT pulls in +# - candle-transformers for the BertModel + config types +# - tokenizers for the HuggingFace WordPiece tokenizer +# - hf-hub for fetching the model weights and tokenizer on first run +candle-core = "0.8" +candle-nn = "0.8" +candle-transformers = "0.8" +tokenizers = { version = "0.20", default-features = false, features = ["onig"] } +hf-hub = "0.3" + +# UUID-style ids. The Python, Node, and .NET demos use 12 hex +# characters; we match the shape by hex-encoding 6 random bytes +# read from /dev/urandom via getrandom. +getrandom = "0.2" + +[dev-dependencies] diff --git a/content/develop/use-cases/agent-memory/rust/_index.md b/content/develop/use-cases/agent-memory/rust/_index.md new file mode 100644 index 0000000000..2292064067 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/_index.md @@ -0,0 +1,344 @@ +--- +categories: +- docs +- develop +- stack +- oss +- rs +- rc +description: Build a Redis-backed agent memory layer in Rust with redis-rs, Candle, and standard Redis commands — working memory in a Hash, long-term semantic recall as JSON with a vector index, and an event log in a Stream. +linkTitle: redis-rs example (Rust) +title: Redis agent memory with redis-rs +weight: 4 +--- + +This guide shows you how to build a small Redis-backed agent memory layer in Rust with [`redis-rs`]({{< relref "/develop/clients/redis-rs" >}}) and the [Candle](https://github.com/huggingface/candle) inference framework, using only standard Redis commands — no agent-memory SDK, no managed service. It includes a local web server built with [`tiny_http`](https://docs.rs/tiny_http) so you can send turns at the agent, watch working memory update in place, see semantically similar long-term memories recalled in real time, watch the write-time deduplication skip near-duplicates, and inspect the per-thread event log. + +The embedder is [Candle](https://github.com/huggingface/candle) running the canonical [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) PyTorch checkpoint — the same weights the [Python]({{< relref "/develop/use-cases/agent-memory/redis-py" >}}) example loads. Candle reads the `.bin` weights directly, so the vectors produced here are bit-identical to the Python ones, and the distance bands the Python walkthrough quotes carry over to this demo without recalibration. A memory written by one demo can be recalled by the other against the same Redis instance. + +## Overview + +The memory layer splits across three Redis primitives, each handling one tier: + +* **Working memory** for the active session is a [Hash]({{< relref "/develop/data-types/hashes" >}}) at `agent:session:` holding the goal, scratchpad, a rolling window of recent turns (as a JSON list inside one field), and a few audit timestamps. One [`HGETALL`]({{< relref "/commands/hgetall" >}}) returns the whole session in a single round trip; every write refreshes the key's [`EXPIRE`]({{< relref "/commands/expire" >}}) so idle sessions decay on their own. +* **Long-term memory** is a set of [JSON]({{< relref "/develop/data-types/json" >}}) documents at `agent:mem:`, each carrying the memory text, a 384-dimensional embedding vector, and tag fields for user, namespace, kind (episodic / semantic), and source thread. A single [Redis Search]({{< relref "/develop/ai/search-and-query" >}}) index covers the [HNSW vector field]({{< relref "/develop/ai/search-and-query/vectors" >}}) and every metadata field, so one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call performs the KNN with the metadata pre-filter in the same round trip. Write-time deduplication runs the same KNN at insert time and skips a new memory whose nearest existing entry is within a tighter threshold. +* **Event log** for the agent's actions and observations is a [Stream]({{< relref "/develop/data-types/streams" >}}) at `agent:events:`, appended with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}) so retention stays bounded automatically, replayed with [`XREVRANGE`]({{< relref "/commands/xrevrange" >}}). + +That gives you: + +* A single round trip per tier: one [`HGETALL`]({{< relref "/commands/hgetall" >}}) for the session, one [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) for recall, one [`XADD`]({{< relref "/commands/xadd" >}}) for the event log. +* Sub-millisecond reads on every step of the agent loop, so the memory layer doesn't dominate per-step latency. +* Per-tier decay: short TTLs on working memory, longer on episodic memories, no TTL on semantic memories. Combined with a database-level [eviction policy]({{< relref "/develop/reference/eviction" >}}) (LFU is the common choice), memory stays bounded under pressure. +* Scoping enforced inside the query: a recall query for `user=alice` will never see `user=bob`'s memories, because the TAG filter goes into the same [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) call as the KNN. + +## How it works + +Each turn through the agent loop touches all three tiers in one pass: append to working memory, recall similar long-term memories, write the turn back as a new memory (with deduplication), and append one event to the log. + +### Per-turn flow + +1. The application calls `embedder.encode_one(text)` to turn the incoming turn into a 384-dimensional `Vec`. +2. `session.append_turn(thread_id, role, content, Some(user), ..., None)` reads the per-thread Hash with [`HGETALL`]({{< relref "/commands/hgetall" >}}), appends the new turn to the rolling window in application code, trims it back to the configured maximum, and writes the Hash back with an [`HSET`]({{< relref "/commands/hset" >}}) + [`EXPIRE`]({{< relref "/commands/expire" >}}) pipeline inside a [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}). The session TTL refreshes on every write so an active thread stays alive. +3. `memory.recall(&vec, user, Some(namespace), None, 5, None)` runs [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) with a TAG pre-filter and a `KNN 5` clause. Redis returns the closest matching memories together with their cosine distances; memories beyond the recall threshold are dropped before they reach the agent so an unrelated query doesn't surface confident-looking false positives. +4. `memory.remember(text, &vec, user, namespace, kind, source_thread, None)` runs the same KNN with a tighter dedup threshold. If an existing memory is within the threshold, the new write is skipped and the existing memory's `hit_count` is incremented with [`JSON.NUMINCRBY`]({{< relref "/commands/json.numincrby" >}}); otherwise a fresh JSON document is written with [`JSON.SET`]({{< relref "/commands/json.set" >}}) and a per-kind [`EXPIRE`]({{< relref "/commands/expire" >}}) — `episodic` defaults to seven days, `semantic` has no TTL by default. +5. `events.record(thread_id, action, detail)` appends one entry to the per-thread Stream with [`XADD MAXLEN ~`]({{< relref "/commands/xadd" >}}), bounding retention to roughly a thousand entries per thread without an explicit cleanup job. + +The embedding is computed once and reused for steps 3 and 4 — there's no point encoding the same text twice. Recall runs before the write, so the agent doesn't see its own just-written turn echoed back as a recalled memory. + +## The session store + +`AgentSession` wraps the working-memory Hash and the rolling turn window ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/rust/src/session_store.rs)): + +```rust +use agent_memory_demo::session_store::AgentSession; + +let client = redis::Client::open("redis://127.0.0.1/")?; +let session = AgentSession::new( + &client, + "agent:session:", + 3600, // one hour TTL on the working-memory hash + 20, // rolling window per thread +)?; + +let thread_id = session.new_thread_id(); +session.start(&thread_id, "alice", "demo-agent", + "Plan next week's meetings.", None)?; +session.append_turn( + &thread_id, "user", + "Schedule a budget review with finance.", + Some("alice"), Some("demo-agent"), None, +)?; +let state = session.load(&thread_id)?; +println!("{} {} {}", + state.as_ref().map(|s| s.turn_count).unwrap_or(0), + state.as_ref().map(|s| s.recent_turns.len()).unwrap_or(0), + state.as_ref().map(|s| s.ttl_seconds).unwrap_or(0)); +``` + +The data model is one Hash per thread. The rolling turn window is stored as a JSON string in a single field so the whole session loads in one [`HGETALL`]({{< relref "/commands/hgetall" >}}) — the hash never grows in size or field count as the conversation goes on. + +```text +agent:session:9f3d2a4b8c61 + thread_id=9f3d2a4b8c61 + user=alice + agent=demo-agent + goal=Plan next week's meetings. + scratchpad=Need to confirm finance's availability. + turn_count=4 + created_ts=1715990400.12 + last_active_ts=1715990650.83 + recent_turns=[{"role":"user","content":"...","ts":...}, ...] +``` + +Every write — `start`, `append_turn`, `set_goal`, `set_scratchpad` — runs the [`HSET`]({{< relref "/commands/hset" >}}) and [`EXPIRE`]({{< relref "/commands/expire" >}}) inside a `redis::pipe().atomic()` block so a connection drop between the two writes can't leave the session without a TTL. + +## The long-term memory store + +`LongTermMemory` owns the JSON documents, the vector index, the recall query, and the write-time deduplication ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/rust/src/long_term_memory.rs)): + +```rust +use agent_memory_demo::embeddings::LocalEmbedder; +use agent_memory_demo::long_term_memory::LongTermMemory; + +let memory = LongTermMemory::new( + &client, + "agentmem:idx", + "agent:mem:", + 384, // vector dimension + 0.20, // dedup threshold — tight at write time + 0.55, // recall threshold — looser at read time +)?; +let embedder = LocalEmbedder::new(None)?; +memory.create_index()?; // idempotent + +// Write a memory. The same KNN that powers recall also runs here +// at a tighter threshold so paraphrases of the same fact collapse. +let vec = embedder.encode_one("The user prefers light mode in editors.")?; +let result = memory.remember( + "The user prefers light mode in editors.", + &vec, + "alice", + "default", + "semantic", + "9f3d2a4b8c61", + None, +)?; +println!("deduped={} id={} dist={:?}", + result.deduped, result.id, result.existing_distance); + +// Recall against a later question. +let q = embedder.encode_one("Which theme does this user like?")?; +let hits = memory.recall(&q, "alice", Some("default"), None, 5, None)?; +for h in &hits { + println!("{:.3} [{}] {}", + h.distance.unwrap_or(0.0), h.kind, h.text); +} +``` + +### Data model + +Each memory is a JSON document at `agent:mem:`. The embedding is stored as a JSON array of floats so the document is human-readable from `redis-cli`; [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) still expects the *query* vector as raw `float32` bytes (`byteorder::LittleEndian` via `embeddings::floats_to_bytes`), regardless of how the indexed document stores it. + +```json +agent:mem:7c3f8a1b9e02 +{ + "id": "7c3f8a1b9e02", + "user": "alice", + "namespace": "default", + "kind": "semantic", + "source_thread": "9f3d2a4b8c61", + "text": "The user prefers light mode in editors.", + "embedding": [0.013, -0.041, ...], + "created_ts": 1715990400.12, + "hit_count": 0 +} +``` + +The Redis Search index is declared on the JSON document type with `AS` aliases so the query syntax stays compact: + +```text +FT.CREATE agentmem:idx + ON JSON PREFIX 1 agent:mem: + SCHEMA + $.text AS text TEXT + $.user AS user TAG + $.namespace AS namespace TAG + $.kind AS kind TAG + $.source_thread AS source_thread TAG + $.created_ts AS created_ts NUMERIC SORTABLE + $.hit_count AS hit_count NUMERIC SORTABLE + $.embedding AS embedding VECTOR HNSW 6 + TYPE FLOAT32 DIM 384 + DISTANCE_METRIC COSINE +``` + +### The query + +Both recall and dedup share the same hybrid query: a TAG pre-filter in parentheses followed by `=>[KNN k @embedding $vec]`. With `DIALECT 2`, Redis applies the filter first and KNN-ranks only the matching documents. + +```text +FT.SEARCH agentmem:idx + "(@user:{alice} @namespace:{default} @kind:{semantic}) + =>[KNN 5 @embedding $vec AS distance]" + PARAMS 2 vec <384-float32-bytes> + SORTBY distance + RETURN 8 user namespace kind source_thread text created_ts hit_count distance + DIALECT 2 +``` + +`distance` is the cosine *distance* (0 means identical, 2 means opposite). Recall and dedup share the same query shape; only the threshold differs — strict at write time so the index doesn't fill with paraphrases of the same fact, looser at read time so the agent gets a wider net of relevant memories. + +### Per-kind TTLs + +`remember` resolves the entry's TTL from the memory's `kind`: + +| Kind | Default TTL | When to use it | +|-----------|-------------|-------------------------------------------------------------| +| `episodic` | 7 days | Snapshots from a specific session that should decay. | +| `semantic` | none | Distilled facts and preferences the agent carries forward. | + +You can override per write by passing `Some(ttl)` as the last argument to `remember`. The defaults live in `long_term_memory::default_ttl_for_kind`; swap it out if you want a different tier map. + +## The event log + +`AgentEventLog` is a thin wrapper over a per-thread Redis Stream ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/rust/src/event_log.rs)): + +```rust +use agent_memory_demo::event_log::AgentEventLog; + +let events = AgentEventLog::new(&client, "agent:events:", 1000)?; +events.record(&thread_id, "turn_appended:user", + "Schedule a budget review with finance.")?; +events.record(&thread_id, "memory_written", + "wrote 7c3f8a1b9e02 as semantic")?; + +for e in events.recent(&thread_id, 20)? { + println!("{} {}", e.action, e.detail); +} +``` + +`record` calls [`XADD`]({{< relref "/commands/xadd" >}}) with `MAXLEN ~ 1000`. The tilde lets Redis trim in whole-node units instead of exactly-N units, which is much cheaper at the cost of overshooting the bound by up to a node's worth — the right tradeoff for an audit log where exact length doesn't matter. + +The Stream is independent of the session Hash and the long-term JSON documents: it answers "what just happened" without competing with either of those for indexing or memory budget. Consumer groups (not used in this demo) would let downstream workers — summarisers, consolidators, audit pipelines — replay the log without losing position. + +## Concurrency caveats + +The three helpers above trade correctness under heavy concurrency for clarity. Each is fine on a single-process demo, but lifting the code into a real multi-worker agent surfaces three races worth knowing about: + +* **Working memory is read-modify-write.** `AgentSession::append_turn` calls [`HGETALL`]({{< relref "/commands/hgetall" >}}), mutates the `recent_turns` vector in application code, and writes the Hash back with [`HSET`]({{< relref "/commands/hset" >}}). Two concurrent turns on the same thread can both read the same `recent_turns`, append different entries, and write back — last writer wins, the other turn is silently lost. The robust fix is either a [`WATCH`]({{< relref "/commands/watch" >}}) / [`MULTI`]({{< relref "/commands/multi" >}}) / [`EXEC`]({{< relref "/commands/exec" >}}) loop around the read-modify-write or a small [Lua script]({{< relref "/commands/eval" >}}) that does the append atomically server-side. (Each helper holds a `Mutex` that protects individual Redis commands against TCP-frame interleaving, but the lock is released between the `HGETALL` and the `HSET`, so it doesn't make the read-modify-write itself atomic.) + +* **Long-term dedup is not atomic.** `LongTermMemory::remember` runs a [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}) KNN lookup, decides whether the candidate is a duplicate, and (if not) calls [`JSON.SET`]({{< relref "/commands/json.set" >}}). Two workers seeing the same fact in flight can each fail to see the other's not-yet-committed write and both insert a new memory. The pragmatic fix is to accept that the index will occasionally hold near-duplicates and run a background consolidator that periodically scans for memory pairs within a tight distance and merges them, rather than trying to make the write itself atomic. + +* **The active thread is server state.** The demo server keeps a single `current_thread_id` (a `Mutex`) that `/new_thread` and `/reset` mutate. `handle_turn` reads it under the mutex but then drops the lock immediately, so a turn racing with a thread rotation can apply to the previous thread. This is cosmetic for a one-user browser demo. A multi-user agent would carry the thread id on the request itself rather than as shared server state. + +Those caveats are deliberate. A more conservative implementation would obscure the Redis-shaped parts of the pattern; the demo prioritizes a small, readable code path that maps directly onto the commands in the prose above. + +## Pre-seeding long-term memory + +In a real deployment the memory store fills up organically as the agent reasons over user turns: each turn produces zero or more memories that flow into the store, with deduplication catching repeats. For the demo, `src/seed_memory.rs` pre-loads a small set of mixed semantic and episodic memories so the very first recall query returns something useful ([source](https://github.com/redis/docs/blob/main/content/develop/use-cases/agent-memory/rust/src/seed_memory.rs)): + +```rust +use agent_memory_demo::seed_memory::seed; + +let memory = LongTermMemory::new(&client, "agentmem:idx", "agent:mem:", + 384, 0.20, 0.55)?; +let embedder = LocalEmbedder::new(None)?; +memory.create_index()?; +seed(&memory, &embedder, "default", "default", "seed")?; +``` + +The seed list mixes long-lived facts and preferences (`semantic`) with snapshots of past sessions (`episodic`), so the **Kind to write** control in the demo has something to switch between when a new turn is being remembered. + +## The interactive demo + +`src/main.rs` runs a [`tiny_http`](https://docs.rs/tiny_http) server on port 8094, spawning a worker thread per request. The HTML page exposes three live panels — working memory, recalled memories, event log — plus a memories table for admin actions. Endpoints: + +| Endpoint | What it does | +|---------------------|---------------------------------------------------------------------------------| +| `GET /state` | Index info, current session, in-scope long-term memories, and recent events. | +| `POST /turn` | Embed the text, append to working memory, recall similar memories, optionally write a new memory (with dedup), append an event. | +| `POST /new_thread` | Start a fresh thread; long-term memory and other threads are untouched. | +| `POST /reset` | Drop every long-term memory and re-seed the sample set. | +| `POST /drop_memory` | Delete a single long-term memory by id. | + +The server holds one `LocalEmbedder`, one `AgentSession`, one `LongTermMemory`, and one `AgentEventLog` for the lifetime of the process, wrapped in an `Arc` so the worker threads can share them without copying. The "current thread" is a `Mutex` that the **New thread** button rotates — every browser tab inherits the same thread until you explicitly start a new one. + +## Run the demo locally + +1. Clone the [`redis/docs`](https://github.com/redis/docs) repository and change into the example + directory: + + ```bash + git clone https://github.com/redis/docs.git + cd docs/content/develop/use-cases/agent-memory/rust + ``` + +2. Build the project. You'll need a recent [Rust toolchain](https://www.rust-lang.org/tools/install) + (1.75 or later): + + ```bash + cargo build --release + ``` + + Building in release mode matters here — Candle is CPU-bound on token throughput, + so the debug build is several times slower at the embedding step. + +3. Make sure a Redis instance with Redis Search and Redis JSON is running locally on + port 6379. [Redis Stack]({{< relref "/operate/oss_and_stack/install/install-stack" >}}) + ships both, or [Redis 8]({{< relref "/develop/ai/search-and-query" >}}) with the + Search and JSON modules enabled. + +4. Start the demo. The first run downloads the + [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) + weights, tokenizer, and config from the Hugging Face Hub into the local cache + (`~/.cache/huggingface/hub/` by default): + + ```bash + cargo run --release + ``` + +5. Open and try some turns: + + * **"Remind me which theme I prefer in editors."** — paraphrase of a seeded + semantic memory ("The user dislikes dark mode and prefers a high-contrast + light theme..."). You should see that memory recalled with a cosine + distance around 0.47, comfortably under the 0.55 default recall + threshold. + * **"What did we discuss about the order-routing outage?"** — paraphrase of + a seeded episodic memory; the postmortem memory should recall around + 0.44. Switch the **Kind to write** dropdown to `skip` so the question + itself doesn't enter long-term memory. + * **"I prefer concise answers without filler phrases."** — paraphrase of + a seeded *semantic* memory. Switch the **Kind to write** dropdown to + `semantic` so the dedup KNN runs in the same kind as the seed (dedup + is scoped per kind, on purpose, so an episodic write can't collapse + onto a semantic memory). You should then see the write **deduped** + onto the existing memory at a cosine distance around 0.15, with + `hit_count` ticking up in the memories table. + * **"My favorite color is teal."** — unrelated to any seed; nothing + recalls above the threshold (every seed lands above 0.8), and the new + memory is written as `episodic` (or `semantic`, depending on the + dropdown) under a fresh id. + * Switch the **User** field to `bob` and re-ask any of the above — recall + returns nothing because the seed memories live under `default`. That's + the TAG pre-filter at work inside [`FT.SEARCH`]({{< relref "/commands/ft.search" >}}). + * Slide the **Recall threshold** down to 0.30 to see borderline paraphrases + drop out of the recall set, then back up to 0.70 to watch them return. + + Candle reads the same PyTorch weights as the Python `sentence-transformers` + library, so distances here match the Python demo bit-for-bit. + `sentence-transformers/all-MiniLM-L6-v2` puts a faithful paraphrase in the + 0.15 – 0.50 cosine-distance range, a loose paraphrase or related topic in + the 0.50 – 0.80 range, and unrelated queries above 0.8 — which is what + motivates the 0.55 default recall threshold and the 0.20 default dedup + threshold. A stricter embedding model (or a domain-tuned one) would let + you tighten both; a noisier one would push them up. The right thresholds + are always a function of the model, the corpus, and how conservative the + agent needs to be about accepting a memory as a match. + +The server is read/write against your local Redis. The default memory index is `agentmem:idx`, JSON keys live under `agent:mem:`, session Hashes under `agent:session:`, and event Streams under `agent:events:`. Useful flags (pass them after `--`, for example `cargo run --release -- --no-reset`): + +* `--no-reset` — keep the existing long-term memories across restarts instead of dropping and re-seeding. +* `--session-ttl-seconds` — change the working-memory TTL (default 3600). +* `--dedup-threshold` — change the cosine-distance cutoff for write-time deduplication. +* `--recall-threshold` — change the default cosine-distance cutoff for recall. diff --git a/content/develop/use-cases/agent-memory/rust/index.html b/content/develop/use-cases/agent-memory/rust/index.html new file mode 100644 index 0000000000..0fa6d75825 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/index.html @@ -0,0 +1,550 @@ + + + + + + Redis Agent Memory Demo + + + +
+
loading…
+

Redis Agent Memory Demo

+

+ A small agent memory layer spread across three Redis primitives: + a per-thread Hash at __SESSION_PREFIX__<thread> + for working memory, JSON documents at + __MEM_PREFIX__<id> indexed by + __MEM_INDEX__ for long-term semantic recall (with + write-time deduplication), and a Stream at + __EVENT_PREFIX__<thread> for the time-ordered + action log. Send a turn and watch all three update in one + request. +

+ +
+ +
+

Send a turn

+

The server appends the turn to working memory, recalls the + top-k long-term memories by cosine similarity (scoped by the + user and namespace filter inside FT.SEARCH), + tries to write the turn back as a memory with deduplication + against existing entries of the same kind, and + appends one event to the stream.

+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + 0.55 +
+

+ A memory is included in the recall result only when its + cosine distance from the turn is at or below this + threshold. Lower = stricter (fewer false positives); + higher = looser (more recall, more noise). +

+ + + + +

Last write

+
(no writes yet)
+
+ +
+

Working memory

+

The per-thread Hash. One HGETALL returns the + whole session in a single round trip; the rolling turn window + keeps the hash size bounded.

+
+
+ +
+

Recalled memories

+

Top-k long-term memories matching the last turn, scored by + cosine distance from the turn's embedding.

+
+
+ +
+

Event log

+

Most recent entries from the thread's Redis Stream.

+
+
+ +
+

Index state

+
+ +
+ +
+

All long-term memories

+

Every JSON memory document in scope for the current user + and namespace. hit_count is the running total + of times a write was deduplicated onto this memory; + ttl is the remaining lifetime in seconds, or + when the memory has no TTL.

+ + + + + + + + + + + + +
IDKindTextHitsTTL
+
+ +
+ +
+
+ + + + diff --git a/content/develop/use-cases/agent-memory/rust/src/embeddings.rs b/content/develop/use-cases/agent-memory/rust/src/embeddings.rs new file mode 100644 index 0000000000..466e733bbb --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/embeddings.rs @@ -0,0 +1,304 @@ +//! Local text-embedding helper backed by Candle. +//! +//! This is a thin wrapper around the sentence-transformers model +//! `sentence-transformers/all-MiniLM-L6-v2`: a 384-dimensional BERT +//! encoder that runs in-process on CPU through Candle's pure-Rust +//! tensor backend, needs no API key, and produces vectors numerically +//! equivalent to the equivalent PyTorch model from +//! sentence-transformers. +//! +//! Two things matter for parity with the Python / Node / Go / Jedis +//! demos: +//! +//! 1. **Mean pooling with the attention mask.** sentence-transformers +//! computes the sentence vector as the attention-mask-weighted +//! average of the per-token last-hidden-state vectors, *not* the +//! `[CLS]` vector. Doing CLS-only here would produce numerically +//! different vectors and the published distance benchmarks (0.30 +//! for "How fast is delivery?", 0.49 for "How do I return an +//! item?") would drift. +//! 2. **Explicit L2 normalisation.** With normalised vectors, cosine +//! distance reduces to `1 - dot product`, which is what Redis +//! Search reports for our `COSINE` HNSW field. Without +//! normalisation, the distances would be in a different range and +//! the 0.5 default threshold would be meaningless. +//! +//! The model weights are fetched from the Hugging Face Hub on first +//! run via `hf-hub`. Subsequent runs read from the local cache. + +use std::path::PathBuf; + +use candle_core::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; +use hf_hub::api::sync::Api; +use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams, TruncationStrategy}; + +pub const DEFAULT_EMBED_MODEL: &str = "sentence-transformers/all-MiniLM-L6-v2"; + +#[derive(Debug)] +pub enum EmbedError { + Hub(String), + Io(std::io::Error), + Candle(candle_core::Error), + Tokenizer(String), + BatchMismatch { expected: usize, got: usize }, + Empty, +} + +impl std::fmt::Display for EmbedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EmbedError::Hub(msg) => write!(f, "hugging face hub: {}", msg), + EmbedError::Io(e) => write!(f, "io: {}", e), + EmbedError::Candle(e) => write!(f, "candle: {}", e), + EmbedError::Tokenizer(msg) => write!(f, "tokenizer: {}", msg), + EmbedError::BatchMismatch { expected, got } => write!( + f, + "pipeline returned {} vectors for {} inputs", + got, expected + ), + EmbedError::Empty => write!(f, "pipeline returned no embeddings"), + } + } +} + +impl std::error::Error for EmbedError {} + +impl From for EmbedError { + fn from(e: candle_core::Error) -> Self { + EmbedError::Candle(e) + } +} + +impl From for EmbedError { + fn from(e: std::io::Error) -> Self { + EmbedError::Io(e) + } +} + +/// Wraps a Candle BertModel + a HuggingFace tokenizer for a sentence +/// transformer. +pub struct LocalEmbedder { + pub model_name: String, + // `dim` is exposed so callers can sanity-check the model output + // against the Redis Search index dimension; the demo HTTP layer + // doesn't read it (it's hard-coded against VECTOR_DIM_DEFAULT). + #[allow(dead_code)] + pub dim: usize, + model: BertModel, + tokenizer: Tokenizer, + device: Device, +} + +impl LocalEmbedder { + /// Load the MiniLM model + tokenizer. Downloads them on first run + /// from the Hugging Face Hub into the local hf cache; later runs + /// load from disk only. + pub fn new(model_name: Option<&str>) -> Result { + let model_name = model_name.unwrap_or(DEFAULT_EMBED_MODEL).to_string(); + + let api = Api::new().map_err(|e| EmbedError::Hub(e.to_string()))?; + let repo = api.model(model_name.clone()); + + // The sentence-transformers MiniLM repo ships pytorch_model.bin + // + config.json + tokenizer.json. We deliberately use the + // PyTorch weights via candle's pickle reader rather than the + // safetensors mirror because (a) the canonical repo only + // publishes .bin, and (b) staying on the canonical repo means + // we hit the same weights as redis-py / nodejs / go / jedis + // demos. + let config_path: PathBuf = repo + .get("config.json") + .map_err(|e| EmbedError::Hub(e.to_string()))?; + let tokenizer_path: PathBuf = repo + .get("tokenizer.json") + .map_err(|e| EmbedError::Hub(e.to_string()))?; + let weights_path: PathBuf = repo + .get("pytorch_model.bin") + .map_err(|e| EmbedError::Hub(e.to_string()))?; + + let config_bytes = std::fs::read(&config_path)?; + let mut config: Config = + serde_json::from_slice(&config_bytes).map_err(|e| EmbedError::Hub(e.to_string()))?; + // sentence-transformers configs sometimes ship with hidden_act + // = "gelu" as a JSON string Candle parses as Gelu. MiniLM ships + // hidden_act="gelu", which already matches; force it just in + // case a downstream repo ships a non-standard value. + config.hidden_act = HiddenAct::Gelu; + + let mut tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| EmbedError::Tokenizer(e.to_string()))?; + // Match sentence-transformers' default: pad to the longest in + // the batch, truncate to the model's max length. Without the + // padding configuration set, single-example calls work fine + // but multi-example batches fail because Candle needs a + // rectangular tensor. + let pad_id = tokenizer + .get_padding() + .map(|p| p.pad_id) + .unwrap_or(0); + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: tokenizers::PaddingDirection::Right, + pad_to_multiple_of: None, + pad_id, + pad_type_id: 0, + pad_token: "[PAD]".to_string(), + })); + tokenizer + .with_truncation(Some(TruncationParams { + max_length: 512, + strategy: TruncationStrategy::LongestFirst, + stride: 0, + direction: tokenizers::TruncationDirection::Right, + })) + .map_err(|e| EmbedError::Tokenizer(e.to_string()))?; + + let device = Device::Cpu; + let vb = VarBuilder::from_pth(&weights_path, DTYPE, &device)?; + let model = BertModel::load(vb, &config)?; + + let dim = config.hidden_size; + + Ok(Self { + model_name, + dim, + model, + tokenizer, + device, + }) + } + + /// Returns a `dim`-element float32 vector for the input string, + /// L2-normalised. + pub fn encode_one(&self, text: &str) -> Result, EmbedError> { + let mut out = self.encode_many(&[text])?; + if out.is_empty() { + return Err(EmbedError::Empty); + } + Ok(out.remove(0)) + } + + /// Batch-encodes several strings in one forward pass so the model + /// pays the kernel-launch overhead once. Returns one vector per + /// input in the same order. Each vector is L2-normalised. + pub fn encode_many(&self, texts: &[&str]) -> Result>, EmbedError> { + if texts.is_empty() { + return Ok(Vec::new()); + } + + let encodings = self + .tokenizer + .encode_batch(texts.to_vec(), true) + .map_err(|e| EmbedError::Tokenizer(e.to_string()))?; + if encodings.len() != texts.len() { + return Err(EmbedError::BatchMismatch { + expected: texts.len(), + got: encodings.len(), + }); + } + + let batch_size = encodings.len(); + let seq_len = encodings.iter().map(|e| e.get_ids().len()).max().unwrap_or(0); + + let mut input_ids = Vec::with_capacity(batch_size * seq_len); + let mut attention_mask = Vec::with_capacity(batch_size * seq_len); + let mut token_type_ids = Vec::with_capacity(batch_size * seq_len); + for enc in &encodings { + input_ids.extend_from_slice(enc.get_ids()); + attention_mask.extend_from_slice(enc.get_attention_mask()); + token_type_ids.extend_from_slice(enc.get_type_ids()); + } + + let input_ids_t = Tensor::from_vec(input_ids, (batch_size, seq_len), &self.device)?; + let token_type_ids_t = + Tensor::from_vec(token_type_ids, (batch_size, seq_len), &self.device)?; + let attn_mask_u32 = attention_mask.clone(); + let attn_mask_t = Tensor::from_vec(attn_mask_u32, (batch_size, seq_len), &self.device)?; + + // Forward pass. Candle's BertModel takes input_ids, + // token_type_ids, and an optional attention_mask. We pass the + // mask so the encoder ignores padded positions. + let hidden = self.model.forward( + &input_ids_t, + &token_type_ids_t, + Some(&attn_mask_t), + )?; + + // Mean-pool with the attention mask. sentence-transformers + // computes the sentence vector as the mask-weighted average of + // per-token last-hidden-state vectors. Pseudocode: + // + // sum = (hidden * mask_expanded).sum(dim=1) + // counts = mask_expanded.sum(dim=1).clamp(min=1e-9) + // pooled = sum / counts + // + // The mask comes in as u32 / DTYPE-incompatible; convert to + // the model's DType and broadcast it across the hidden dim. + let mask_f = attn_mask_t.to_dtype(DTYPE)?; // (B, T) + let mask_expanded = mask_f.unsqueeze(2)?; // (B, T, 1) + let mask_expanded = mask_expanded.broadcast_as(hidden.shape())?; // (B, T, H) + + let masked = hidden.broadcast_mul(&mask_expanded)?; // (B, T, H) + let summed = masked.sum(1)?; // (B, H) + let counts = mask_f.sum(1)?; // (B,) + // Clamp the counts so an all-pad row (shouldn't happen, but be + // defensive) doesn't divide by zero. + let counts = counts.maximum(&Tensor::new(1e-9f32, &self.device)?.broadcast_as(counts.shape())?)?; + let counts = counts.unsqueeze(1)?; // (B, 1) + let pooled = summed.broadcast_div(&counts)?; // (B, H) + + // Extract as Vec> and L2-normalise each row in + // user-space so the demo's normalisation is explicit and + // visible in source. (Candle's tensor normalize helpers also + // exist but doing it by hand makes the docs example legible + // without a Candle deep-dive.) + let pooled_f32 = pooled.to_dtype(DType::F32)?; + let rows: Vec> = pooled_f32.to_vec2::()?; + if rows.len() != texts.len() { + return Err(EmbedError::BatchMismatch { + expected: texts.len(), + got: rows.len(), + }); + } + + let mut out = Vec::with_capacity(rows.len()); + for mut row in rows { + normalize_in_place(&mut row); + out.push(row); + } + Ok(out) + } +} + +/// L2-normalises a vector so it has unit length. A zero vector is left +/// untouched (its cosine distance to anything is undefined, but at +/// least Redis won't reject the bytes). +fn normalize_in_place(v: &mut [f32]) { + let mut sum_sq: f64 = 0.0; + for &x in v.iter() { + sum_sq += (x as f64) * (x as f64); + } + if sum_sq == 0.0 { + return; + } + let inv = (1.0 / sum_sq.sqrt()) as f32; + for x in v.iter_mut() { + *x *= inv; + } +} + +/// Packs a `&[f32]` into the raw little-endian byte sequence Redis +/// Search expects for a FLOAT32 vector field. We use +/// `byteorder::LittleEndian` (via `write_f32`) rather than relying on +/// `f32::to_le_bytes` so the encoding contract is visible in source +/// and consistent with the Go demo's `binary.LittleEndian.PutUint32`. +pub fn floats_to_bytes(fs: &[f32]) -> Vec { + use byteorder::{LittleEndian, WriteBytesExt}; + let mut buf = Vec::with_capacity(fs.len() * 4); + for &f in fs { + buf.write_f32::(f).expect("Vec write never fails"); + } + buf +} diff --git a/content/develop/use-cases/agent-memory/rust/src/event_log.rs b/content/develop/use-cases/agent-memory/rust/src/event_log.rs new file mode 100644 index 0000000000..73c89bd5ce --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/event_log.rs @@ -0,0 +1,217 @@ +//! Append-only event log for an agent thread, backed by a Redis +//! Stream. +//! +//! Each thread gets a stream at `agent:events:{thread_id}`. Every +//! action the agent takes (a user turn arriving, a memory being +//! recalled, a memory being written, a tool being called) is one +//! `XADD` to that stream. Replay with `XREVRANGE` for the most recent +//! N events; bound retention with `XTRIM MAXLEN ~` so the log stays +//! cheap regardless of how long the thread has been running. +//! +//! The stream is independent of the session hash and the long-term +//! memory store: it answers the "what just happened" question +//! without competing with either of those for indexing or memory +//! budget. Consumer groups (not used in this demo) would let +//! downstream workers — summarisers, consolidators, audit pipelines — +//! replay the log without losing position. + +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +use redis::{Client, Connection, FromRedisValue, RedisError, Value}; +use serde::Serialize; + +/// Approximate cap on stream length. `MAXLEN ~` lets Redis trim in +/// whole-node units instead of exactly-N units, which is much cheaper +/// at the cost of overshooting the bound by up to a node's worth. +pub const DEFAULT_MAX_LEN: i64 = 1000; + +#[derive(Debug)] +pub enum EventLogError { + Redis(RedisError), +} + +impl std::fmt::Display for EventLogError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EventLogError::Redis(e) => write!(f, "redis: {}", e), + } + } +} + +impl std::error::Error for EventLogError {} + +impl From for EventLogError { + fn from(e: RedisError) -> Self { + EventLogError::Redis(e) + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct AgentEvent { + pub event_id: String, + pub thread_id: String, + pub action: String, + pub detail: String, + pub ts: f64, +} + +pub struct AgentEventLog { + conn: Mutex, + pub key_prefix: String, + pub max_len: i64, +} + +impl AgentEventLog { + pub fn new( + client: &Client, + key_prefix: impl Into, + max_len: i64, + ) -> Result { + let conn = client.get_connection()?; + Ok(Self { + conn: Mutex::new(conn), + key_prefix: key_prefix.into(), + max_len, + }) + } + + pub fn stream_key(&self, thread_id: &str) -> String { + format!("{}{}", self.key_prefix, thread_id) + } + + /// Append one event and return its stream id. + /// + /// `MAXLEN ~ N` keeps the stream bounded with near-zero overhead; + /// an exact bound (`MAXLEN N` without the tilde) forces a scan + /// and is rarely worth the cost. + pub fn record( + &self, + thread_id: &str, + action: &str, + detail: &str, + ) -> Result { + let ts = unix_secs(); + let id: String = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("XADD") + .arg(self.stream_key(thread_id)) + .arg("MAXLEN") + .arg("~") + .arg(self.max_len) + .arg("*") + .arg("action").arg(action) + .arg("detail").arg(detail) + .arg("ts").arg(format!("{}", ts)) + .query(&mut *con)? + }; + Ok(id) + } + + /// Return the most recent events, newest first. + pub fn recent( + &self, + thread_id: &str, + count: usize, + ) -> Result, EventLogError> { + let value: Value = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("XREVRANGE") + .arg(self.stream_key(thread_id)) + .arg("+") + .arg("-") + .arg("COUNT") + .arg(count as i64) + .query(&mut *con)? + }; + Ok(parse_xrange(&value, thread_id)) + } + + #[allow(dead_code)] + pub fn length(&self, thread_id: &str) -> Result { + let n: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("XLEN").arg(self.stream_key(thread_id)).query(&mut *con)? + }; + Ok(n) + } + + pub fn clear(&self, thread_id: &str) -> Result { + let n: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("DEL").arg(self.stream_key(thread_id)).query(&mut *con)? + }; + Ok(n > 0) + } +} + +fn parse_xrange(value: &Value, thread_id: &str) -> Vec { + // XRANGE / XREVRANGE shape: [[id, [field, val, field, val, ...]], ...] + let items = match value { + Value::Array(items) => items, + _ => return Vec::new(), + }; + let mut out = Vec::with_capacity(items.len()); + for item in items { + let pair = match item { + Value::Array(pair) => pair, + _ => continue, + }; + if pair.len() < 2 { + continue; + } + let id = match redis_value_to_string(&pair[0]) { + Some(s) => s, + None => continue, + }; + let fields = match &pair[1] { + Value::Array(fs) => fs, + _ => continue, + }; + let mut action = String::new(); + let mut detail = String::new(); + let mut ts: f64 = 0.0; + let mut iter = fields.iter(); + while let Some(k) = iter.next() { + let v = match iter.next() { + Some(v) => v, + None => break, + }; + let key = redis_value_to_string(k).unwrap_or_default(); + let val = redis_value_to_string(v).unwrap_or_default(); + match key.as_str() { + "action" => action = val, + "detail" => detail = val, + "ts" => ts = val.parse::().unwrap_or(0.0), + _ => {} + } + } + out.push(AgentEvent { + event_id: id, + thread_id: thread_id.to_string(), + action, + detail, + ts, + }); + } + out +} + +fn redis_value_to_string(v: &Value) -> Option { + match v { + Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).into_owned()), + Value::SimpleString(s) => Some(s.clone()), + Value::VerbatimString { format: _, text } => Some(text.clone()), + Value::Int(n) => Some(n.to_string()), + Value::Double(d) => Some(d.to_string()), + Value::Nil => None, + _ => String::from_redis_value(v).ok(), + } +} + +fn unix_secs() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0) +} diff --git a/content/develop/use-cases/agent-memory/rust/src/lib.rs b/content/develop/use-cases/agent-memory/rust/src/lib.rs new file mode 100644 index 0000000000..1b19aa0327 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/lib.rs @@ -0,0 +1,12 @@ +//! Library entry point for the agent-memory demo. +//! +//! `src/main.rs` is the runnable binary; this file re-exports the +//! same modules as a library crate so the snippets in the walkthrough +//! (`use agent_memory_demo::session_store::AgentSession;`, and so on) +//! resolve against this package without needing a separate workspace. + +pub mod embeddings; +pub mod event_log; +pub mod long_term_memory; +pub mod seed_memory; +pub mod session_store; diff --git a/content/develop/use-cases/agent-memory/rust/src/long_term_memory.rs b/content/develop/use-cases/agent-memory/rust/src/long_term_memory.rs new file mode 100644 index 0000000000..3681ab8698 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/long_term_memory.rs @@ -0,0 +1,658 @@ +//! Long-term memory store for an agent, backed by Redis JSON and +//! Search. +//! +//! Each memory lives as one JSON document at `agent:mem:`. The +//! document holds the memory text, its embedding vector, and a small +//! metadata block — user, namespace, kind, source thread, timestamps +//! — that lets the recall query scope results without falling back +//! to application-side filtering. +//! +//! A single Redis Search index covers the embedding plus every +//! metadata field, so one `FT.SEARCH` call performs approximate- +//! nearest-neighbour over the in-scope subset and returns the top-k +//! memories ranked by cosine distance. The same KNN check runs at +//! *write* time to deduplicate near-identical memories before they +//! enter the store, which keeps the index from filling with +//! paraphrases of the same fact as the agent reasons over similar +//! topics across sessions. +//! +//! Memories carry one of two kinds: +//! +//! * `episodic` — "what happened" snapshots from a specific thread, +//! written with a medium TTL so old session detail decays +//! naturally. +//! * `semantic` — distilled facts and preferences the agent should +//! carry forward indefinitely. Written with no TTL by default. +//! +//! The split is enforced as a TAG on the index, so the recall query +//! can ask for one kind or both with a filter — no separate +//! keyspaces. + +use std::collections::HashMap; +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +use redis::{Client, Connection, FromRedisValue, RedisError, Value}; +use serde::Serialize; +use serde_json::json; + +use crate::embeddings::floats_to_bytes; + +pub const VECTOR_DIM_DEFAULT: usize = 384; + +/// How close (cosine distance) a candidate must be to an existing +/// memory to count as a duplicate at write time. Smaller = stricter. +/// 0.20 is calibrated to the `sentence-transformers/all-MiniLM-L6-v2` +/// embedding model used in the demo, where a paraphrase of an +/// existing memory lands in the 0.10 – 0.20 range and a distinct +/// memory lands above 0.50. +pub const DEFAULT_DEDUP_THRESHOLD: f64 = 0.20; + +/// How close (cosine distance) a candidate must be to count as a +/// relevant recall result. Larger than the dedup threshold so the +/// agent gets a wider net at read time than at write time. +pub const DEFAULT_RECALL_THRESHOLD: f64 = 0.55; + +/// TTL tiers, in seconds. `None` means "no TTL" — the memory +/// persists until explicitly deleted or evicted under memory +/// pressure. +pub fn default_ttl_for_kind(kind: &str) -> Option { + match kind { + "episodic" => Some(7 * 24 * 3600), + "semantic" => None, + _ => None, + } +} + +#[derive(Debug)] +pub enum MemoryError { + Redis(RedisError), + ShapeMismatch { expected: usize, got: usize }, + Parse(String), +} + +impl std::fmt::Display for MemoryError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MemoryError::Redis(e) => write!(f, "redis: {}", e), + MemoryError::ShapeMismatch { expected, got } => write!( + f, + "embedding has dimension {}; index expects {}", + got, expected + ), + MemoryError::Parse(msg) => write!(f, "parse: {}", msg), + } + } +} + +impl std::error::Error for MemoryError {} + +impl From for MemoryError { + fn from(e: RedisError) -> Self { + MemoryError::Redis(e) + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct MemoryRecord { + pub id: String, + pub user: String, + pub namespace: String, + pub kind: String, + pub source_thread: String, + pub text: String, + pub created_ts: f64, + pub hit_count: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub distance: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl_seconds: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct WriteResult { + pub id: String, + pub deduped: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub existing_distance: Option, +} + +#[derive(Debug, Clone, Default, Serialize)] +pub struct IndexSnapshot { + pub num_docs: i64, + pub indexing_failures: i64, +} + +pub struct LongTermMemory { + conn: Mutex, + pub index_name: String, + pub key_prefix: String, + pub vector_dim: usize, + pub dedup_threshold: f64, + pub recall_threshold: f64, +} + +impl LongTermMemory { + pub fn new( + client: &Client, + index_name: impl Into, + key_prefix: impl Into, + vector_dim: usize, + dedup_threshold: f64, + recall_threshold: f64, + ) -> Result { + let conn = client.get_connection()?; + Ok(Self { + conn: Mutex::new(conn), + index_name: index_name.into(), + key_prefix: key_prefix.into(), + vector_dim, + dedup_threshold, + recall_threshold, + }) + } + + pub fn memory_key(&self, memory_id: &str) -> String { + format!("{}{}", self.key_prefix, memory_id) + } + + /// Create the Redis Search index if it doesn't already exist. + /// + /// The index is declared on the JSON document type with alias + /// names on each path; the same `FT.SEARCH` filter clause works + /// here as on a HASH-backed index, and the field paths + /// (`$.user`, `$.embedding`, ...) only show up in `FT.CREATE`. + pub fn create_index(&self) -> Result<(), MemoryError> { + let result: Result = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("FT.CREATE") + .arg(&self.index_name) + .arg("ON").arg("JSON") + .arg("PREFIX").arg(1).arg(&self.key_prefix) + .arg("SCHEMA") + .arg("$.text").arg("AS").arg("text").arg("TEXT") + .arg("$.user").arg("AS").arg("user").arg("TAG") + .arg("$.namespace").arg("AS").arg("namespace").arg("TAG") + .arg("$.kind").arg("AS").arg("kind").arg("TAG") + .arg("$.source_thread").arg("AS").arg("source_thread").arg("TAG") + .arg("$.created_ts").arg("AS").arg("created_ts") + .arg("NUMERIC").arg("SORTABLE") + .arg("$.hit_count").arg("AS").arg("hit_count") + .arg("NUMERIC").arg("SORTABLE") + .arg("$.embedding").arg("AS").arg("embedding") + .arg("VECTOR").arg("HNSW").arg(6) + .arg("TYPE").arg("FLOAT32") + .arg("DIM").arg(self.vector_dim as i64) + .arg("DISTANCE_METRIC").arg("COSINE") + .query(&mut *con) + }; + match result { + Ok(_) => Ok(()), + Err(e) => { + let msg = e.to_string().to_lowercase(); + if msg.contains("exists") && msg.contains("index") { + Ok(()) + } else { + Err(MemoryError::Redis(e)) + } + } + } + } + + pub fn drop_index(&self, delete_documents: bool) -> Result<(), MemoryError> { + let mut con = self.conn.lock().unwrap(); + let mut cmd = redis::cmd("FT.DROPINDEX"); + cmd.arg(&self.index_name); + if delete_documents { + cmd.arg("DD"); + } + match cmd.query::(&mut *con) { + Ok(_) => Ok(()), + Err(e) => { + let msg = e.to_string().to_lowercase(); + if msg.contains("no such index") || msg.contains("unknown index name") { + Ok(()) + } else { + Err(MemoryError::Redis(e)) + } + } + } + } + + /// Write a new memory, deduplicating against existing entries. + /// + /// Runs one in-scope `KNN(1)` against the index first. If the + /// nearest existing memory is within `dedup_threshold`, the new + /// memory is skipped (its content is already represented) and + /// the existing memory's `hit_count` is bumped. Otherwise a + /// fresh JSON document is written under a new id with a TTL + /// derived from the memory's `kind`. + /// + /// The KNN-then-write sequence is not atomic; two workers that + /// remember the same fact at the same time can both miss each + /// other's in-flight write and insert duplicate memories. See + /// the walkthrough's "Concurrency caveats" section for the + /// production fix. + pub fn remember( + &self, + text: &str, + embedding: &[f32], + user: &str, + namespace: &str, + kind: &str, + source_thread: &str, + ttl_seconds: Option, + ) -> Result { + if embedding.len() != self.vector_dim { + return Err(MemoryError::ShapeMismatch { + expected: self.vector_dim, + got: embedding.len(), + }); + } + + let nearest = self.nearest(embedding, Some(user), Some(namespace), Some(kind), 1)?; + let existing_distance = nearest.first().and_then(|r| r.distance); + if let Some(first) = nearest.first() { + if let Some(d) = first.distance { + if d <= self.dedup_threshold { + self.bump_hit_count(&first.id)?; + return Ok(WriteResult { + id: first.id.clone(), + deduped: true, + existing_distance, + }); + } + } + } + + let id = new_id_12(); + let key = self.memory_key(&id); + let now = unix_secs(); + let doc = json!({ + "id": id, + "user": user, + "namespace": namespace, + "kind": kind, + "source_thread": source_thread, + "text": text, + "embedding": embedding, + "created_ts": now, + "hit_count": 0, + }); + let ttl = ttl_seconds.or_else(|| default_ttl_for_kind(kind)); + + { + let mut con = self.conn.lock().unwrap(); + redis::cmd("JSON.SET") + .arg(&key) + .arg("$") + .arg(doc.to_string()) + .query::(&mut *con)?; + if let Some(t) = ttl { + redis::cmd("EXPIRE").arg(&key).arg(t).query::(&mut *con)?; + } + } + Ok(WriteResult { + id, + deduped: false, + existing_distance, + }) + } + + /// Return the top-k in-scope memories ranked by similarity. + /// Memories beyond `distance_threshold` (or the instance default) + /// are dropped — the index always returns *something* for KNN, + /// so a recall result on an unrelated query would otherwise be a + /// confidently-wrong false positive. + pub fn recall( + &self, + query_embedding: &[f32], + user: &str, + namespace: Option<&str>, + kind: Option<&str>, + k: usize, + distance_threshold: Option, + ) -> Result, MemoryError> { + let threshold = distance_threshold.unwrap_or(self.recall_threshold); + let candidates = self.nearest(query_embedding, Some(user), namespace, kind, k)?; + Ok(candidates + .into_iter() + .filter(|c| c.distance.is_some_and(|d| d <= threshold)) + .collect()) + } + + fn nearest( + &self, + embedding: &[f32], + user: Option<&str>, + namespace: Option<&str>, + kind: Option<&str>, + k: usize, + ) -> Result, MemoryError> { + if embedding.len() != self.vector_dim { + return Err(MemoryError::ShapeMismatch { + expected: self.vector_dim, + got: embedding.len(), + }); + } + let filter_clause = build_filter_clause(user, namespace, kind); + let query_str = format!("{}=>[KNN {} @embedding $vec AS distance]", filter_clause, k); + let vec_bytes = floats_to_bytes(embedding); + + let value: Value = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("FT.SEARCH") + .arg(&self.index_name) + .arg(&query_str) + .arg("PARAMS").arg(2).arg("vec").arg(&vec_bytes[..]) + .arg("SORTBY").arg("distance").arg("ASC") + .arg("RETURN").arg(8) + .arg("user").arg("namespace").arg("kind").arg("source_thread") + .arg("text").arg("created_ts").arg("hit_count").arg("distance") + .arg("LIMIT").arg(0).arg(k as i64) + .arg("DIALECT").arg(2) + .query(&mut *con)? + }; + let docs = parse_ft_search(&value)?; + let mut out = Vec::with_capacity(docs.len()); + for doc in docs { + // `doc.id` is the full Redis key (e.g. + // `agent:mem:abc123`). Strip the prefix so the returned + // record exposes only the opaque id the UI and + // `delete_memory` work with. + let memory_id = doc + .id + .strip_prefix(&self.key_prefix) + .unwrap_or(&doc.id) + .to_string(); + let ttl: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("TTL").arg(self.memory_key(&memory_id)).query(&mut *con).unwrap_or(-2) + }; + let ttl_seconds = if ttl > 0 { Some(ttl) } else { None }; + let distance = doc.field("distance").and_then(|s| s.parse::().ok()); + out.push(MemoryRecord { + id: memory_id, + user: doc.field("user").unwrap_or_default().to_string(), + namespace: doc.field("namespace").unwrap_or_default().to_string(), + kind: doc.field("kind").unwrap_or_default().to_string(), + source_thread: doc.field("source_thread").unwrap_or_default().to_string(), + text: doc.field("text").unwrap_or_default().to_string(), + created_ts: doc + .field("created_ts") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0), + hit_count: doc + .field("hit_count") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + distance, + ttl_seconds, + }); + } + Ok(out) + } + + fn bump_hit_count(&self, memory_id: &str) -> Result<(), MemoryError> { + let mut con = self.conn.lock().unwrap(); + // The doc may have expired between recall and bump — fine, + // we just lose the hit count update. Discarding the error + // keeps the demo from blowing up on that race. + let _ = redis::cmd("JSON.NUMINCRBY") + .arg(self.memory_key(memory_id)) + .arg("$.hit_count") + .arg(1) + .query::(&mut *con); + Ok(()) + } + + pub fn index_info(&self) -> IndexSnapshot { + let mut con = match self.conn.lock() { + Ok(c) => c, + Err(_) => return IndexSnapshot::default(), + }; + let value: Value = match redis::cmd("FT.INFO") + .arg(&self.index_name) + .query(&mut *con) + { + Ok(v) => v, + Err(_) => return IndexSnapshot::default(), + }; + parse_ft_info(&value) + } + + pub fn list_memories( + &self, + user: Option<&str>, + namespace: Option<&str>, + kind: Option<&str>, + limit: usize, + ) -> Result, MemoryError> { + let filter_clause = build_filter_clause(user, namespace, kind); + let value: Value = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("FT.SEARCH") + .arg(&self.index_name) + .arg(&filter_clause) + .arg("RETURN").arg(7) + .arg("user").arg("namespace").arg("kind").arg("source_thread") + .arg("text").arg("created_ts").arg("hit_count") + .arg("SORTBY").arg("created_ts").arg("DESC") + .arg("LIMIT").arg(0).arg(limit as i64) + .arg("DIALECT").arg(2) + .query(&mut *con)? + }; + let docs = parse_ft_search(&value)?; + let mut out = Vec::with_capacity(docs.len()); + for doc in docs { + let memory_id = doc + .id + .strip_prefix(&self.key_prefix) + .unwrap_or(&doc.id) + .to_string(); + let ttl: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("TTL").arg(self.memory_key(&memory_id)).query(&mut *con).unwrap_or(-2) + }; + let ttl_seconds = if ttl > 0 { Some(ttl) } else { None }; + out.push(MemoryRecord { + id: memory_id, + user: doc.field("user").unwrap_or_default().to_string(), + namespace: doc.field("namespace").unwrap_or_default().to_string(), + kind: doc.field("kind").unwrap_or_default().to_string(), + source_thread: doc.field("source_thread").unwrap_or_default().to_string(), + text: doc.field("text").unwrap_or_default().to_string(), + created_ts: doc + .field("created_ts") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0), + hit_count: doc + .field("hit_count") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + distance: None, + ttl_seconds, + }); + } + Ok(out) + } + + pub fn delete_memory(&self, memory_id: &str) -> Result { + let n: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("DEL").arg(self.memory_key(memory_id)).query(&mut *con)? + }; + Ok(n > 0) + } + + /// Drop the index and every memory document, then recreate the + /// index. Returns the count of documents that were removed. + pub fn clear(&self) -> Result { + let before = self.index_info().num_docs; + self.drop_index(true)?; + self.create_index()?; + Ok(before) + } +} + +// ---- FT.SEARCH / FT.INFO response parsing ---------------------------- + +#[derive(Debug)] +pub struct SearchDoc { + pub id: String, + pub fields: HashMap, +} + +impl SearchDoc { + fn field(&self, name: &str) -> Option<&str> { + self.fields.get(name).map(|s| s.as_str()) + } +} + +fn parse_ft_search(value: &Value) -> Result, MemoryError> { + let items = match value { + Value::Array(items) => items, + _ => return Err(MemoryError::Parse("FT.SEARCH did not return an array".into())), + }; + if items.is_empty() { + return Ok(vec![]); + } + let mut out = Vec::new(); + let mut iter = items.iter().skip(1); + while let Some(key_value) = iter.next() { + let key = redis_value_to_string(key_value) + .ok_or_else(|| MemoryError::Parse("key is not a string".into()))?; + let fields_value = match iter.next() { + Some(v) => v, + None => { + out.push(SearchDoc { id: key, fields: HashMap::new() }); + continue; + } + }; + let field_items: Vec<&Value> = match fields_value { + Value::Array(v) => v.iter().collect(), + _ => { + out.push(SearchDoc { id: key, fields: HashMap::new() }); + continue; + } + }; + let mut fields = HashMap::new(); + let mut f_iter = field_items.into_iter(); + while let Some(name_val) = f_iter.next() { + let name = match redis_value_to_string(name_val) { + Some(s) => s, + None => continue, + }; + let value = f_iter.next().and_then(redis_value_to_string).unwrap_or_default(); + fields.insert(name, value); + } + out.push(SearchDoc { id: key, fields }); + } + Ok(out) +} + +fn parse_ft_info(value: &Value) -> IndexSnapshot { + let items = match value { + Value::Array(items) => items, + _ => return IndexSnapshot::default(), + }; + let mut info = IndexSnapshot::default(); + let mut iter = items.iter(); + while let Some(k) = iter.next() { + let key = redis_value_to_string(k).unwrap_or_default(); + let v = match iter.next() { + Some(v) => v, + None => break, + }; + match key.as_str() { + "num_docs" => { + info.num_docs = match v { + Value::Int(n) => *n, + _ => redis_value_to_string(v) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + }; + } + "hash_indexing_failures" => { + info.indexing_failures = match v { + Value::Int(n) => *n, + _ => redis_value_to_string(v) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + }; + } + _ => {} + } + } + info +} + +fn redis_value_to_string(v: &Value) -> Option { + match v { + Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).into_owned()), + Value::SimpleString(s) => Some(s.clone()), + Value::VerbatimString { format: _, text } => Some(text.clone()), + Value::Int(n) => Some(n.to_string()), + Value::Double(d) => Some(d.to_string()), + Value::Boolean(b) => Some(b.to_string()), + Value::Nil => None, + _ => String::from_redis_value(v).ok(), + } +} + +// ---- Filter clause helpers ----------------------------------------- + +/// Characters Redis Search treats as syntax inside a TAG value; any +/// of them in a user-supplied filter must be backslash-escaped or +/// the surrounding `{...}` block won't parse correctly. +const TAG_SPECIAL: &str = "\\,.<>{}[]\"':;!@#$%^&*()-+=~| "; + +pub fn escape_tag_value(v: &str) -> String { + let mut out = String::with_capacity(v.len()); + for ch in v.chars() { + if TAG_SPECIAL.contains(ch) { + out.push('\\'); + } + out.push(ch); + } + out +} + +pub fn build_filter_clause( + user: Option<&str>, + namespace: Option<&str>, + kind: Option<&str>, +) -> String { + let mut clauses = Vec::new(); + if let Some(u) = user.filter(|s| !s.is_empty()) { + clauses.push(format!("@user:{{{}}}", escape_tag_value(u))); + } + if let Some(n) = namespace.filter(|s| !s.is_empty()) { + clauses.push(format!("@namespace:{{{}}}", escape_tag_value(n))); + } + if let Some(k) = kind.filter(|s| !s.is_empty()) { + clauses.push(format!("@kind:{{{}}}", escape_tag_value(k))); + } + if clauses.is_empty() { + "(*)".to_string() + } else { + format!("({})", clauses.join(" ")) + } +} + +fn unix_secs() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0) +} + +fn new_id_12() -> String { + let mut buf = [0u8; 6]; + getrandom::getrandom(&mut buf).expect("getrandom never fails on supported platforms"); + let mut s = String::with_capacity(12); + for b in buf { + s.push_str(&format!("{:02x}", b)); + } + s +} diff --git a/content/develop/use-cases/agent-memory/rust/src/main.rs b/content/develop/use-cases/agent-memory/rust/src/main.rs new file mode 100644 index 0000000000..db3187d3c1 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/main.rs @@ -0,0 +1,649 @@ +//! Redis agent-memory demo server (Rust + redis-rs + Candle). +//! +//! Run this and visit `http://localhost:8094` to drive a small +//! agent-memory demo backed by Redis Hashes, JSON, Search, and +//! Streams. The UI lets you type a turn, watch working memory update, +//! see semantically similar long-term memories recalled, watch the +//! write-time deduplication skip near-duplicates, and inspect the +//! per-thread event log. +//! +//! The server holds a single [`LocalEmbedder`], one [`AgentSession`], +//! one [`LongTermMemory`], and one [`AgentEventLog`] for the lifetime +//! of the process. The first run downloads the embedding model into +//! the local Hugging Face cache; everything after is local. + +use std::env; +use std::fs; +use std::io::Read; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +use serde::Serialize; +use serde_json::{json, Value as JsonValue}; +use tiny_http::{Header, Method, Response, Server, StatusCode}; +use url::form_urlencoded; + +// Pull the helpers in through the library crate so the same module +// paths the walkthrough quotes (e.g. `agent_memory_demo::session_store`) +// also work inside the binary. +use agent_memory_demo::embeddings::LocalEmbedder; +use agent_memory_demo::event_log::AgentEventLog; +use agent_memory_demo::long_term_memory::{LongTermMemory, MemoryRecord, WriteResult}; +use agent_memory_demo::seed_memory::seed as seed_memories; +use agent_memory_demo::session_store::{AgentSession, SessionState}; + +const STACK_LABEL: &str = + "redis-rs + tiny_http + Candle (BERT / sentence-transformers MiniLM)"; + +// 1 MiB cap on POST bodies so a runaway client (or a `curl +// --data-binary @big-file` by mistake) can't accumulate unbounded +// memory before the handler runs. +const MAX_BODY_BYTES: usize = 1 * 1024 * 1024; + +fn main() { + let args = match Args::parse(env::args().collect::>()) { + Ok(a) => a, + Err(e) => { + eprintln!("Error: {}", e); + print_help(); + std::process::exit(2); + } + }; + + let client = match redis::Client::open( + format!("redis://{}:{}/", args.redis_host, args.redis_port).as_str(), + ) { + Ok(c) => c, + Err(e) => { + eprintln!( + "Error: cannot reach Redis at {}:{}", + args.redis_host, args.redis_port + ); + eprintln!(" ({})", e); + std::process::exit(1); + } + }; + + // Ping eagerly so a "Redis not running" error appears before the + // (slow) model download starts. + match client.get_connection().and_then(|mut c| redis::cmd("PING").query::(&mut c)) { + Ok(_) => {} + Err(e) => { + eprintln!( + "Error: cannot reach Redis at {}:{}", + args.redis_host, args.redis_port + ); + eprintln!(" ({})", e); + std::process::exit(1); + } + } + + let session = match AgentSession::new( + &client, + args.session_key_prefix.clone(), + args.session_ttl_seconds, + 20, + ) { + Ok(s) => s, + Err(e) => { + eprintln!("Error creating session store: {}", e); + std::process::exit(1); + } + }; + let memory = match LongTermMemory::new( + &client, + args.mem_index_name.clone(), + args.mem_key_prefix.clone(), + 384, + args.dedup_threshold, + args.recall_threshold, + ) { + Ok(m) => m, + Err(e) => { + eprintln!("Error creating memory store: {}", e); + std::process::exit(1); + } + }; + if let Err(e) = memory.create_index() { + eprintln!("Error creating index: {}", e); + std::process::exit(1); + } + let events = match AgentEventLog::new(&client, args.event_key_prefix.clone(), 1000) { + Ok(e) => e, + Err(e) => { + eprintln!("Error creating event log: {}", e); + std::process::exit(1); + } + }; + + println!( + "Loading embedding model (first run downloads weights from the Hugging Face Hub)..." + ); + let embedder = match LocalEmbedder::new(None) { + Ok(e) => e, + Err(e) => { + eprintln!("Error loading embedder: {}", e); + std::process::exit(1); + } + }; + + let demo = Arc::new(AgentMemoryDemo::new(session, memory, events, embedder)); + + if args.reset_on_start { + println!( + "Dropping any existing memories under '{}*' and re-seeding from the \ + sample memory list (pass --no-reset to keep).", + args.mem_key_prefix + ); + match demo.seed_all("default", "default") { + Ok(n) => println!("Seeded {} memories.", n), + Err(e) => { + eprintln!("Error during initial seed: {}", e); + std::process::exit(1); + } + } + } + + // Load index.html once and substitute the template tokens so the + // docs panel shows the actual values in use. + let html_path = locate_index_html(); + let raw_html = match fs::read_to_string(&html_path) { + Ok(s) => s, + Err(e) => { + eprintln!("Could not read index.html at {}: {}", html_path.display(), e); + std::process::exit(1); + } + }; + let html_page = raw_html + .replace("__SESSION_PREFIX__", &args.session_key_prefix) + .replace("__MEM_PREFIX__", &args.mem_key_prefix) + .replace("__MEM_INDEX__", &args.mem_index_name) + .replace("__EVENT_PREFIX__", &args.event_key_prefix); + let html_page = Arc::new(html_page); + + let addr = format!("{}:{}", args.host, args.port); + let server = match Server::http(&addr) { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to bind {}: {}", addr, e); + std::process::exit(1); + } + }; + println!("Redis agent memory demo listening on http://{}", addr); + println!( + "Using Redis at {}:{} with memory index '{}'", + args.redis_host, args.redis_port, args.mem_index_name + ); + + for request in server.incoming_requests() { + let demo = Arc::clone(&demo); + let html = Arc::clone(&html_page); + std::thread::spawn(move || { + if let Err(e) = handle_request(request, demo.as_ref(), html.as_str()) { + eprintln!("[demo] handler error: {}", e); + } + }); + } +} + +fn locate_index_html() -> PathBuf { + // Look beside the binary first (when ``cargo run`` puts us in + // target/{debug,release}/), then in the project root. The + // project-root fallback is what makes ``cargo run`` work + // straight from the example directory without copying files. + let exe = env::current_exe().ok(); + if let Some(exe_path) = exe { + if let Some(dir) = exe_path.parent() { + let candidate = dir.join("index.html"); + if candidate.exists() { + return candidate; + } + } + } + if let Ok(cwd) = env::current_dir() { + let candidate = cwd.join("index.html"); + if candidate.exists() { + return candidate; + } + } + PathBuf::from("index.html") +} + +// ---- Request dispatch ----------------------------------------------- + +fn handle_request( + mut request: tiny_http::Request, + demo: &AgentMemoryDemo, + html_page: &str, +) -> std::io::Result<()> { + let url = request.url().to_string(); + let method = request.method().clone(); + let (path, query) = split_path_query(&url); + + let response = match (&method, path.as_str()) { + (Method::Get, "/") | (Method::Get, "/index.html") => { + Response::from_string(html_page.to_string()) + .with_status_code(StatusCode(200)) + .with_header(html_header()) + .boxed() + } + (Method::Get, "/state") => { + let qs = parse_form(&query); + let user = qs.get("user").cloned().unwrap_or_else(|| "default".to_string()); + let namespace = qs.get("namespace").cloned().unwrap_or_else(|| "default".to_string()); + json_response(200, &demo.build_state(&user, &namespace)) + } + (Method::Post, "/turn") => { + let body = read_body(&mut request)?; + let form = parse_form(&body); + let text = form.get("text").cloned().unwrap_or_default(); + let text = text.trim(); + if text.is_empty() { + json_response(400, &json!({ "error": "text is required" })) + } else { + let threshold = clamp_threshold(form.get("threshold"), demo.memory.recall_threshold); + let payload = demo.handle_turn( + text, + form.get("user").map(String::as_str).unwrap_or("default"), + form.get("namespace").map(String::as_str).unwrap_or("default"), + form.get("kind").map(String::as_str).unwrap_or("episodic"), + form.get("role").map(String::as_str).unwrap_or("user"), + threshold, + form.get("action").map(String::as_str).unwrap_or("turn"), + ); + match payload { + Ok(p) => json_response(200, &p), + Err(e) => json_response(500, &json!({ "error": e.to_string() })), + } + } + } + (Method::Post, "/new_thread") => { + let body = read_body(&mut request)?; + let form = parse_form(&body); + let user = form.get("user").cloned().unwrap_or_else(|| "default".to_string()); + let namespace = form.get("namespace").cloned().unwrap_or_else(|| "default".to_string()); + match demo.new_thread(&user, &namespace) { + Ok(tid) => json_response(200, &json!({ "thread_id": tid })), + Err(e) => json_response(500, &json!({ "error": e })), + } + } + (Method::Post, "/reset") => { + let body = read_body(&mut request)?; + let form = parse_form(&body); + let user = form.get("user").cloned().unwrap_or_else(|| "default".to_string()); + let namespace = form.get("namespace").cloned().unwrap_or_else(|| "default".to_string()); + match demo.seed_all(&user, &namespace) { + Ok(n) => json_response(200, &json!({ "seeded": n })), + Err(e) => json_response(500, &json!({ "error": e })), + } + } + (Method::Post, "/drop_memory") => { + let body = read_body(&mut request)?; + let form = parse_form(&body); + let memory_id = form + .get("memory_id") + .cloned() + .unwrap_or_default() + .trim() + .to_string(); + if memory_id.is_empty() { + json_response(400, &json!({ "error": "memory_id is required" })) + } else { + match demo.memory.delete_memory(&memory_id) { + Ok(deleted) => json_response(200, &json!({ + "deleted": deleted, + "memory_id": memory_id, + })), + Err(e) => json_response(500, &json!({ "error": e.to_string() })), + } + } + } + _ => json_response(404, &json!({ "error": "not found" })), + }; + + request.respond(response) +} + +fn read_body(request: &mut tiny_http::Request) -> std::io::Result { + let mut buf = Vec::new(); + request.as_reader().take(MAX_BODY_BYTES as u64 + 1).read_to_end(&mut buf)?; + if buf.len() > MAX_BODY_BYTES { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("request body exceeds {} bytes", MAX_BODY_BYTES), + )); + } + Ok(String::from_utf8_lossy(&buf).into_owned()) +} + +fn parse_form(s: &str) -> std::collections::HashMap { + form_urlencoded::parse(s.as_bytes()) + .into_owned() + .collect() +} + +fn split_path_query(url: &str) -> (String, String) { + if let Some(idx) = url.find('?') { + (url[..idx].to_string(), url[idx + 1..].to_string()) + } else { + (url.to_string(), String::new()) + } +} + +fn json_response(status: u16, value: &T) -> tiny_http::ResponseBox { + let body = serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()); + Response::from_string(body) + .with_status_code(StatusCode(status)) + .with_header(json_header()) + .boxed() +} + +fn html_header() -> Header { + Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]).unwrap() +} + +fn json_header() -> Header { + Header::from_bytes(&b"Content-Type"[..], &b"application/json"[..]).unwrap() +} + +fn clamp_threshold(raw: Option<&String>, fallback: f64) -> f64 { + let parsed = raw + .and_then(|s| s.parse::().ok()) + .filter(|d| d.is_finite()); + let v = parsed.unwrap_or(fallback); + v.clamp(0.0, 2.0) +} + +// ---- Demo orchestrator ---------------------------------------------- + +/// Demo state: working memory, long-term memory, event log. +/// +/// `current_thread_id` is wrapped in a `Mutex`, but the lock +/// is released after each rotation or read — a turn racing with +/// `/new_thread` or `/reset` can therefore capture the old id and +/// apply to the previous thread. The demo is single-user in +/// practice, so the race never triggers; a multi-user agent would +/// carry the thread id on each request instead of holding it as +/// shared server state. See the walkthrough's "Concurrency caveats" +/// section. +pub struct AgentMemoryDemo { + pub session: AgentSession, + pub memory: LongTermMemory, + pub events: AgentEventLog, + pub embedder: LocalEmbedder, + pub default_user: String, + pub default_namespace: String, + pub current_thread_id: Mutex, +} + +impl AgentMemoryDemo { + pub fn new( + session: AgentSession, + memory: LongTermMemory, + events: AgentEventLog, + embedder: LocalEmbedder, + ) -> Self { + let tid = session.new_thread_id(); + Self { + session, + memory, + events, + embedder, + default_user: "default".to_string(), + default_namespace: "default".to_string(), + current_thread_id: Mutex::new(tid), + } + } + + pub fn seed_all(&self, user: &str, namespace: &str) -> Result { + self.memory.clear().map_err(|e| e.to_string())?; + let tid = self.current_thread_id.lock().unwrap().clone(); + self.session.delete(&tid).map_err(|e| e.to_string())?; + self.events.clear(&tid).map_err(|e| e.to_string())?; + let written = seed_memories(&self.memory, &self.embedder, user, namespace, "seed") + .map_err(|e| e.to_string())?; + *self.current_thread_id.lock().unwrap() = self.session.new_thread_id(); + Ok(written) + } + + pub fn new_thread(&self, user: &str, namespace: &str) -> Result { + let tid = self.current_thread_id.lock().unwrap().clone(); + self.events.clear(&tid).map_err(|e| e.to_string())?; + let new_id = self.session.new_thread_id(); + self.session + .start(&new_id, user, "demo-agent", "", None) + .map_err(|e| e.to_string())?; + self.events + .record( + &new_id, + "thread_started", + &format!("user={} namespace={}", user, namespace), + ) + .map_err(|e| e.to_string())?; + *self.current_thread_id.lock().unwrap() = new_id.clone(); + Ok(new_id) + } + + /// One pass through the agent loop: append, recall, remember, + /// log. + /// + /// The order matters. We embed once and reuse the vector for + /// both the recall and (if asked) the remember step — no point + /// encoding the same text twice. Recall runs *before* the + /// remember write so the agent doesn't see its own just-written + /// turn as a recalled memory. + pub fn handle_turn( + &self, + text: &str, + user: &str, + namespace: &str, + kind: &str, + role: &str, + threshold: f64, + action: &str, + ) -> Result> { + let thread_id = self.current_thread_id.lock().unwrap().clone(); + + let t0 = Instant::now(); + let vec = self.embedder.encode_one(text)?; + let embed_ms = elapsed_ms(t0); + + // `set_goal` only touches the goal field so existing turns + // aren't wiped; `append_turn` carries the request `user` + // through to the auto-create path so a first turn for a new + // thread doesn't land under the default user. + let session_action = if action == "goal" { + self.session + .set_goal(&thread_id, text, Some(user), Some("demo-agent"), None)?; + "goal_set".to_string() + } else { + self.session.append_turn( + &thread_id, + role, + text, + Some(user), + Some("demo-agent"), + None, + )?; + format!("turn_appended:{}", role) + }; + + let t1 = Instant::now(); + let recalled = self.memory.recall( + &vec, + user, + Some(namespace), + None, + 5, + Some(threshold), + )?; + let recall_ms = elapsed_ms(t1); + + let write_skipped = kind == "skip" || action == "goal"; + let mut write_result: Option = None; + let mut write_ms = 0.0_f64; + if !write_skipped { + let t2 = Instant::now(); + write_result = Some(self.memory.remember( + text, + &vec, + user, + namespace, + kind, + &thread_id, + None, + )?); + write_ms = elapsed_ms(t2); + } + + let detail = match &write_result { + Some(w) if w.deduped => format!("deduped onto {}", w.id), + Some(w) => format!("wrote {} as {}", w.id, kind), + None => String::new(), + }; + self.events.record(&thread_id, &session_action, &detail)?; + + Ok(json!({ + "thread_id": thread_id, + "write_skipped": write_skipped, + "memory_id": write_result.as_ref().map(|w| w.id.clone()), + "deduped": write_result.as_ref().map(|w| w.deduped).unwrap_or(false), + "existing_distance": write_result.as_ref().and_then(|w| w.existing_distance), + "kind": if write_skipped { JsonValue::Null } else { JsonValue::String(kind.to_string()) }, + "recalled": recalled, + "embed_ms": embed_ms, + "recall_ms": recall_ms, + "write_ms": write_ms, + })) + } + + pub fn build_state(&self, user: &str, namespace: &str) -> JsonValue { + let info = self.memory.index_info(); + let thread_id = self.current_thread_id.lock().unwrap().clone(); + let session: Option = self.session.load(&thread_id).ok().flatten(); + let memories: Vec = self + .memory + .list_memories(Some(user), Some(namespace), None, 200) + .unwrap_or_default(); + let events = self.events.recent(&thread_id, 20).unwrap_or_default(); + + json!({ + "index": { + "num_docs": info.num_docs, + "indexing_failures": info.indexing_failures, + "index_name": self.memory.index_name, + "model": self.embedder.model_name, + "session_ttl_seconds": self.session.default_ttl_seconds, + "dedup_threshold": self.memory.dedup_threshold, + "default_recall_threshold": self.memory.recall_threshold, + "stack_label": STACK_LABEL, + }, + "thread_id": thread_id, + "session": session, + "memories": memories, + "events": events, + // `recalled` is populated by /turn; on plain /state reads + // the UI keeps showing the last turn's result. + "recalled": [], + }) + } +} + +fn elapsed_ms(start: Instant) -> f64 { + start.elapsed().as_secs_f64() * 1000.0 +} + +// ---- Arg parsing ---------------------------------------------------- + +#[derive(Debug, Clone)] +struct Args { + host: String, + port: u16, + redis_host: String, + redis_port: u16, + mem_index_name: String, + mem_key_prefix: String, + session_key_prefix: String, + event_key_prefix: String, + session_ttl_seconds: i64, + dedup_threshold: f64, + recall_threshold: f64, + reset_on_start: bool, +} + +impl Args { + fn parse(argv: Vec) -> Result { + let mut host = "127.0.0.1".to_string(); + let mut port: u16 = 8094; + let mut redis_host = "localhost".to_string(); + let mut redis_port: u16 = 6379; + let mut mem_index = "agentmem:idx".to_string(); + let mut mem_prefix = "agent:mem:".to_string(); + let mut session_prefix = "agent:session:".to_string(); + let mut event_prefix = "agent:events:".to_string(); + let mut session_ttl: i64 = 3600; + let mut dedup: f64 = 0.20; + let mut recall: f64 = 0.55; + let mut reset = true; + + let mut iter = argv.into_iter().skip(1); + while let Some(a) = iter.next() { + match a.as_str() { + "--host" => host = iter.next().ok_or("missing value for --host")?, + "--port" => port = iter.next().ok_or("missing value for --port")? + .parse().map_err(|e: std::num::ParseIntError| e.to_string())?, + "--redis-host" => redis_host = iter.next().ok_or("missing value for --redis-host")?, + "--redis-port" => redis_port = iter.next().ok_or("missing value for --redis-port")? + .parse().map_err(|e: std::num::ParseIntError| e.to_string())?, + "--mem-index-name" => mem_index = iter.next().ok_or("missing value for --mem-index-name")?, + "--mem-key-prefix" => mem_prefix = iter.next().ok_or("missing value for --mem-key-prefix")?, + "--session-key-prefix" => session_prefix = iter.next().ok_or("missing value for --session-key-prefix")?, + "--event-key-prefix" => event_prefix = iter.next().ok_or("missing value for --event-key-prefix")?, + "--session-ttl-seconds" => session_ttl = iter.next().ok_or("missing value for --session-ttl-seconds")? + .parse().map_err(|e: std::num::ParseIntError| e.to_string())?, + "--dedup-threshold" => dedup = iter.next().ok_or("missing value for --dedup-threshold")? + .parse().map_err(|e: std::num::ParseFloatError| e.to_string())?, + "--recall-threshold" => recall = iter.next().ok_or("missing value for --recall-threshold")? + .parse().map_err(|e: std::num::ParseFloatError| e.to_string())?, + "--no-reset" => reset = false, + "--help" | "-h" => return Err("help requested".to_string()), + other => return Err(format!("unknown flag: {}", other)), + } + } + + Ok(Self { + host, + port, + redis_host, + redis_port, + mem_index_name: mem_index, + mem_key_prefix: mem_prefix, + session_key_prefix: session_prefix, + event_key_prefix: event_prefix, + session_ttl_seconds: session_ttl, + dedup_threshold: dedup, + recall_threshold: recall, + reset_on_start: reset, + }) + } +} + +fn print_help() { + eprintln!( + "Usage: agent-memory-demo [flags]\n\ + \n\ + --host HTTP bind host (default 127.0.0.1)\n\ + --port HTTP bind port (default 8094)\n\ + --redis-host Redis host (default localhost)\n\ + --redis-port Redis port (default 6379)\n\ + --mem-index-name Memory index name (default agentmem:idx)\n\ + --mem-key-prefix JSON memory key prefix (default agent:mem:)\n\ + --session-key-prefix Session hash key prefix (default agent:session:)\n\ + --event-key-prefix Event stream key prefix (default agent:events:)\n\ + --session-ttl-seconds Working memory TTL (default 3600)\n\ + --dedup-threshold Cosine distance for dedup (default 0.20)\n\ + --recall-threshold Cosine distance for recall (default 0.55)\n\ + --no-reset Skip clearing and re-seeding on startup" + ); +} diff --git a/content/develop/use-cases/agent-memory/rust/src/seed_memory.rs b/content/develop/use-cases/agent-memory/rust/src/seed_memory.rs new file mode 100644 index 0000000000..138307ed4e --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/seed_memory.rs @@ -0,0 +1,97 @@ +//! Pre-seed the long-term memory store with sample memories. +//! +//! In a real deployment the memory store fills up organically as the +//! agent reasons over user turns: each turn produces zero or more +//! memories (preferences, facts, episodic summaries) that flow into +//! the store with deduplication. To make the demo immediately useful +//! — so the first recall query lands on relevant results instead of +//! an empty list — we seed a small set of canonical memories for a +//! default user at startup. +//! +//! The seed list mixes `semantic` memories (long-lived preferences +//! and facts) with `episodic` memories (snapshots of past sessions), +//! matching what the Python, Node, and .NET demos seed so all four +//! implementations behave identically. + +use crate::embeddings::LocalEmbedder; +use crate::long_term_memory::{LongTermMemory, MemoryError}; + +pub struct SeedEntry { + pub text: &'static str, + pub kind: &'static str, +} + +pub const SEED_MEMORIES: &[SeedEntry] = &[ + SeedEntry { + text: "The user prefers concise answers without filler phrases.", + kind: "semantic", + }, + SeedEntry { + text: "The user is a Python developer working on a logistics platform.", + kind: "semantic", + }, + SeedEntry { + text: "The user lives in Berlin and works in the Europe/Berlin time zone.", + kind: "semantic", + }, + SeedEntry { + text: "The user dislikes dark mode and prefers a high-contrast light \ + theme in editors and dashboards.", + kind: "semantic", + }, + SeedEntry { + text: "The user is allergic to peanuts; any restaurant suggestion must \ + avoid dishes that commonly contain them.", + kind: "semantic", + }, + SeedEntry { + text: "Last Tuesday the user asked the agent to draft a postmortem for \ + the order-routing outage. The agent produced a five-section \ + draft and the user approved sections 1, 2, and 4 with minor \ + edits.", + kind: "episodic", + }, + SeedEntry { + text: "In a previous session the user asked for help debugging a flaky \ + test in the inventory service. The fix turned out to be a race \ + condition in the warehouse webhook handler.", + kind: "episodic", + }, + SeedEntry { + text: "Two weeks ago the user mentioned they were planning to migrate \ + the analytics warehouse from Snowflake to BigQuery in Q3.", + kind: "episodic", + }, +]; + +/// Embed and write the seed memories. Returns the count actually +/// written (entries that dedup against existing memories don't +/// count). +pub fn seed( + memory: &LongTermMemory, + embedder: &LocalEmbedder, + user: &str, + namespace: &str, + source_thread: &str, +) -> Result { + let texts: Vec<&str> = SEED_MEMORIES.iter().map(|e| e.text).collect(); + let vectors = embedder + .encode_many(&texts) + .map_err(|e| MemoryError::Parse(e.to_string()))?; + let mut written = 0usize; + for (entry, vec) in SEED_MEMORIES.iter().zip(vectors.iter()) { + let result = memory.remember( + entry.text, + vec, + user, + namespace, + entry.kind, + source_thread, + None, + )?; + if !result.deduped { + written += 1; + } + } + Ok(written) +} diff --git a/content/develop/use-cases/agent-memory/rust/src/session_store.rs b/content/develop/use-cases/agent-memory/rust/src/session_store.rs new file mode 100644 index 0000000000..80b245a9a6 --- /dev/null +++ b/content/develop/use-cases/agent-memory/rust/src/session_store.rs @@ -0,0 +1,350 @@ +//! Working-memory store for an agent session, backed by a Redis Hash. +//! +//! Each session is one Hash document at `agent:session:{thread_id}`. +//! The hash holds the running scratchpad, the current goal, a rolling +//! window of recent turns (serialised as a JSON list to fit in one +//! field), and a few audit fields. One `HGETALL` returns the whole +//! session in a single round trip on every step of the agent loop. +//! +//! Every write refreshes the key's TTL with `EXPIRE`, so idle sessions +//! fall off without a separate cleanup job and active sessions stay +//! alive as long as the agent keeps touching them. A separate +//! `LongTermMemory` is what survives beyond a session's TTL. +//! +//! The turn window is bounded to `max_turns` in application code; the +//! hash itself doesn't grow, so the working set per thread stays +//! constant regardless of how long the agent has been running. + +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +use redis::{Client, Connection, FromRedisValue, RedisError, Value}; +use serde::{Deserialize, Serialize}; + +/// How many recent turns to keep inline on the session hash. Older +/// turns flow through the event log (see [`crate::event_log`]) and +/// the long-term memory store (see [`crate::long_term_memory`]). +pub const DEFAULT_MAX_TURNS: usize = 20; + +#[derive(Debug)] +pub enum SessionError { + Redis(RedisError), + Parse(String), +} + +impl std::fmt::Display for SessionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SessionError::Redis(e) => write!(f, "redis: {}", e), + SessionError::Parse(msg) => write!(f, "parse: {}", msg), + } + } +} + +impl std::error::Error for SessionError {} + +impl From for SessionError { + fn from(e: RedisError) -> Self { + SessionError::Redis(e) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionTurn { + pub role: String, + pub content: String, + pub ts: f64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct SessionState { + pub thread_id: String, + pub user: String, + pub agent: String, + pub goal: String, + pub scratchpad: String, + pub turn_count: i64, + pub created_ts: f64, + pub last_active_ts: f64, + pub recent_turns: Vec, + pub ttl_seconds: i64, +} + +pub struct AgentSession { + conn: Mutex, + pub key_prefix: String, + pub default_ttl_seconds: i64, + pub max_turns: usize, +} + +impl AgentSession { + pub fn new( + client: &Client, + key_prefix: impl Into, + default_ttl_seconds: i64, + max_turns: usize, + ) -> Result { + let conn = client.get_connection()?; + Ok(Self { + conn: Mutex::new(conn), + key_prefix: key_prefix.into(), + default_ttl_seconds, + max_turns, + }) + } + + pub fn session_key(&self, thread_id: &str) -> String { + format!("{}{}", self.key_prefix, thread_id) + } + + pub fn new_thread_id(&self) -> String { + new_id_12() + } + + /// Create a fresh working memory for a thread. Overwrites any + /// existing session at the same key. The agent normally calls + /// this once per thread at the first turn and relies on + /// [`load`](Self::load) / [`append_turn`](Self::append_turn) for + /// subsequent steps. + pub fn start( + &self, + thread_id: &str, + user: &str, + agent: &str, + goal: &str, + ttl_seconds: Option, + ) -> Result { + let ttl = ttl_seconds.unwrap_or(self.default_ttl_seconds); + let now = unix_secs(); + let state = SessionState { + thread_id: thread_id.to_string(), + user: user.to_string(), + agent: agent.to_string(), + goal: goal.to_string(), + scratchpad: String::new(), + turn_count: 0, + created_ts: now, + last_active_ts: now, + recent_turns: Vec::new(), + ttl_seconds: ttl, + }; + self.write(&state, ttl)?; + Ok(state) + } + + /// Return the session state, or `None` if it has expired. + pub fn load(&self, thread_id: &str) -> Result, SessionError> { + let key = self.session_key(thread_id); + let raw: Vec<(String, Value)> = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("HGETALL").arg(&key).query(&mut *con)? + }; + if raw.is_empty() { + return Ok(None); + } + let mut fields = std::collections::HashMap::new(); + for (k, v) in raw { + if let Some(s) = redis_value_to_string(&v) { + fields.insert(k, s); + } + } + let ttl: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("TTL").arg(&key).query(&mut *con).unwrap_or(0) + }; + let recent_turns = match fields.get("recent_turns") { + Some(blob) => serde_json::from_str::>(blob).unwrap_or_default(), + None => Vec::new(), + }; + Ok(Some(SessionState { + thread_id: thread_id.to_string(), + user: fields.remove("user").unwrap_or_else(|| "default".to_string()), + agent: fields.remove("agent").unwrap_or_else(|| "default".to_string()), + goal: fields.remove("goal").unwrap_or_default(), + scratchpad: fields.remove("scratchpad").unwrap_or_default(), + turn_count: fields + .get("turn_count") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0), + created_ts: fields + .get("created_ts") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0), + last_active_ts: fields + .get("last_active_ts") + .and_then(|s| s.parse::().ok()) + .unwrap_or(0.0), + recent_turns, + ttl_seconds: if ttl > 0 { ttl } else { 0 }, + })) + } + + /// Append a turn, bound the rolling window, refresh the TTL. + /// + /// `user` and `agent` are only consulted when the session does + /// not yet exist — they seed the auto-created session so the + /// working-memory hash matches the user the caller is operating + /// against. On an existing session they're ignored; the original + /// `start` values stand. + /// + /// Read-modify-write here is last-writer-wins on the turn list + /// if two concurrent turns reach the same thread; the demo never + /// triggers that race in practice (one browser, one turn at a + /// time) but a multi-worker agent that shares a thread id would + /// wrap this in `WATCH` / `MULTI` / `EXEC` or a Lua script that + /// does the append atomically server-side. + pub fn append_turn( + &self, + thread_id: &str, + role: &str, + content: &str, + user: Option<&str>, + agent: Option<&str>, + ttl_seconds: Option, + ) -> Result { + let mut state = match self.load(thread_id)? { + Some(s) => s, + None => self.start( + thread_id, + user.unwrap_or("default"), + agent.unwrap_or("default"), + "", + ttl_seconds, + )?, + }; + state.recent_turns.push(SessionTurn { + role: role.to_string(), + content: content.to_string(), + ts: unix_secs(), + }); + if state.recent_turns.len() > self.max_turns { + let drop_count = state.recent_turns.len() - self.max_turns; + state.recent_turns.drain(0..drop_count); + } + state.turn_count += 1; + state.last_active_ts = unix_secs(); + let ttl = ttl_seconds.unwrap_or(self.default_ttl_seconds); + state.ttl_seconds = ttl; + self.write(&state, ttl)?; + Ok(state) + } + + /// Update the agent's running scratchpad and refresh the TTL. + /// Returns `None` when the session does not exist. + #[allow(dead_code)] + pub fn set_scratchpad( + &self, + thread_id: &str, + text: &str, + ttl_seconds: Option, + ) -> Result, SessionError> { + let mut state = match self.load(thread_id)? { + Some(s) => s, + None => return Ok(None), + }; + state.scratchpad = text.to_string(); + state.last_active_ts = unix_secs(); + let ttl = ttl_seconds.unwrap_or(self.default_ttl_seconds); + state.ttl_seconds = ttl; + self.write(&state, ttl)?; + Ok(Some(state)) + } + + /// Update the goal field without touching turns or the + /// scratchpad. Creates the session if it doesn't exist yet — + /// setting a goal on a fresh thread is a sensible first step in + /// the agent loop, so this method covers both the "rename the + /// goal mid-session" and the "start a thread with this goal" + /// cases. + pub fn set_goal( + &self, + thread_id: &str, + text: &str, + user: Option<&str>, + agent: Option<&str>, + ttl_seconds: Option, + ) -> Result { + let mut state = match self.load(thread_id)? { + Some(s) => s, + None => { + return self.start( + thread_id, + user.unwrap_or("default"), + agent.unwrap_or("default"), + text, + ttl_seconds, + ); + } + }; + state.goal = text.to_string(); + state.last_active_ts = unix_secs(); + let ttl = ttl_seconds.unwrap_or(self.default_ttl_seconds); + state.ttl_seconds = ttl; + self.write(&state, ttl)?; + Ok(state) + } + + pub fn delete(&self, thread_id: &str) -> Result { + let n: i64 = { + let mut con = self.conn.lock().unwrap(); + redis::cmd("DEL").arg(self.session_key(thread_id)).query(&mut *con)? + }; + Ok(n > 0) + } + + fn write(&self, state: &SessionState, ttl: i64) -> Result<(), SessionError> { + let key = self.session_key(&state.thread_id); + let turns_blob = serde_json::to_string(&state.recent_turns) + .map_err(|e| SessionError::Parse(e.to_string()))?; + // MULTI/EXEC so HSET and EXPIRE either both apply or neither + // does. A connection drop between the two writes would + // otherwise leave the session without a TTL. + let mut con = self.conn.lock().unwrap(); + redis::pipe() + .atomic() + .cmd("HSET") + .arg(&key) + .arg("thread_id").arg(&state.thread_id) + .arg("user").arg(&state.user) + .arg("agent").arg(&state.agent) + .arg("goal").arg(&state.goal) + .arg("scratchpad").arg(&state.scratchpad) + .arg("turn_count").arg(state.turn_count.to_string()) + .arg("created_ts").arg(format!("{}", state.created_ts)) + .arg("last_active_ts").arg(format!("{}", state.last_active_ts)) + .arg("recent_turns").arg(turns_blob) + .cmd("EXPIRE").arg(&key).arg(ttl) + .query::(&mut *con)?; + Ok(()) + } +} + +fn unix_secs() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0) +} + +fn new_id_12() -> String { + let mut buf = [0u8; 6]; + getrandom::getrandom(&mut buf).expect("getrandom never fails on supported platforms"); + let mut s = String::with_capacity(12); + for b in buf { + s.push_str(&format!("{:02x}", b)); + } + s +} + +fn redis_value_to_string(v: &Value) -> Option { + match v { + Value::BulkString(bytes) => Some(String::from_utf8_lossy(bytes).into_owned()), + Value::SimpleString(s) => Some(s.clone()), + Value::VerbatimString { format: _, text } => Some(text.clone()), + Value::Int(n) => Some(n.to_string()), + Value::Double(d) => Some(d.to_string()), + Value::Boolean(b) => Some(b.to_string()), + Value::Nil => None, + _ => String::from_redis_value(v).ok(), + } +}