diff --git a/README.md b/README.md index 2bf607c..6dc07d8 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ sqlite-memory bridges these concepts, allowing any SQLite-powered application to - **Hybrid Search**: Combines vector similarity (cosine distance) with FTS5 full-text search for superior retrieval - **Smart Chunking**: Markdown-aware parsing preserves semantic boundaries - **Intelligent Sync**: Content-hash change detection skips unchanged files, atomically replaces modified ones, and cleans up deleted ones -- **Transactional Safety**: Every sync operation runs inside a SAVEPOINT transaction - either fully succeeds or fully rolls back, no partially-indexed content +- **Transactional Safety**: Text/file ingests run inside SAVEPOINT transactions, and directory sync uses transactional cleanup plus per-file transactional updates so failed files do not leave partial rows behind - **Efficient Storage**: Binary embeddings with configurable dimensions - **Embedding Cache**: Automatically caches computed embeddings, so re-indexing the same text skips redundant API calls and computation - **Flexible Embedding**: Use local models (llama.cpp) or [vectors.space](https://vectors.space) remote API @@ -61,6 +61,9 @@ sqlite-memory bridges these concepts, allowing any SQLite-powered application to ## Getting Started +> [!IMPORTANT] +> Databases created with sqlite-memory versions earlier than `1.0.0` must be rebuilt before use with `1.0.0+`, because the internal schema changed. + ### Prerequisites - SQLite @@ -74,7 +77,7 @@ sqlite-memory bridges these concepts, allowing any SQLite-powered application to ```sql -- Load extensions (sync is optional) .load ./vector -.load ./sync +.load ./cloudsync .load ./memory -- Configure embedding model (choose one): @@ -84,8 +87,8 @@ SELECT memory_set_model('local', '/path/to/nomic-embed-text-v1.5.Q8_0.gguf'); -- Option 2: Remote embedding via vectors.space (requires free API key from https://vectors.space) -- The provider name 'openai' selects the vectors.space OpenAI-compatible endpoint. --- SELECT memory_set_model('openai', 'text-embedding-3-small'); -- SELECT memory_set_apikey('your-vectorspace-api-key'); +-- SELECT memory_set_model('openai', 'text-embedding-3-small'); -- Add some knowledge SELECT memory_add_text('SQLite is a C-language library that implements a small, fast, @@ -160,7 +163,7 @@ All `memory_add_*` functions use content-hash change detection to avoid redundan 1. **Cleanup**: Removes database entries for files that no longer exist on disk 2. **Scan**: Recursively processes all matching files - adding new ones, replacing modified ones, and skipping unchanged ones -Every sync operation is wrapped in a SQLite SAVEPOINT transaction. If anything fails mid-sync (embedding error, disk issue, etc.), the entire operation rolls back cleanly. There is no risk of partially-indexed files or orphaned entries. +`memory_add_text()` and `memory_add_file()` each run inside a SQLite SAVEPOINT transaction. `memory_add_directory()` performs its cleanup pass transactionally and then processes each file in its own transaction. If one file fails, that file rolls back cleanly and previously-committed files remain valid; there are no partially-indexed rows or orphaned chunk/FTS entries for the failed file. This makes all sync functions safe to call repeatedly - for example, on a cron schedule or at agent startup - with minimal overhead. @@ -258,8 +261,8 @@ FROM dbmem_content; -- Delete by context SELECT memory_delete_context('old-project'); --- Delete specific memory -SELECT memory_delete(1234567890); +-- Delete specific memory by hash +SELECT memory_delete('9e3779b97f4a7c15'); -- Clear all memories SELECT memory_clear(); @@ -279,8 +282,11 @@ cd sqlite-memory # Build (full build with local + remote engines) make -# Run tests +# Run parser/core unit tests + extension loading smoke test make test + +# Run the full SQL extension unit suite +make test DEFINES="-DTEST_SQLITE_EXTENSION" ``` ### Build Configurations diff --git a/src/dbmem-embed.h b/src/dbmem-embed.h index 985950f..0ddc66d 100644 --- a/src/dbmem-embed.h +++ b/src/dbmem-embed.h @@ -29,6 +29,7 @@ void dbmem_local_engine_free (dbmem_local_engine_t *engine); dbmem_remote_engine_t *dbmem_remote_engine_init (void *ctx, const char *provider, const char *model, char err_msg[DBMEM_ERRBUF_SIZE]); int dbmem_remote_compute_embedding (dbmem_remote_engine_t *engine, const char *text, int text_len, embedding_result_t *result); +int dbmem_remote_engine_set_apikey (dbmem_remote_engine_t *engine, const char *api_key, char err_msg[DBMEM_ERRBUF_SIZE]); void dbmem_remote_engine_free (dbmem_remote_engine_t *engine); // Custom provider (always available, defined in sqlite-memory.c) diff --git a/src/dbmem-lembed.c b/src/dbmem-lembed.c index ce5685d..ce0f0e1 100644 --- a/src/dbmem-lembed.c +++ b/src/dbmem-lembed.c @@ -100,9 +100,15 @@ void dbmem_logger (enum ggml_log_level level, const char *text, void *user_data) // MARK: - +static void dbmem_local_set_error(dbmem_local_engine_t *engine, const char *message) { + if (!engine || !engine->context) return; + dbmem_context_set_error(engine->context, message); +} + dbmem_local_engine_t *dbmem_local_engine_init (void *ctx, const char *model_path, char err_msg[DBMEM_ERRBUF_SIZE]) { dbmem_local_engine_t *engine = (dbmem_local_engine_t *)dbmemory_zeroalloc(sizeof(dbmem_local_engine_t)); if (!engine) return NULL; + engine->context = (dbmem_context *)ctx; // set logger llama_log_set(dbmem_logger, engine); @@ -212,7 +218,7 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex // Tokenize int n_tokens = llama_tokenize(engine->vocab, text, text_len, engine->tokens, engine->tokens_capacity, true, true); if (n_tokens < 0) { - dbmem_context_set_error(engine->context, "Tokenization failed (text too long?)"); + dbmem_local_set_error(engine, "Tokenization failed (text too long?)"); return -1; } @@ -242,7 +248,7 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex // Encode int ret = llama_encode(engine->ctx, batch); if (ret != 0) { - dbmem_context_set_error(engine->context, "Llama_encode failed"); + dbmem_local_set_error(engine, "Llama_encode failed"); return -1; } @@ -255,7 +261,7 @@ int dbmem_local_compute_embedding (dbmem_local_engine_t *engine, const char *tex } if (!emb_ptr) { - dbmem_context_set_error(engine->context, "Failed to get embeddings"); + dbmem_local_set_error(engine, "Failed to get embeddings"); return -1; } @@ -301,5 +307,5 @@ void dbmem_local_engine_free (dbmem_local_engine_t *engine) { } llama_backend_free(); + dbmemory_free(engine); } - diff --git a/src/dbmem-parser.c b/src/dbmem-parser.c index dc5042d..2e69c04 100644 --- a/src/dbmem-parser.c +++ b/src/dbmem-parser.c @@ -28,6 +28,7 @@ typedef struct { size_t start; // Byte offset in source buffer size_t end; // Byte end in source buffer + int is_heading; // True if this section starts with a heading block char *text; // Stripped plain text (allocated) size_t text_len; // Length of stripped text } section_t; @@ -113,8 +114,6 @@ static size_t find_split (const char *text, size_t len, size_t max_chars) { // Push a section to dynamic array static int section_push (parse_ctx_t *ctx, size_t start, size_t end, int is_heading) { - UNUSED_PARAM(is_heading); - if (ctx->sec_count >= ctx->sec_cap) { size_t new_cap = ctx->sec_cap ? ctx->sec_cap * 2 : 16; section_t *tmp = (section_t *)dbmemory_realloc(ctx->sections, new_cap * sizeof(section_t)); @@ -126,6 +125,7 @@ static int section_push (parse_ctx_t *ctx, size_t start, size_t end, int is_head section_t *s = &ctx->sections[ctx->sec_count++]; s->start = start; s->end = end; + s->is_heading = is_heading; s->text = NULL; s->text_len = 0; @@ -607,7 +607,7 @@ static int parse_sections (const char *buffer, size_t buffer_size, bool skip_sem for (size_t i = 0; i < ctx->sec_count; i++) { section_t *s = &ctx->sections[i]; // First section or heading starts new section - if (write_idx == 0) { + if (write_idx == 0 || s->is_heading) { ctx->sections[write_idx++] = *s; } else { // Extend previous section to include this one diff --git a/src/dbmem-rembed.c b/src/dbmem-rembed.c index 9ad7c22..0613ed4 100644 --- a/src/dbmem-rembed.c +++ b/src/dbmem-rembed.c @@ -26,6 +26,7 @@ static size_t cacert_len = sizeof(cacert_pem) - 1; #ifndef DBMEM_OMIT_CURL static size_t dbmem_remote_receive_data(void *contents, size_t size, size_t nmemb, void *xdata); +static struct curl_slist *dbmem_remote_build_headers (const char *api_key); #endif struct dbmem_remote_engine_t { @@ -67,6 +68,27 @@ struct dbmem_remote_engine_t { #include #include +#ifndef DBMEM_OMIT_CURL +static struct curl_slist *dbmem_remote_build_headers (const char *api_key) { + char auth_header[512]; + struct curl_slist *headers = NULL; + struct curl_slist *next = NULL; + + snprintf(auth_header, sizeof(auth_header), "Authorization: Bearer %s", api_key); + headers = curl_slist_append(headers, auth_header); + if (!headers) return NULL; + + next = curl_slist_append(headers, "Content-Type: application/json"); + if (!next) { + curl_slist_free_all(headers); + return NULL; + } + headers = next; + + return headers; +} +#endif + static bool text_needs_json_escape (const char *text, size_t *len) { size_t original_len = *len; size_t required_len = 0; @@ -263,11 +285,7 @@ dbmem_remote_engine_t *dbmem_remote_engine_init (void *ctx, const char *provider #endif // set up headers - char auth_header[512]; - snprintf(auth_header, sizeof(auth_header), "Authorization: Bearer %s", api_key); - struct curl_slist *headers = NULL; - headers = curl_slist_append(headers, auth_header); - if (headers) headers = curl_slist_append(headers, "Content-Type: application/json"); + struct curl_slist *headers = dbmem_remote_build_headers(api_key); if (!headers) { snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate HTTP headers"); curl_easy_cleanup(curl); @@ -522,6 +540,36 @@ int dbmem_remote_compute_embedding (dbmem_remote_engine_t *engine, const char *t return 0; } +int dbmem_remote_engine_set_apikey (dbmem_remote_engine_t *engine, const char *api_key, char err_msg[DBMEM_ERRBUF_SIZE]) { + if (!engine || !api_key) { + if (err_msg) snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Invalid remote engine or API key"); + return SQLITE_MISUSE; + } + +#ifndef DBMEM_OMIT_CURL + struct curl_slist *headers = dbmem_remote_build_headers(api_key); + if (!headers) { + if (err_msg) snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Failed to allocate HTTP headers"); + return SQLITE_NOMEM; + } + + curl_easy_setopt(engine->curl, CURLOPT_HTTPHEADER, headers); + if (engine->headers) curl_slist_free_all(engine->headers); + engine->headers = headers; +#else + char *copy = dbmem_strdup(api_key); + if (!copy) { + if (err_msg) snprintf(err_msg, DBMEM_ERRBUF_SIZE, "Unable to duplicate API key (insufficient memory)"); + return SQLITE_NOMEM; + } + + if (engine->api_key) dbmemory_free(engine->api_key); + engine->api_key = copy; +#endif + + return SQLITE_OK; +} + void dbmem_remote_engine_free (dbmem_remote_engine_t *engine) { if (!engine) return; diff --git a/src/dbmem-search.c b/src/dbmem-search.c index 9b4f6e4..7dec56d 100644 --- a/src/dbmem-search.c +++ b/src/dbmem-search.c @@ -19,6 +19,7 @@ #include #include #include +#include #include #ifndef SQLITE_CORE @@ -52,14 +53,14 @@ typedef struct { struct { int count; double *rank; - sqlite3_int64 *hash; + uint64_t *hash; sqlite3_int64 *seq; } fts; struct { int count; double *rank; - sqlite3_int64 *hash; + uint64_t *hash; sqlite3_int64 *seq; } semantic; @@ -67,21 +68,45 @@ typedef struct { int count; double *vectorScore; double *textScore; - sqlite3_int64 *hash; + uint64_t *hash; sqlite3_int64 *seq; int *hasVector; int *hasFts; } merge; double *rank; - sqlite3_int64 *hash; + uint64_t *hash; sqlite3_int64 *seq; } vMemorySearchCursor; +static int dbmem_search_bind_hash (sqlite3_stmt *vm, int index, uint64_t hash) { + char hash_text[DBMEM_HASH_STR_MAXLEN]; + dbmem_hash_to_hex(hash, hash_text); + return sqlite3_bind_text(vm, index, hash_text, -1, SQLITE_TRANSIENT); +} + +static bool dbmem_search_column_hash (sqlite3_stmt *vm, int column, uint64_t *hash) { + const char *hash_text = (const char *)sqlite3_column_text(vm, column); + return dbmem_hash_from_hex(hash_text, hash); +} + // MARK: - UTILS - +static void vMemorySearchCursorReset (vMemorySearchCursor *c) { + if (c->buffer) dbmemory_free(c->buffer); + memset((char *)c + offsetof(vMemorySearchCursor, max_results), 0, + sizeof(*c) - offsetof(vMemorySearchCursor, max_results)); +} + int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perform_fts) { + if (entries <= 0) { + vMemorySearchCursorReset(c); + c->max_results = entries; + c->perform_fts = perform_fts; + return SQLITE_OK; + } + // one buffer to rule them all // fts (if enabled): rank, hash, seq = 3 arrays * entries // semantic: rank, hash, seq = 3 arrays * entries @@ -94,26 +119,26 @@ int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perfo // fts arrays if (perform_fts) { size += sizeof(double) * entries; // fts.rank - size += sizeof(sqlite3_int64) * entries; // fts.hash + size += sizeof(uint64_t) * entries; // fts.hash size += sizeof(sqlite3_int64) * entries; // fts.seq } // semantic arrays size += sizeof(double) * entries; // semantic.rank - size += sizeof(sqlite3_int64) * entries; // semantic.hash + size += sizeof(uint64_t) * entries; // semantic.hash size += sizeof(sqlite3_int64) * entries; // semantic.seq // merge arrays (2x entries for union of both sources) size += sizeof(double) * merge_entries; // merge.vectorScore size += sizeof(double) * merge_entries; // merge.textScore - size += sizeof(sqlite3_int64) * merge_entries; // merge.hash + size += sizeof(uint64_t) * merge_entries; // merge.hash size += sizeof(sqlite3_int64) * merge_entries; // merge.seq size += sizeof(int) * merge_entries; // merge.hasVector size += sizeof(int) * merge_entries; // merge.hasFts // final arrays size += sizeof(double) * entries; // rank - size += sizeof(sqlite3_int64) * entries; // hash + size += sizeof(uint64_t) * entries; // hash size += sizeof(sqlite3_int64) * entries; // seq char *buffer = (char *)dbmemory_zeroalloc(size); @@ -127,8 +152,8 @@ int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perfo if (perform_fts) { c->fts.rank = (double *)buffer; buffer += sizeof(double) * entries; - c->fts.hash = (sqlite3_int64 *)buffer; - buffer += sizeof(sqlite3_int64) * entries; + c->fts.hash = (uint64_t *)buffer; + buffer += sizeof(uint64_t) * entries; c->fts.seq = (sqlite3_int64 *)buffer; buffer += sizeof(sqlite3_int64) * entries; } @@ -136,8 +161,8 @@ int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perfo // semantic c->semantic.rank = (double *)buffer; buffer += sizeof(double) * entries; - c->semantic.hash = (sqlite3_int64 *)buffer; - buffer += sizeof(sqlite3_int64) * entries; + c->semantic.hash = (uint64_t *)buffer; + buffer += sizeof(uint64_t) * entries; c->semantic.seq = (sqlite3_int64 *)buffer; buffer += sizeof(sqlite3_int64) * entries; @@ -146,8 +171,8 @@ int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perfo buffer += sizeof(double) * merge_entries; c->merge.textScore = (double *)buffer; buffer += sizeof(double) * merge_entries; - c->merge.hash = (sqlite3_int64 *)buffer; - buffer += sizeof(sqlite3_int64) * merge_entries; + c->merge.hash = (uint64_t *)buffer; + buffer += sizeof(uint64_t) * merge_entries; c->merge.seq = (sqlite3_int64 *)buffer; buffer += sizeof(sqlite3_int64) * merge_entries; c->merge.hasVector = (int *)buffer; @@ -158,8 +183,8 @@ int vMemorySearchCursorAllocate (vMemorySearchCursor *c, int entries, bool perfo // final rowset c->rank = (double *)buffer; buffer += sizeof(double) * entries; - c->hash = (sqlite3_int64 *)buffer; - buffer += sizeof(sqlite3_int64) * entries; + c->hash = (uint64_t *)buffer; + buffer += sizeof(uint64_t) * entries; c->seq = (sqlite3_int64 *)buffer; return SQLITE_OK; @@ -183,7 +208,7 @@ int vMemorySearchCursorMerge(vMemorySearchCursor *c, double vectorWeight, double // add/merge FTS results (already normalized to 0..1) for (int i = 0; i < c->fts.count; i++) { - sqlite3_int64 hash = c->fts.hash[i]; + uint64_t hash = c->fts.hash[i]; sqlite3_int64 seq = c->fts.seq[i]; // check if already in merge list @@ -230,7 +255,7 @@ int vMemorySearchCursorMerge(vMemorySearchCursor *c, double vectorWeight, double // swap all parallel arrays together for (int i = 1; i < c->merge.count; i++) { double tempScore = c->merge.textScore[i]; - sqlite3_int64 tempHash = c->merge.hash[i]; + uint64_t tempHash = c->merge.hash[i]; sqlite3_int64 tempSeq = c->merge.seq[i]; int j = i - 1; @@ -269,7 +294,7 @@ static void vMemorySearchUpdateAccess(sqlite3 *db, vMemorySearchCursor *c) { for (int i = 0; i < c->count; i++) { sqlite3_bind_int64(vm, 1, now); - sqlite3_bind_int64(vm, 2, c->hash[i]); + dbmem_search_bind_hash(vm, 2, c->hash[i]); sqlite3_step(vm); sqlite3_reset(vm); } @@ -308,8 +333,7 @@ static int dbmem_fts_search (sqlite3 *db, vMemorySearchCursor *c, const char *in "SELECT rank, hash, seq FROM dbmem_vault_fts WHERE content MATCH ?1 ORDER BY rank LIMIT ?2"; static const char *sql_with_context = "SELECT fts.rank, fts.hash, fts.seq FROM dbmem_vault_fts AS fts " - "JOIN dbmem_vault AS v ON fts.hash = v.hash AND fts.seq = v.seq " - "WHERE fts.content MATCH ?1 AND INSTR(',' || ?3 || ',', ',' || v.context || ',') > 0 " + "WHERE fts.content MATCH ?1 AND INSTR(',' || ?3 || ',', ',' || fts.context || ',') > 0 " "ORDER BY fts.rank LIMIT ?2"; const char *sql = (context) ? sql_with_context : sql_no_context; @@ -347,7 +371,10 @@ static int dbmem_fts_search (sqlite3 *db, vMemorySearchCursor *c, const char *in if (rank > rank_max) rank_max = rank; c->fts.rank[count] = rank; - c->fts.hash[count] = sqlite3_column_int64(vm, 1); + if (!dbmem_search_column_hash(vm, 1, &c->fts.hash[count])) { + rc = SQLITE_MISMATCH; + break; + } c->fts.seq[count] = sqlite3_column_int64(vm, 2); c->fts.count++; @@ -409,7 +436,10 @@ static int dbmem_semantic_search (sqlite3 *db, vMemorySearchCursor *c, float *em // SQLITE_ROW c->semantic.rank[count] = sqlite3_column_double(vm, 0); - c->semantic.hash[count] = sqlite3_column_int64(vm, 1); + if (!dbmem_search_column_hash(vm, 1, &c->semantic.hash[count])) { + rc = SQLITE_MISMATCH; + break; + } c->semantic.seq[count] = sqlite3_column_int64(vm, 2); c->semantic.count++; @@ -504,7 +534,7 @@ static int vMemorySearchCursorOpen (sqlite3_vtab *pVtab, sqlite3_vtab_cursor **p static int vMemorySearchCursorClose (sqlite3_vtab_cursor *cur){ vMemorySearchCursor *c = (vMemorySearchCursor *)cur; - if (c->buffer) dbmemory_free(c->buffer); + vMemorySearchCursorReset(c); dbmemory_free(c); return SQLITE_OK; } @@ -530,7 +560,10 @@ static int vMemorySearchCursorColumn (sqlite3_vtab_cursor *cur, sqlite3_context switch (iCol) { case SEARCH_COLUMN_HASH: - sqlite3_result_int64(context, c->hash[c->index]); + { + char hash_text[DBMEM_HASH_STR_MAXLEN]; + sqlite3_result_text(context, dbmem_hash_to_hex(c->hash[c->index], hash_text), -1, SQLITE_TRANSIENT); + } break; case SEARCH_COLUMN_SEQ: @@ -546,7 +579,7 @@ static int vMemorySearchCursorColumn (sqlite3_vtab_cursor *cur, sqlite3_context const char *sql = (iCol == SEARCH_COLUMN_PATH) ? path_sql : snippet_sql; sqlite3_stmt *vm = NULL; if (sqlite3_prepare_v2(db, sql, -1, &vm, NULL) == SQLITE_OK) { - sqlite3_bind_int64(vm, 1, c->hash[c->index]); + dbmem_search_bind_hash(vm, 1, c->hash[c->index]); if (iCol == SEARCH_COLUMN_SNIPPET) sqlite3_bind_int64(vm, 2, c->seq[c->index]); if (sqlite3_step(vm) == SQLITE_ROW) sqlite3_result_value(context, sqlite3_column_value(vm, 0)); } @@ -612,6 +645,12 @@ static int vMemorySearchCursorFilter (sqlite3_vtab_cursor *cur, int idxNum, cons fetch_count = max_results; } + vMemorySearchCursorReset(c); + if (fetch_count <= 0) { + c->count = 0; + return SQLITE_OK; + } + // allocate internal cursor buffer int rc = vMemorySearchCursorAllocate(c, fetch_count, perform_fts); if (rc != SQLITE_OK) return SQLITE_NOMEM; @@ -698,9 +737,9 @@ static int vMemorySearchCursorFilter (sqlite3_vtab_cursor *cur, int idxNum, cons printf("=================================\n"); for (int i = 0; i < c->count; i++) { double rank = c->rank[i]; - sqlite3_int64 hash = c->hash[i]; + uint64_t hash = c->hash[i]; sqlite3_int64 seq = c->seq[i]; - printf("%3d %.3f %20lld %2lld\n", i, rank, (long long)hash, (long long)seq); + printf("%3d %.3f %016llx %2lld\n", i, rank, (unsigned long long)hash, (long long)seq); } printf("=================================\n"); #endif diff --git a/src/dbmem-utils.c b/src/dbmem-utils.c index c82a83e..919daf2 100644 --- a/src/dbmem-utils.c +++ b/src/dbmem-utils.c @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -137,6 +138,29 @@ uint64_t dbmem_hash_compute (const void *data, size_t len) { return mix64(h); } +char *dbmem_hash_to_hex (uint64_t hash, char value[DBMEM_HASH_STR_MAXLEN]) { + snprintf(value, DBMEM_HASH_STR_MAXLEN, "%016llx", (unsigned long long)hash); + return value; +} + +bool dbmem_hash_from_hex (const char *text, uint64_t *value) { + if (!text || !value) return false; + if (strlen(text) != DBMEM_HASH_HEX_LEN) return false; + + uint64_t result = 0; + for (int i = 0; i < DBMEM_HASH_HEX_LEN; i++) { + char c = text[i]; + if (!isxdigit((unsigned char)c)) return false; + + result <<= 4; + if (c >= '0' && c <= '9') result |= (uint64_t)(c - '0'); + else result |= (uint64_t)(10 + (tolower((unsigned char)c) - 'a')); + } + + *value = result; + return true; +} + // MARK: - UUIDv7 - /* diff --git a/src/dbmem-utils.h b/src/dbmem-utils.h index 33623d8..f51299a 100644 --- a/src/dbmem-utils.h +++ b/src/dbmem-utils.h @@ -33,6 +33,8 @@ #define DBMEM_ERRBUF_SIZE 1024 #define DBMEM_UUID_STR_MAXLEN 37 +#define DBMEM_HASH_HEX_LEN 16 +#define DBMEM_HASH_STR_MAXLEN (DBMEM_HASH_HEX_LEN + 1) // MEMORY void *dbmemory_alloc (uint64_t size); @@ -52,6 +54,8 @@ bool dbmem_file_has_extension (const char *path, const char *extensions); // GENERAL uint64_t dbmem_hash_compute (const void *data, size_t len); +char *dbmem_hash_to_hex (uint64_t hash, char value[DBMEM_HASH_STR_MAXLEN]); +bool dbmem_hash_from_hex (const char *text, uint64_t *value); char *dbmem_uuid_v7 (char value[DBMEM_UUID_STR_MAXLEN]); #endif diff --git a/src/sqlite-memory.c b/src/sqlite-memory.c index 2c7a0c7..d82ca9e 100644 --- a/src/sqlite-memory.c +++ b/src/sqlite-memory.c @@ -134,6 +134,31 @@ struct dbmem_context { static bool fts5_is_available = true; +static int dbmem_bind_hash (sqlite3_stmt *vm, int index, uint64_t hash) { + char hash_text[DBMEM_HASH_STR_MAXLEN]; + dbmem_hash_to_hex(hash, hash_text); + return sqlite3_bind_text(vm, index, hash_text, -1, SQLITE_TRANSIENT); +} + +static bool dbmem_column_hash (sqlite3_stmt *vm, int column, uint64_t *hash) { + const char *hash_text = (const char *)sqlite3_column_text(vm, column); + return dbmem_hash_from_hex(hash_text, hash); +} + +static bool dbmem_value_hash (sqlite3_value *value, uint64_t *hash) { + if (!value || !hash) return false; + + switch (sqlite3_value_type(value)) { + case SQLITE_TEXT: + return dbmem_hash_from_hex((const char *)sqlite3_value_text(value), hash); + case SQLITE_INTEGER: + *hash = (uint64_t)sqlite3_value_int64(value); + return true; + default: + return false; + } +} + // MARK: - Settings - static int dbmem_settings_write (sqlite3 *db, const char *key, const char *text_value, sqlite3_int64 int_value, const sqlite3_value *sql_value, int bind_type) { @@ -233,25 +258,25 @@ static int dbmem_settings_sync (dbmem_context *ctx, const char *key, sqlite3_val if (strcasecmp(key, DBMEM_SETTINGS_KEY_MAX_RESULTS) == 0) { int n = sqlite3_value_int(value); - if (n > 0) ctx->max_results = n; + if (n >= 0) ctx->max_results = n; return 0; } if (strcasecmp(key, DBMEM_SETTINGS_KEY_VECTOR_WEIGHT) == 0) { double n = sqlite3_value_double(value); - if (n > 0) ctx->vector_weight = n; + if (n >= 0) ctx->vector_weight = n; return 0; } if (strcasecmp(key, DBMEM_SETTINGS_KEY_TEXT_WEIGHT) == 0) { double n = sqlite3_value_double(value); - if (n > 0) ctx->text_weight = n; + if (n >= 0) ctx->text_weight = n; return 0; } if (strcasecmp(key, DBMEM_SETTINGS_KEY_MIN_SCORE) == 0) { double n = sqlite3_value_double(value); - if (n > 0) ctx->min_score = n; + if (n >= 0) ctx->min_score = n; return 0; } @@ -338,15 +363,15 @@ static int dbmem_database_init (sqlite3 *db) { int rc = sqlite3_exec(db, sql, NULL, NULL, NULL); if (rc != SQLITE_OK) return rc; - sql = "CREATE TABLE IF NOT EXISTS dbmem_content (hash INTEGER PRIMARY KEY NOT NULL, path TEXT NOT NULL DEFAULT '' UNIQUE, value TEXT DEFAULT NULL, length INTEGER NOT NULL DEFAULT 0, context TEXT DEFAULT NULL, created_at INTEGER DEFAULT 0, last_accessed INTEGER DEFAULT 0);"; + sql = "CREATE TABLE IF NOT EXISTS dbmem_content (hash TEXT PRIMARY KEY NOT NULL, path TEXT NOT NULL DEFAULT '' UNIQUE, value TEXT DEFAULT NULL, length INTEGER NOT NULL DEFAULT 0, context TEXT DEFAULT NULL, created_at INTEGER DEFAULT 0, last_accessed INTEGER DEFAULT 0);"; rc = sqlite3_exec(db, sql, NULL, NULL, NULL); if (rc != SQLITE_OK) return rc; - sql = "CREATE TABLE IF NOT EXISTS dbmem_vault (hash INTEGER NOT NULL, seq INTEGER NOT NULL, embedding BLOB NOT NULL, offset INTEGER NOT NULL, length INTEGER NOT NULL, PRIMARY KEY (hash, seq));"; + sql = "CREATE TABLE IF NOT EXISTS dbmem_vault (hash TEXT NOT NULL, seq INTEGER NOT NULL, embedding BLOB NOT NULL, offset INTEGER NOT NULL, length INTEGER NOT NULL, PRIMARY KEY (hash, seq));"; rc = sqlite3_exec(db, sql, NULL, NULL, NULL); if (rc != SQLITE_OK) return rc; - sql = "CREATE TABLE IF NOT EXISTS dbmem_cache (text_hash INTEGER NOT NULL, provider TEXT NOT NULL, model TEXT NOT NULL, embedding BLOB NOT NULL, dimension INTEGER NOT NULL, PRIMARY KEY (text_hash, provider, model));"; + sql = "CREATE TABLE IF NOT EXISTS dbmem_cache (text_hash TEXT NOT NULL, provider TEXT NOT NULL, model TEXT NOT NULL, embedding BLOB NOT NULL, dimension INTEGER NOT NULL, PRIMARY KEY (text_hash, provider, model));"; rc = sqlite3_exec(db, sql, NULL, NULL, NULL); if (rc != SQLITE_OK) return rc; @@ -375,7 +400,7 @@ static bool dbmem_database_check_if_stored (sqlite3 *db, uint64_t hash, int64_t int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL); if (rc != SQLITE_OK) goto cleanup; - rc = sqlite3_bind_int64(vm, 1, (sqlite3_int64)hash); + rc = dbmem_bind_hash(vm, 1, hash); if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_step(vm); @@ -391,21 +416,21 @@ static bool dbmem_database_check_if_stored (sqlite3 *db, uint64_t hash, int64_t return result; } -static void dbmem_database_delete_hash (sqlite3 *db, sqlite3_int64 hash) { +static void dbmem_database_delete_hash (sqlite3 *db, uint64_t hash) { sqlite3_stmt *vm = NULL; if (fts5_is_available) { sqlite3_prepare_v2(db, "DELETE FROM dbmem_vault_fts WHERE hash=?1;", -1, &vm, NULL); - sqlite3_bind_int64(vm, 1, hash); + dbmem_bind_hash(vm, 1, hash); sqlite3_step(vm); sqlite3_finalize(vm); } sqlite3_prepare_v2(db, "DELETE FROM dbmem_vault WHERE hash=?1;", -1, &vm, NULL); - sqlite3_bind_int64(vm, 1, hash); + dbmem_bind_hash(vm, 1, hash); sqlite3_step(vm); sqlite3_finalize(vm); sqlite3_prepare_v2(db, "DELETE FROM dbmem_content WHERE hash=?1;", -1, &vm, NULL); - sqlite3_bind_int64(vm, 1, hash); + dbmem_bind_hash(vm, 1, hash); sqlite3_step(vm); sqlite3_finalize(vm); } @@ -420,9 +445,10 @@ static void dbmem_database_delete_stale_path (sqlite3 *db, const char *path, uin sqlite3_bind_text(vm, 1, path, -1, SQLITE_STATIC); rc = sqlite3_step(vm); if (rc == SQLITE_ROW) { - sqlite3_int64 old_hash = sqlite3_column_int64(vm, 0); + uint64_t old_hash = 0; + bool has_old_hash = dbmem_column_hash(vm, 0, &old_hash); sqlite3_finalize(vm); - if ((uint64_t)old_hash != new_hash) { + if (has_old_hash && old_hash != new_hash) { dbmem_database_delete_hash(db, old_hash); } } else { @@ -437,7 +463,7 @@ static int dbmem_database_add_entry (dbmem_context *ctx, sqlite3 *db, uint64_t h int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL); if (rc != SQLITE_OK) goto cleanup; - rc = sqlite3_bind_int64(vm, 1, (sqlite3_int64)hash); + rc = dbmem_bind_hash(vm, 1, hash); if (rc != SQLITE_OK) goto cleanup; const char *path = ctx->path; @@ -475,7 +501,7 @@ static int dbmem_database_add_chunk (dbmem_context *ctx, embedding_result_t *res int rc = sqlite3_prepare_v2(ctx->db, sql, -1, &vm, NULL); if (rc != SQLITE_OK) goto cleanup; - rc = sqlite3_bind_int64(vm, 1, (sqlite3_int64)ctx->hash); + rc = dbmem_bind_hash(vm, 1, ctx->hash); if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_bind_int64(vm, 2, (sqlite3_int64)index); @@ -509,7 +535,7 @@ static int dbmem_database_add_fts5 (dbmem_context *ctx, const char *text, size_t rc = sqlite3_bind_text(vm, 1, text, (int)text_len, SQLITE_STATIC); if (rc != SQLITE_OK) goto cleanup; - rc = sqlite3_bind_int64(vm, 2, (sqlite3_int64)ctx->hash); + rc = dbmem_bind_hash(vm, 2, ctx->hash); if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_bind_int64(vm, 3, (sqlite3_int64)index); @@ -720,12 +746,11 @@ void dbmem_context_set_errorf (dbmem_context *ctx, const char *fmt, ...) { static void dbmem_delete (sqlite3_context *context, int argc, sqlite3_value **argv) { UNUSED_PARAM(argc); - if (sqlite3_value_type(argv[0]) != SQLITE_INTEGER) { - sqlite3_result_error(context, "The function memory_delete expects one argument of type INTEGER (hash)", SQLITE_ERROR); + uint64_t hash = 0; + if (!dbmem_value_hash(argv[0], &hash)) { + sqlite3_result_error(context, "The function memory_delete expects one argument of type TEXT (hash)", SQLITE_ERROR); return; } - - sqlite3_int64 hash = sqlite3_value_int64(argv[0]); sqlite3 *db = sqlite3_context_db_handle(context); int rc = dbmem_database_begin_transaction(db); @@ -739,7 +764,7 @@ static void dbmem_delete (sqlite3_context *context, int argc, sqlite3_value **ar sqlite3_stmt *vm = NULL; rc = sqlite3_prepare_v2(db, "DELETE FROM dbmem_vault_fts WHERE hash = ?1;", -1, &vm, NULL); if (rc == SQLITE_OK) { - sqlite3_bind_int64(vm, 1, hash); + dbmem_bind_hash(vm, 1, hash); sqlite3_step(vm); sqlite3_finalize(vm); } @@ -749,7 +774,7 @@ static void dbmem_delete (sqlite3_context *context, int argc, sqlite3_value **ar sqlite3_stmt *vm = NULL; rc = sqlite3_prepare_v2(db, "DELETE FROM dbmem_vault WHERE hash = ?1;", -1, &vm, NULL); if (rc != SQLITE_OK) goto rollback; - sqlite3_bind_int64(vm, 1, hash); + dbmem_bind_hash(vm, 1, hash); rc = sqlite3_step(vm); sqlite3_finalize(vm); if (rc != SQLITE_DONE) goto rollback; @@ -757,7 +782,7 @@ static void dbmem_delete (sqlite3_context *context, int argc, sqlite3_value **ar // Delete from content rc = sqlite3_prepare_v2(db, "DELETE FROM dbmem_content WHERE hash = ?1;", -1, &vm, NULL); if (rc != SQLITE_OK) goto rollback; - sqlite3_bind_int64(vm, 1, hash); + dbmem_bind_hash(vm, 1, hash); rc = sqlite3_step(vm); sqlite3_finalize(vm); if (rc != SQLITE_DONE) goto rollback; @@ -926,6 +951,7 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value * // retrieve context dbmem_context *ctx = (dbmem_context *)sqlite3_user_data(context); + sqlite3 *db = sqlite3_context_db_handle(context); // detect model change (only if a model was previously configured) bool model_changed = false; @@ -954,80 +980,177 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value * #endif } + char *new_provider = dbmem_strdup(provider); + char *new_model = dbmem_strdup(model); + if (!new_provider || !new_model) { + if (new_provider) dbmemory_free(new_provider); + if (new_model) dbmemory_free(new_model); + sqlite3_result_error_nomem(context); + return; + } + + char *old_provider = ctx->provider; + char *old_model = ctx->model; + bool old_is_local = ctx->is_local; + bool old_is_custom = ctx->is_custom; + + #ifndef DBMEM_OMIT_LOCAL_ENGINE + dbmem_local_engine_t *old_l_engine = ctx->l_engine; + dbmem_local_engine_t *new_l_engine = ctx->l_engine; + #endif + + #ifndef DBMEM_OMIT_REMOTE_ENGINE + dbmem_remote_engine_t *old_r_engine = ctx->r_engine; + dbmem_remote_engine_t *new_r_engine = ctx->r_engine; + #endif + + void *old_custom_engine = ctx->custom_engine; + void *new_custom_engine = ctx->custom_engine; + bool set_model_started = false; + int rc = SQLITE_OK; + // custom provider path if (is_custom_provider) { - // free previous custom engine if any - if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine, ctx->custom_provider.xdata); - ctx->custom_engine = NULL; - - ctx->custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->custom_provider.xdata, ctx->error_msg); - if (ctx->custom_engine == NULL) { + new_custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->custom_provider.xdata, ctx->error_msg); + if (new_custom_engine == NULL) { + dbmemory_free(new_provider); + dbmemory_free(new_model); sqlite3_result_error(context, ctx->error_msg, -1); return; } - ctx->is_custom = true; - ctx->is_local = false; } // if provider is local then make sure model file exists #ifndef DBMEM_OMIT_LOCAL_ENGINE if (!is_custom_provider && is_local_provider) { if (dbmem_file_exists(model) == false) { + dbmemory_free(new_provider); + dbmemory_free(new_model); sqlite3_result_error(context, "Local model not found in the specified path", SQLITE_ERROR); return; } - if (ctx->l_engine) dbmem_local_engine_free(ctx->l_engine); - ctx->l_engine = NULL; - - ctx->l_engine = dbmem_local_engine_init(ctx, model, ctx->error_msg); - if (ctx->l_engine == NULL) { + new_l_engine = dbmem_local_engine_init(ctx, model, ctx->error_msg); + if (new_l_engine == NULL) { + dbmemory_free(new_provider); + dbmemory_free(new_model); sqlite3_result_error(context, ctx->error_msg, -1); return; } if (ctx->engine_warmup) { - dbmem_local_engine_warmup(ctx->l_engine); + dbmem_local_engine_warmup(new_l_engine); } - - ctx->is_local = true; - ctx->is_custom = false; } #endif #ifndef DBMEM_OMIT_REMOTE_ENGINE if (!is_custom_provider && !is_local_provider) { - if (ctx->r_engine) dbmem_remote_engine_free(ctx->r_engine); - ctx->r_engine = NULL; - - ctx->r_engine = dbmem_remote_engine_init(ctx, provider, model, ctx->error_msg); - if (ctx->r_engine == NULL) { + new_r_engine = dbmem_remote_engine_init(ctx, provider, model, ctx->error_msg); + if (new_r_engine == NULL) { + dbmemory_free(new_provider); + dbmemory_free(new_model); sqlite3_result_error(context, ctx->error_msg, -1); return; } - - ctx->is_local = false; - ctx->is_custom = false; } #endif - + + ctx->provider = new_provider; + ctx->model = new_model; + ctx->is_local = is_custom_provider ? false : is_local_provider; + ctx->is_custom = is_custom_provider; + ctx->custom_engine = new_custom_engine; + #ifndef DBMEM_OMIT_LOCAL_ENGINE + ctx->l_engine = new_l_engine; + #endif + #ifndef DBMEM_OMIT_REMOTE_ENGINE + ctx->r_engine = new_r_engine; + #endif + + rc = sqlite3_exec(db, "SAVEPOINT dbmem_set_model;", NULL, NULL, NULL); + if (rc == SQLITE_OK) set_model_started = true; + // update settings - sqlite3 *db = sqlite3_context_db_handle(context); - int rc = dbmem_settings_write_text (db, DBMEM_SETTINGS_KEY_PROVIDER, provider); - if (rc == SQLITE_OK) rc = dbmem_settings_write_text (db, DBMEM_SETTINGS_KEY_MODEL, model); - - // sync settings - if (rc == SQLITE_OK) { - dbmem_settings_sync(ctx, DBMEM_SETTINGS_KEY_PROVIDER, argv[0]); - dbmem_settings_sync(ctx, DBMEM_SETTINGS_KEY_MODEL, argv[1]); - } + if (rc == SQLITE_OK) rc = dbmem_settings_write_text(db, DBMEM_SETTINGS_KEY_PROVIDER, provider); + if (rc == SQLITE_OK) rc = dbmem_settings_write_text(db, DBMEM_SETTINGS_KEY_MODEL, model); // reindex all content if the model changed if (model_changed && rc == SQLITE_OK) { rc = dbmem_reindex(ctx); } - (rc == SQLITE_OK) ? sqlite3_result_int(context, 1) : sqlite3_result_error(context, sqlite3_errmsg(db), -1); + if (rc == SQLITE_OK && set_model_started) { + rc = sqlite3_exec(db, "RELEASE dbmem_set_model;", NULL, NULL, NULL); + set_model_started = false; + } + + if (rc != SQLITE_OK) { + if (set_model_started) { + sqlite3_exec(db, "ROLLBACK TO dbmem_set_model; RELEASE dbmem_set_model;", NULL, NULL, NULL); + } + + ctx->provider = old_provider; + ctx->model = old_model; + ctx->is_local = old_is_local; + ctx->is_custom = old_is_custom; + ctx->custom_engine = old_custom_engine; + #ifndef DBMEM_OMIT_LOCAL_ENGINE + ctx->l_engine = old_l_engine; + if (!is_custom_provider && is_local_provider && new_l_engine != old_l_engine && new_l_engine) { + dbmem_local_engine_free(new_l_engine); + } + #endif + #ifndef DBMEM_OMIT_REMOTE_ENGINE + ctx->r_engine = old_r_engine; + if (!is_custom_provider && !is_local_provider && new_r_engine != old_r_engine && new_r_engine) { + dbmem_remote_engine_free(new_r_engine); + } + #endif + if (is_custom_provider && new_custom_engine != old_custom_engine && new_custom_engine && ctx->custom_provider.free) { + ctx->custom_provider.free(new_custom_engine, ctx->custom_provider.xdata); + } + dbmemory_free(new_provider); + dbmemory_free(new_model); + sqlite3_result_error(context, ctx->error_msg[0] ? ctx->error_msg : sqlite3_errmsg(db), -1); + return; + } + + if (old_provider) dbmemory_free(old_provider); + if (old_model) dbmemory_free(old_model); + #ifndef DBMEM_OMIT_LOCAL_ENGINE + if (!is_custom_provider && is_local_provider) { + if (old_l_engine && old_l_engine != new_l_engine) { + dbmem_local_engine_free(old_l_engine); + } + } else if (old_l_engine) { + // switching away from local provider: release the previous engine + dbmem_local_engine_free(old_l_engine); + ctx->l_engine = NULL; + } + #endif + #ifndef DBMEM_OMIT_REMOTE_ENGINE + if (!is_custom_provider && !is_local_provider) { + if (old_r_engine && old_r_engine != new_r_engine) { + dbmem_remote_engine_free(old_r_engine); + } + } else if (old_r_engine) { + // switching away from remote provider: release the previous engine + dbmem_remote_engine_free(old_r_engine); + ctx->r_engine = NULL; + } + #endif + if (is_custom_provider) { + if (old_custom_engine && old_custom_engine != new_custom_engine && ctx->custom_provider.free) { + ctx->custom_provider.free(old_custom_engine, ctx->custom_provider.xdata); + } + } else if (old_custom_engine && ctx->custom_provider.free) { + // switching away from custom provider: release the previous engine + ctx->custom_provider.free(old_custom_engine, ctx->custom_provider.xdata); + ctx->custom_engine = NULL; + } + + sqlite3_result_int(context, 1); } static void dbmem_set_apikey (sqlite3_context *context, int argc, sqlite3_value **argv) { @@ -1045,7 +1168,18 @@ static void dbmem_set_apikey (sqlite3_context *context, int argc, sqlite3_value // retrieve context dbmem_context *ctx = (dbmem_context *)sqlite3_user_data(context); - + + #ifndef DBMEM_OMIT_REMOTE_ENGINE + if (ctx->r_engine && !ctx->is_local && !ctx->is_custom) { + int rc = dbmem_remote_engine_set_apikey(ctx->r_engine, apikey, ctx->error_msg); + if (rc != SQLITE_OK) { + dbmemory_free(apikey); + sqlite3_result_error(context, ctx->error_msg[0] ? ctx->error_msg : "Unable to update remote API key", -1); + return; + } + } + #endif + if (ctx->api_key) dbmemory_free(ctx->api_key); ctx->api_key = apikey; @@ -1142,7 +1276,7 @@ static bool dbmem_cache_lookup (dbmem_context *ctx, uint64_t text_hash, embeddin int rc = sqlite3_prepare_v2(ctx->db, sql, -1, &vm, NULL); if (rc != SQLITE_OK) goto cleanup; - sqlite3_bind_int64(vm, 1, (sqlite3_int64)text_hash); + dbmem_bind_hash(vm, 1, text_hash); sqlite3_bind_text(vm, 2, ctx->provider, -1, SQLITE_STATIC); sqlite3_bind_text(vm, 3, ctx->model, -1, SQLITE_STATIC); @@ -1211,7 +1345,7 @@ static void dbmem_cache_store (dbmem_context *ctx, uint64_t text_hash, const emb int rc = sqlite3_prepare_v2(ctx->db, sql, -1, &vm, NULL); if (rc != SQLITE_OK) goto cleanup; - sqlite3_bind_int64(vm, 1, (sqlite3_int64)text_hash); + dbmem_bind_hash(vm, 1, text_hash); sqlite3_bind_text(vm, 2, ctx->provider, -1, SQLITE_STATIC); sqlite3_bind_text(vm, 3, ctx->model, -1, SQLITE_STATIC); sqlite3_bind_blob(vm, 4, result->embedding, result->n_embd * (int)sizeof(float), SQLITE_STATIC); @@ -1376,17 +1510,30 @@ static int dbmem_process_file (dbmem_context *ctx, const char *path) { static int dbmem_reindex (dbmem_context *ctx) { sqlite3 *db = ctx->db; int rc = SQLITE_OK; + sqlite3_stmt *vm = NULL; + int saved_dimension = ctx->dimension; + bool saved_dimension_saved = ctx->dimension_saved; + bool saved_vector_extension_available = ctx->vector_extension_available; + bool reindex_started = false; // copy all content to a temp table + sqlite3_exec(db, "DROP TABLE IF EXISTS dbmem_reindex;", NULL, NULL, NULL); rc = sqlite3_exec(db, "CREATE TEMP TABLE dbmem_reindex AS SELECT path, value, context FROM dbmem_content;", NULL, NULL, NULL); if (rc != SQLITE_OK) return rc; + rc = sqlite3_exec(db, "SAVEPOINT dbmem_reindex;", NULL, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; + reindex_started = true; + // clear all indexed data if (fts5_is_available) { - sqlite3_exec(db, "DELETE FROM dbmem_vault_fts;", NULL, NULL, NULL); + rc = sqlite3_exec(db, "DELETE FROM dbmem_vault_fts;", NULL, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; } - sqlite3_exec(db, "DELETE FROM dbmem_vault;", NULL, NULL, NULL); - sqlite3_exec(db, "DELETE FROM dbmem_content;", NULL, NULL, NULL); + rc = sqlite3_exec(db, "DELETE FROM dbmem_vault;", NULL, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; + rc = sqlite3_exec(db, "DELETE FROM dbmem_content;", NULL, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; // reset dimension so the new model's dimension is auto-detected ctx->dimension = 0; @@ -1394,11 +1541,10 @@ static int dbmem_reindex (dbmem_context *ctx) { ctx->vector_extension_available = false; // iterate temp table one row at a time - sqlite3_stmt *vm = NULL; rc = sqlite3_prepare_v2(db, "SELECT path, value, context FROM dbmem_reindex;", -1, &vm, NULL); if (rc != SQLITE_OK) goto cleanup; - while (sqlite3_step(vm) == SQLITE_ROW) { + while ((rc = sqlite3_step(vm)) == SQLITE_ROW) { const char *path = (const char *)sqlite3_column_text(vm, 0); const char *value = (const char *)sqlite3_column_text(vm, 1); int value_len = sqlite3_column_bytes(vm, 1); @@ -1408,18 +1554,34 @@ static int dbmem_reindex (dbmem_context *ctx) { ctx->context = context; if (path && dbmem_file_exists(path)) { - dbmem_process_file(ctx, path); + rc = dbmem_process_file(ctx, path); } else if (value && value_len > 0) { ctx->path = path; - dbmem_process_buffer(ctx, value, value_len); + rc = dbmem_process_buffer(ctx, value, value_len); + } else { + rc = SQLITE_OK; } - // else: skip entries that can't be rebuilt + if (rc != SQLITE_OK) goto cleanup; } - rc = SQLITE_OK; + if (rc == SQLITE_DONE) rc = SQLITE_OK; + + if (rc == SQLITE_OK && reindex_started) { + rc = sqlite3_exec(db, "RELEASE dbmem_reindex;", NULL, NULL, NULL); + reindex_started = false; + } cleanup: if (vm) sqlite3_finalize(vm); + if (rc != SQLITE_OK) { + if (reindex_started) { + sqlite3_exec(db, "ROLLBACK TO dbmem_reindex; RELEASE dbmem_reindex;", NULL, NULL, NULL); + } + ctx->dimension = saved_dimension; + ctx->dimension_saved = saved_dimension_saved; + ctx->vector_extension_available = saved_vector_extension_available; + } + dbmem_context_reset_temp_values(ctx); sqlite3_exec(db, "DROP TABLE IF EXISTS dbmem_reindex;", NULL, NULL, NULL); return rc; } @@ -1481,28 +1643,55 @@ static void dbmem_add_file (sqlite3_context *context, int argc, sqlite3_value ** (rc == 0) ? sqlite3_result_int(context, 1) : sqlite3_result_error(context, ctx->error_msg, -1); } +static bool dbmem_path_is_under_directory (const char *path, const char *dir_path) { + if (!path || !dir_path) return false; + + size_t dir_len = strlen(dir_path); + if (dir_len == 0) return false; + + while (dir_len > 1 && (dir_path[dir_len - 1] == '/' || dir_path[dir_len - 1] == '\\')) { + dir_len--; + } + + if (dir_len == 1 && (dir_path[0] == '/' || dir_path[0] == '\\')) { + return path[0] == dir_path[0]; + } + + if (strncmp(path, dir_path, dir_len) != 0) return false; + if (path[dir_len] == '\0') return true; + + return (path[dir_len] == '/' || path[dir_len] == '\\'); +} + static void dbmem_database_delete_missing_files (sqlite3 *db, const char *dir_path) { - char *sql = sqlite3_mprintf("SELECT hash, path FROM dbmem_content WHERE path LIKE '%q%%';", dir_path); - if (!sql) return; - - char **table = NULL; - int nrow = 0, ncol = 0; - int rc = sqlite3_get_table(db, sql, &table, &nrow, &ncol, NULL); - sqlite3_free(sql); - if (rc != SQLITE_OK || nrow == 0) { - if (table) sqlite3_free_table(table); + static const char *sql = "SELECT hash, path FROM dbmem_content WHERE path IS NOT NULL AND path != '';"; + sqlite3_stmt *vm = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL); + if (rc != SQLITE_OK) return; + + rc = dbmem_database_begin_transaction(db); + if (rc != SQLITE_OK) { + sqlite3_finalize(vm); return; } - dbmem_database_begin_transaction(db); - for (int i = 0; i < nrow; i++) { - const char *path = table[ncol + i * ncol + 1]; + while ((rc = sqlite3_step(vm)) == SQLITE_ROW) { + uint64_t hash = 0; + const char *hash_text = (const char *)sqlite3_column_text(vm, 0); + const char *path = (const char *)sqlite3_column_text(vm, 1); + if (!dbmem_path_is_under_directory(path, dir_path)) continue; if (dbmem_file_exists(path)) continue; - sqlite3_int64 hash = strtoll(table[ncol + i * ncol], NULL, 10); + if (!dbmem_hash_from_hex(hash_text, &hash)) continue; dbmem_database_delete_hash(db, hash); } - dbmem_database_commit_transaction(db); - sqlite3_free_table(table); + + sqlite3_finalize(vm); + + if (rc == SQLITE_DONE || rc == SQLITE_OK) { + dbmem_database_commit_transaction(db); + } else { + dbmem_database_rollback_transaction(db); + } } static void dbmem_add_directory (sqlite3_context *context, int argc, sqlite3_value **argv) { @@ -1569,7 +1758,15 @@ static void dbmem_sql_reindex (sqlite3_context *context, int argc, sqlite3_value if (rc != SQLITE_OK) break; int step = sqlite3_step(vm); - if (step != SQLITE_ROW) { sqlite3_finalize(vm); break; } + if (step == SQLITE_DONE) { + sqlite3_finalize(vm); + break; + } + if (step != SQLITE_ROW) { + sqlite3_finalize(vm); + rc = step; + break; + } // Copy row data before finalizing so we can write in the next step const char *path_raw = (const char *)sqlite3_column_text(vm, 0); @@ -1604,7 +1801,7 @@ static void dbmem_sql_reindex (sqlite3_context *context, int argc, sqlite3_value "UPDATE dbmem_content SET hash = ?1 WHERE path = ?2 AND hash != ?1;"; sqlite3_stmt *fix_vm = NULL; if (sqlite3_prepare_v2(db, fix_sql, -1, &fix_vm, NULL) == SQLITE_OK) { - sqlite3_bind_int64(fix_vm, 1, (sqlite3_int64)ctx->hash); + dbmem_bind_hash(fix_vm, 1, ctx->hash); sqlite3_bind_text(fix_vm, 2, path, -1, SQLITE_STATIC); sqlite3_step(fix_vm); sqlite3_finalize(fix_vm); @@ -1649,7 +1846,7 @@ static void dbmem_enable_sync (sqlite3_context *context, int argc, sqlite3_value } } - int rc = sqlite3_exec(db, "SELECT cloudsync_init('dbmem_content', 'cls', 3);", NULL, NULL, NULL); + int rc = sqlite3_exec(db, "SELECT cloudsync_init('dbmem_content');", NULL, NULL, NULL); if (rc != SQLITE_OK) { sqlite3_result_error(context, sqlite3_errmsg(db), -1); return; diff --git a/src/sqlite-memory.h b/src/sqlite-memory.h index f2ddfdc..f4aaefd 100644 --- a/src/sqlite-memory.h +++ b/src/sqlite-memory.h @@ -26,7 +26,7 @@ extern "C" { #endif -#define SQLITE_DBMEMORY_VERSION "0.9.0" +#define SQLITE_DBMEMORY_VERSION "1.0.0" // public API SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi); diff --git a/test/e2e.c b/test/e2e.c index cb1f707..41d97ed 100644 --- a/test/e2e.c +++ b/test/e2e.c @@ -13,6 +13,7 @@ #include #include #include +#include "dbmem-utils.h" #include "sqlite-memory.h" #ifdef _WIN32 @@ -332,12 +333,12 @@ TEST(memory_search) { ASSERT(rc == SQLITE_OK); ASSERT(sqlite3_step(stmt) == SQLITE_ROW); - int64_t hash = sqlite3_column_int64(stmt, 0); + const char *hash = (const char *)sqlite3_column_text(stmt, 0); const char *path = (const char *)sqlite3_column_text(stmt, 1); const char *snippet = (const char *)sqlite3_column_text(stmt, 3); double ranking = sqlite3_column_double(stmt, 4); - ASSERT(hash != 0); + ASSERT(hash != NULL && strlen(hash) == DBMEM_HASH_HEX_LEN); ASSERT(path != NULL && strlen(path) > 0); ASSERT(snippet != NULL && strlen(snippet) > 0); ASSERT(ranking > 0.0 && ranking <= 1.0); @@ -377,6 +378,52 @@ TEST(memory_search_ranking) { ASSERT_SQL_OK(db, "SELECT memory_set_option('min_score', 0.7);"); } +TEST(memory_search_statement_reuse) { + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, + "SELECT hash, snippet FROM memory_search(?1, ?2);", + -1, &stmt, NULL); + ASSERT(rc == SQLITE_OK); + + rc = sqlite3_bind_text(stmt, 1, "fox", -1, SQLITE_STATIC); + ASSERT(rc == SQLITE_OK); + rc = sqlite3_bind_int(stmt, 2, 5); + ASSERT(rc == SQLITE_OK); + + int first_count = 0; + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + const char *hash = (const char *)sqlite3_column_text(stmt, 0); + const char *snippet = (const char *)sqlite3_column_text(stmt, 1); + ASSERT(hash != NULL && strlen(hash) == DBMEM_HASH_HEX_LEN); + ASSERT(snippet != NULL && strlen(snippet) > 0); + first_count++; + } + ASSERT(rc == SQLITE_DONE); + ASSERT(first_count > 0); + + rc = sqlite3_reset(stmt); + ASSERT(rc == SQLITE_OK); + sqlite3_clear_bindings(stmt); + + rc = sqlite3_bind_text(stmt, 1, "SQL database engine", -1, SQLITE_STATIC); + ASSERT(rc == SQLITE_OK); + rc = sqlite3_bind_int(stmt, 2, 10); + ASSERT(rc == SQLITE_OK); + + int second_count = 0; + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + const char *hash = (const char *)sqlite3_column_text(stmt, 0); + const char *snippet = (const char *)sqlite3_column_text(stmt, 1); + ASSERT(hash != NULL && strlen(hash) == DBMEM_HASH_HEX_LEN); + ASSERT(snippet != NULL && strlen(snippet) > 0); + second_count++; + } + ASSERT(rc == SQLITE_DONE); + ASSERT(second_count > 0); + + sqlite3_finalize(stmt); +} + // ============================================================================ // Phase 5: Deletion // ============================================================================ @@ -384,11 +431,11 @@ TEST(memory_search_ranking) { // memory_delete: delete by hash TEST(memory_delete) { // Get a hash from a context-less entry - int64_t hash = 0; + char hash[DBMEM_HASH_STR_MAXLEN] = {0}; sqlite3_stmt *stmt; int rc = sqlite3_prepare_v2(db, "SELECT hash FROM dbmem_content WHERE context IS NULL LIMIT 1;", -1, &stmt, NULL); ASSERT(rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW); - hash = sqlite3_column_int64(stmt, 0); + snprintf(hash, sizeof(hash), "%s", (const char *)sqlite3_column_text(stmt, 0)); sqlite3_finalize(stmt); result_int = 0; @@ -396,7 +443,7 @@ TEST(memory_delete) { int before = result_int; char sql[128]; - snprintf(sql, sizeof(sql), "SELECT memory_delete(%lld);", (long long)hash); + snprintf(sql, sizeof(sql), "SELECT memory_delete('%s');", hash); ASSERT_SQL_OK(db, sql); // Verify count decreased by 1 @@ -494,6 +541,7 @@ int main(void) { // Phase 4: Search (network calls) RUN_TEST(memory_search); RUN_TEST(memory_search_ranking); + RUN_TEST(memory_search_statement_reuse); // Phase 5: Deletion RUN_TEST(memory_delete); diff --git a/test/sync/README.md b/test/sync/README.md index 7078958..009ffcd 100644 --- a/test/sync/README.md +++ b/test/sync/README.md @@ -13,10 +13,10 @@ The test runs in eight phases: | Phase | Description | |-------|-------------| -| 1 | Open separate in-memory databases for Agent A and Agent B | +| 1 | Open separate temporary file-backed databases for Agent A and Agent B | | 2 | Enable CRDT sync on both agents before ingesting content | -| 3 | Each agent ingests its own knowledge (embeddings computed locally) | -| 4 | Pre-sync isolation: each agent can answer its own topic, not the other's | +| 3 | Each agent ingests its own knowledge using the remote `vectors.space` embedding service | +| 4 | Pre-sync isolation: each agent can answer its own topic, not the other's, and context-filtered FTS works locally | | 5 | Connect both agents to the shared SQLiteCloud managed database | | 6 | Bidirectional sync: each agent pushes its content and pulls the other's | | 7 | Reindex: each agent generates embeddings for the newly received content | @@ -35,6 +35,8 @@ AI agents increasingly need to share knowledge without tight coupling. Each agen This is the foundation for distributed agent memory: a growing body of knowledge that no single agent owns but all agents contribute to and benefit from. +The sync test uses the remote `vectors.space` embedding API via `provider='llama'` and `model='embeddinggemma-300m'`. That keeps setup lightweight because users do not need to download a local GGUF model just to run the sync scenario. + ## Dependencies ### sqlite-vector @@ -62,7 +64,7 @@ The sync layer routes changes through a [SQLiteCloud](https://sqlitecloud.io/) m 3. **Create the memory table** — connect to your database and run: ```sql CREATE TABLE IF NOT EXISTS dbmem_content ( - hash INTEGER PRIMARY KEY NOT NULL, + hash TEXT PRIMARY KEY NOT NULL, path TEXT NOT NULL DEFAULT '' UNIQUE, value TEXT DEFAULT NULL, length INTEGER NOT NULL DEFAULT 0, @@ -74,7 +76,7 @@ The sync layer routes changes through a [SQLiteCloud](https://sqlitecloud.io/) m 4. **Enable OffSync** for the database: open the database, click **OffSync**, and enable synchronization. This provisions the CloudSync microservice that routes changes between agents. 5. **Enable OffSync for the table** — initialize sync on `dbmem_content` and configure the `value` column to use block-level LWW so that concurrent agent edits to different lines of the same entry are preserved rather than overwritten. This can be done from the dashboard UI (Database → OffSync section) or via SQL: ```sql - SELECT cloudsync_init('dbmem_content', 'cls', 1); + SELECT cloudsync_init('dbmem_content'); SELECT cloudsync_set_column('dbmem_content', 'value', 'algo', 'block'); ``` 6. **Copy the managed database ID** — shown on the OffSync page (format: `db_xxxxxxxxxxxx`). @@ -84,12 +86,12 @@ The sync layer routes changes through a [SQLiteCloud](https://sqlitecloud.io/) m | Variable | Description | |----------|-------------| +| `APIKEY` | vectors.space API key used for `memory_set_model('llama', 'embeddinggemma-300m')` | | `SYNC_DB_ID` | Managed database ID from the OffSync page | | `SYNC_APIKEY_A` | API key for Agent A | | `SYNC_APIKEY_B` | API key for Agent B | -| `APIKEY` | vectors.space API key for computing embeddings | -The Makefile has defaults baked in for the project's own test database. To use your own: +The Makefile has defaults baked in for the project's own test database. To use your own SQLiteCloud database and keys: ```sh make sync-test \ @@ -107,15 +109,15 @@ make sync-test \ make sync-test DEFINES="-DTEST_SQLITE_EXTENSION" APIKEY=your_vectors_space_key ``` -The test creates two temporary SQLite databases in `/tmp/`, runs the full eight-phase scenario, and cleans up on exit. Expected output: +The test creates two temporary SQLite database files in `/tmp/`, runs the full eight-phase scenario, and cleans up on exit. Expected output: ``` Sync integration test: JWST (Agent A) + Great Barrier Reef (Agent B) ======================================================================= ... === Sync Test Results === -Tests run: 26 -Tests passed: 26 +Tests run: 28 +Tests passed: 28 Tests failed: 0 Sync test passed! @@ -125,6 +127,6 @@ Sync test passed! The test uses `cloudsync_network_sync(500, 3)` called twice per agent in sequence. The first call pushes local changes to the cloud; the second call (after the other agent has pushed) pulls remote changes. No manual version resets or delays are needed; the two-round pattern is the standard approach for bidirectional peer exchange described in the [SQLite Sync documentation](https://github.com/sqliteai/sqlite-sync). -CRDT sync is enabled on `dbmem_content` with the `cls` algorithm. The `value` column (which stores the raw text) uses the `block` algorithm so that line-level changes from concurrent agents are preserved rather than replaced wholesale. +Sync is enabled on `dbmem_content`, and the `value` column (which stores the raw text) is configured with the `block` algorithm so that line-level changes from concurrent agents are preserved rather than replaced wholesale. After receiving content via sync, each agent calls `memory_reindex()` to generate embeddings for the newly arrived rows. Only rows not yet in the local embedding vault are processed, so existing embeddings are never duplicated. diff --git a/test/sync/test_sync.c b/test/sync/test_sync.c index 9db80c8..df8013c 100644 --- a/test/sync/test_sync.c +++ b/test/sync/test_sync.c @@ -11,7 +11,7 @@ // each agent never directly indexed. // // Required environment variables: -// APIKEY — vectors.space API key for embeddings +// APIKEY — vectors.space API key for remote embeddings // VECTOR_LIB — path to sqlite-vector shared library // SYNC_LIB — path to sqlite-sync (cloudsync) shared library // SYNC_DB_ID — SQLiteCloud managed database ID (shared by both agents) @@ -217,13 +217,41 @@ static int count_search_results(sqlite3 *db, const char *query, const char *cont static sqlite3 *db_a = NULL; static sqlite3 *db_b = NULL; -static const char *g_apikey = NULL; +static const char *g_api_key = NULL; static const char *g_vector_lib = NULL; static const char *g_sync_lib = NULL; static const char *g_sync_db_id = NULL; static const char *g_sync_apikey_a = NULL; static const char *g_sync_apikey_b = NULL; +static void step_context_filtered_fts(void) { + printf(" Agent A context-filtered FTS search works... "); + fflush(stdout); + ASSERT_SQL_OK(db_a, "SELECT memory_set_option('vector_weight', 0.0);"); + ASSERT_SQL_OK(db_a, "SELECT memory_set_option('text_weight', 1.0);"); + ASSERT_SQL_OK(db_a, "SELECT memory_set_option('min_score', 0.0);"); + int n = count_search_results(db_a, "James Webb Telescope", CONTEXT_SPACE); + ASSERT(n >= 1); + ASSERT_SQL_OK(db_a, "SELECT memory_set_option('vector_weight', 0.6);"); + ASSERT_SQL_OK(db_a, "SELECT memory_set_option('text_weight', 0.4);"); + ASSERT_SQL_OK(db_a, "SELECT memory_set_option('min_score', 0.7);"); + tests_run++; tests_passed++; + printf("PASSED (%d result(s))\n", n); + + printf(" Agent B context-filtered FTS search works... "); + fflush(stdout); + ASSERT_SQL_OK(db_b, "SELECT memory_set_option('vector_weight', 0.0);"); + ASSERT_SQL_OK(db_b, "SELECT memory_set_option('text_weight', 1.0);"); + ASSERT_SQL_OK(db_b, "SELECT memory_set_option('min_score', 0.0);"); + n = count_search_results(db_b, "Great Barrier Reef", CONTEXT_REEF); + ASSERT(n >= 1); + ASSERT_SQL_OK(db_b, "SELECT memory_set_option('vector_weight', 0.6);"); + ASSERT_SQL_OK(db_b, "SELECT memory_set_option('text_weight', 0.4);"); + ASSERT_SQL_OK(db_b, "SELECT memory_set_option('min_score', 0.7);"); + tests_run++; tests_passed++; + printf("PASSED (%d result(s))\n", n); +} + // Step 1: Open both agent databases static void step_open_databases(void) { printf(" Opening Agent A database... "); @@ -245,8 +273,8 @@ static void step_open_databases(void) { static void step_configure_agents(void) { printf(" Configuring Agent A (space context)... "); fflush(stdout); - char sql[512]; - snprintf(sql, sizeof(sql), "SELECT memory_set_apikey('%s');", g_apikey); + char sql[1024]; + snprintf(sql, sizeof(sql), "SELECT memory_set_apikey('%s');", g_api_key); ASSERT_SQL_OK(db_a, sql); ASSERT_SQL_OK(db_a, "SELECT memory_set_model('llama', 'embeddinggemma-300m');"); tests_run++; tests_passed++; @@ -254,7 +282,7 @@ static void step_configure_agents(void) { printf(" Configuring Agent B (reef context)... "); fflush(stdout); - snprintf(sql, sizeof(sql), "SELECT memory_set_apikey('%s');", g_apikey); + snprintf(sql, sizeof(sql), "SELECT memory_set_apikey('%s');", g_api_key); ASSERT_SQL_OK(db_b, sql); ASSERT_SQL_OK(db_b, "SELECT memory_set_model('llama', 'embeddinggemma-300m');"); tests_run++; tests_passed++; @@ -479,14 +507,14 @@ static void step_postsync_search(void) { // ============================================================================ int main(void) { - g_apikey = getenv("APIKEY"); + g_api_key = getenv("APIKEY"); g_vector_lib = getenv("VECTOR_LIB"); g_sync_lib = getenv("SYNC_LIB"); g_sync_db_id = getenv("SYNC_DB_ID"); g_sync_apikey_a = getenv("SYNC_APIKEY_A"); g_sync_apikey_b = getenv("SYNC_APIKEY_B"); - if (!g_apikey || strlen(g_apikey) == 0 || + if (!g_api_key || strlen(g_api_key) == 0 || !g_vector_lib || strlen(g_vector_lib) == 0 || !g_sync_lib || strlen(g_sync_lib) == 0 || !g_sync_db_id || !g_sync_apikey_a || !g_sync_apikey_b) { @@ -519,6 +547,7 @@ int main(void) { step_ingest_content(); printf("\nPhase 4: Pre-sync search (isolated knowledge)\n"); + step_context_filtered_fts(); step_presync_search(); printf("\nPhase 5: Connect both agents to cloud\n"); diff --git a/test/unittest.c b/test/unittest.c index 7b47fc2..cb0e041 100644 --- a/test/unittest.c +++ b/test/unittest.c @@ -612,6 +612,23 @@ TEST(dbmem_parse_heading_levels) { free_test_ctx(&ctx); } +TEST(dbmem_parse_heading_sections_stay_split) { + const char *input = "# One\nAlpha text.\n\n## Two\nBeta text."; + dbmem_parse_settings settings = default_settings(); + test_ctx_t ctx = {0}; + settings.overlay_tokens = 0; + settings.callback = test_callback; + settings.xdata = &ctx; + + int rc = dbmem_parse(input, strlen(input), &settings); + ASSERT_EQ(rc, 0); + ASSERT_EQ(ctx.count, 2); + ASSERT_STR_EQ(ctx.chunks[0], "One\nAlpha text."); + ASSERT_STR_EQ(ctx.chunks[1], "Two\nBeta text."); + + free_test_ctx(&ctx); +} + TEST(dbmem_parse_heading_trailing_hashes) { const char *input = "## Heading ##\n### Another ###"; dbmem_parse_settings settings = default_settings(); @@ -1355,6 +1372,19 @@ static int exec_get_text(sqlite3 *db, const char *sql, char *result, size_t max_ return rc; } +static dbmem_context *get_test_ctx(sqlite3 *db) { + sqlite3_stmt *stmt = NULL; + dbmem_context *ctx = NULL; + int rc = sqlite3_prepare_v2(db, "SELECT _memory_ctx_ptr();", -1, &stmt, NULL); + + if (rc == SQLITE_OK && sqlite3_step(stmt) == SQLITE_ROW) { + ctx = (dbmem_context *)sqlite3_value_pointer(sqlite3_column_value(stmt, 0), "dbmem_context_ptr"); + } + + if (stmt) sqlite3_finalize(stmt); + return ctx; +} + TEST(sqlite_memory_version) { sqlite3 *db = open_test_db(); ASSERT(db != NULL); @@ -1385,7 +1415,7 @@ TEST(sqlite_memory_delete_nonexistent) { ASSERT(db != NULL); sqlite3_int64 result; - int rc = exec_get_int(db, "SELECT memory_delete(12345);", &result); + int rc = exec_get_int(db, "SELECT memory_delete(printf('%016x', 12345));", &result); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(result, 0); // Should return 0 (no rows deleted) @@ -1414,9 +1444,22 @@ TEST(sqlite_schema_has_timestamps) { "SELECT sql FROM sqlite_master WHERE name='dbmem_content';", sql, sizeof(sql)); ASSERT_EQ(rc, SQLITE_OK); + ASSERT(strstr(sql, "hash TEXT PRIMARY KEY NOT NULL") != NULL); ASSERT(strstr(sql, "created_at") != NULL); ASSERT(strstr(sql, "last_accessed") != NULL); + rc = exec_get_text(db, + "SELECT sql FROM sqlite_master WHERE name='dbmem_vault';", + sql, sizeof(sql)); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT(strstr(sql, "hash TEXT NOT NULL") != NULL); + + rc = exec_get_text(db, + "SELECT sql FROM sqlite_master WHERE name='dbmem_cache';", + sql, sizeof(sql)); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT(strstr(sql, "text_hash TEXT NOT NULL") != NULL); + sqlite3_close(db); } @@ -1428,7 +1471,7 @@ TEST(sqlite_direct_insert_with_timestamp) { // Insert a test record directly int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " - "VALUES (123, 'test/path', 'test value', 10, 'ctx1', strftime('%s','now'));", + "VALUES (printf('%016x', 123), 'test/path', 'test value', 10, 'ctx1', strftime('%s','now'));", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1440,7 +1483,7 @@ TEST(sqlite_direct_insert_with_timestamp) { // Verify created_at was set sqlite3_int64 created_at; - rc = exec_get_int(db, "SELECT created_at FROM dbmem_content WHERE hash=123;", &created_at); + rc = exec_get_int(db, "SELECT created_at FROM dbmem_content WHERE hash = printf('%016x', 123);", &created_at); ASSERT_EQ(rc, SQLITE_OK); ASSERT(created_at > 0); // Should be a valid Unix timestamp @@ -1454,13 +1497,13 @@ TEST(sqlite_memory_delete_direct) { // Insert a test record directly int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " - "VALUES (456, 'test/path2', 'test value 2', 12, 'ctx2', strftime('%s','now'));", + "VALUES (printf('%016x', 456), 'test/path2', 'test value 2', 12, 'ctx2', strftime('%s','now'));", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); // Delete it sqlite3_int64 result; - rc = exec_get_int(db, "SELECT memory_delete(456);", &result); + rc = exec_get_int(db, "SELECT memory_delete(printf('%016x', 456));", &result); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(result, 1); // Should have deleted 1 row @@ -1480,9 +1523,9 @@ TEST(sqlite_memory_delete_context_direct) { // Insert test records with different contexts int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) VALUES " - "(100, 'path1', 'v1', 2, 'ctx_a', 0), " - "(101, 'path2', 'v2', 2, 'ctx_a', 0), " - "(102, 'path3', 'v3', 2, 'ctx_b', 0);", + "(printf('%016x', 100), 'path1', 'v1', 2, 'ctx_a', 0), " + "(printf('%016x', 101), 'path2', 'v2', 2, 'ctx_a', 0), " + "(printf('%016x', 102), 'path3', 'v3', 2, 'ctx_b', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1514,8 +1557,8 @@ TEST(sqlite_memory_clear_direct) { // Insert test records int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) VALUES " - "(200, 'p1', 'v1', 2, 'c1', 0), " - "(201, 'p2', 'v2', 2, 'c2', 0);", + "(printf('%016x', 200), 'p1', 'v1', 2, 'c1', 0), " + "(printf('%016x', 201), 'p2', 'v2', 2, 'c2', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1541,36 +1584,36 @@ TEST(sqlite_memory_delete_with_vault_data) { // Insert into content and vault tables int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " - "VALUES (300, 'path300', 'value', 5, 'ctx', 0);", + "VALUES (printf('%016x', 300), 'path300', 'value', 5, 'ctx', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); rc = sqlite3_exec(db, "INSERT INTO dbmem_vault (hash, seq, embedding, offset, length) " - "VALUES (300, 0, X'00000000', 0, 5), (300, 1, X'00000000', 5, 5);", + "VALUES (printf('%016x', 300), 0, X'00000000', 0, 5), (printf('%016x', 300), 1, X'00000000', 5, 5);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); // Verify vault has data sqlite3_int64 vault_count; - rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault WHERE hash=300;", &vault_count); + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault WHERE hash = printf('%016x', 300);", &vault_count); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(vault_count, 2); // Delete by hash sqlite3_int64 result; - rc = exec_get_int(db, "SELECT memory_delete(300);", &result); + rc = exec_get_int(db, "SELECT memory_delete(printf('%016x', 300));", &result); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(result, 1); // Verify content is gone sqlite3_int64 content_count; - rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_content WHERE hash=300;", &content_count); + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_content WHERE hash = printf('%016x', 300);", &content_count); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(content_count, 0); // Verify vault is also gone - rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault WHERE hash=300;", &vault_count); + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault WHERE hash = printf('%016x', 300);", &vault_count); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(vault_count, 0); @@ -1584,18 +1627,18 @@ TEST(sqlite_memory_delete_twice) { // Insert a record int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " - "VALUES (400, 'path400', 'value', 5, 'ctx', 0);", + "VALUES (printf('%016x', 400), 'path400', 'value', 5, 'ctx', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); // Delete first time - should return 1 sqlite3_int64 result; - rc = exec_get_int(db, "SELECT memory_delete(400);", &result); + rc = exec_get_int(db, "SELECT memory_delete(printf('%016x', 400));", &result); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(result, 1); // Delete second time - should return 0 - rc = exec_get_int(db, "SELECT memory_delete(400);", &result); + rc = exec_get_int(db, "SELECT memory_delete(printf('%016x', 400));", &result); ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(result, 0); @@ -1609,9 +1652,9 @@ TEST(sqlite_memory_delete_context_null) { // Insert records - some with NULL context, some with context int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) VALUES " - "(500, 'p1', 'v1', 2, NULL, 0), " - "(501, 'p2', 'v2', 2, NULL, 0), " - "(502, 'p3', 'v3', 2, 'has_context', 0);", + "(printf('%016x', 500), 'p1', 'v1', 2, NULL, 0), " + "(printf('%016x', 501), 'p2', 'v2', 2, NULL, 0), " + "(printf('%016x', 502), 'p3', 'v3', 2, 'has_context', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1643,7 +1686,7 @@ TEST(sqlite_memory_delete_wrong_type) { sqlite3 *db = open_test_db(); ASSERT(db != NULL); - // Try to call memory_delete with TEXT instead of INTEGER + // Try to call memory_delete with an invalid hash string sqlite3_stmt *stmt = NULL; int rc = sqlite3_prepare_v2(db, "SELECT memory_delete('not_a_number');", -1, &stmt, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1708,13 +1751,13 @@ TEST(sqlite_memory_created_at_valid_range) { // Insert with current timestamp int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " - "VALUES (600, 'path600', 'value', 5, 'ctx', strftime('%s','now'));", + "VALUES (printf('%016x', 600), 'path600', 'value', 5, 'ctx', strftime('%s','now'));", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); // Get the created_at value sqlite3_int64 created_at; - rc = exec_get_int(db, "SELECT created_at FROM dbmem_content WHERE hash=600;", &created_at); + rc = exec_get_int(db, "SELECT created_at FROM dbmem_content WHERE hash = printf('%016x', 600);", &created_at); ASSERT_EQ(rc, SQLITE_OK); // Should be greater than 0 @@ -1739,19 +1782,19 @@ TEST(sqlite_memory_clear_with_vault_fts) { // Insert into all tables int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " - "VALUES (700, 'path700', 'value', 5, 'ctx', 0);", + "VALUES (printf('%016x', 700), 'path700', 'value', 5, 'ctx', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); rc = sqlite3_exec(db, "INSERT INTO dbmem_vault (hash, seq, embedding, offset, length) " - "VALUES (700, 0, X'00000000', 0, 5);", + "VALUES (printf('%016x', 700), 0, X'00000000', 0, 5);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); rc = sqlite3_exec(db, "INSERT INTO dbmem_vault_fts (content, hash, seq, context) " - "VALUES ('test content', 700, 0, 'ctx');", + "VALUES ('test content', printf('%016x', 700), 0, 'ctx');", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1779,13 +1822,14 @@ TEST(sqlite_memory_clear_with_vault_fts) { } // Helper to insert a fake dbmem_content entry with a known path, hash, and length -static int insert_fake_content(sqlite3 *db, sqlite3_int64 hash, const char *path, const char *context, sqlite3_int64 length) { +static int insert_fake_content(sqlite3 *db, uint64_t hash, const char *path, const char *context, sqlite3_int64 length) { sqlite3_stmt *vm = NULL; const char *sql = "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) " "VALUES (?1, ?2, 'fake', ?3, ?4, 0);"; int rc = sqlite3_prepare_v2(db, sql, -1, &vm, NULL); if (rc != SQLITE_OK) return rc; - sqlite3_bind_int64(vm, 1, hash); + char hash_text[DBMEM_HASH_STR_MAXLEN]; + sqlite3_bind_text(vm, 1, dbmem_hash_to_hex(hash, hash_text), -1, SQLITE_TRANSIENT); sqlite3_bind_text(vm, 2, path, -1, SQLITE_STATIC); sqlite3_bind_int64(vm, 3, length); if (context) sqlite3_bind_text(vm, 4, context, -1, SQLITE_STATIC); @@ -1820,7 +1864,7 @@ TEST(sqlite_sync_directory_removes_deleted) { uint64_t keep_hash = dbmem_hash_compute(buf, (size_t)len); dbmemory_free(buf); - int rc = insert_fake_content(db, (sqlite3_int64)keep_hash, file_keep, NULL, len); + int rc = insert_fake_content(db, keep_hash, file_keep, NULL, len); ASSERT_EQ(rc, SQLITE_OK); rc = insert_fake_content(db, 99999, "/tmp/dbmem_test_sync_del/gone.md", NULL, 4); @@ -1873,9 +1917,9 @@ TEST(sqlite_sync_directory_removes_all_deleted) { // Also insert vault entries to verify cascade delete rc = sqlite3_exec(db, "INSERT INTO dbmem_vault (hash, seq, embedding, offset, length) VALUES " - "(1001, 0, X'00000000', 0, 4), " - "(1002, 0, X'00000000', 0, 4), " - "(1003, 0, X'00000000', 0, 4);", + "(printf('%016x', 1001), 0, X'00000000', 0, 4), " + "(printf('%016x', 1002), 0, X'00000000', 0, 4), " + "(printf('%016x', 1003), 0, X'00000000', 0, 4);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -1918,7 +1962,7 @@ TEST(sqlite_sync_directory_skips_unchanged) { // Compute the hash and pre-insert the entry uint64_t hash = dbmem_hash_compute(content, strlen(content)); - int rc = insert_fake_content(db, (sqlite3_int64)hash, file, "notes", (sqlite3_int64)strlen(content)); + int rc = insert_fake_content(db, hash, file, "notes", (sqlite3_int64)strlen(content)); ASSERT_EQ(rc, SQLITE_OK); // Sync — file exists with matching hash, should be skipped @@ -1937,6 +1981,43 @@ TEST(sqlite_sync_directory_skips_unchanged) { sqlite3_close(db); } +TEST(sqlite_sync_directory_ignores_sibling_prefixes) { + sqlite3 *db = open_test_db(); + ASSERT(db != NULL); + + const char *test_dir = TEST_TMP_DIR "/dbmem_test_sync_prefix"; + const char *target_file = TEST_TMP_DIR "/dbmem_test_sync_prefix/gone.md"; + const char *sibling_file = TEST_TMP_DIR "/dbmem_test_sync_prefix2/gone.md"; + + remove_test_file(target_file); + remove_test_file(sibling_file); + rmdir_p(TEST_TMP_DIR "/dbmem_test_sync_prefix2"); + rmdir_p(test_dir); + mkdir_p(test_dir); + + int rc = insert_fake_content(db, 3001, target_file, NULL, 4); + ASSERT_EQ(rc, SQLITE_OK); + rc = insert_fake_content(db, 3002, sibling_file, NULL, 4); + ASSERT_EQ(rc, SQLITE_OK); + + sqlite3_int64 result; + rc = exec_get_int(db, "SELECT memory_add_directory('" TEST_TMP_DIR "/dbmem_test_sync_prefix');", &result); + ASSERT_EQ(rc, SQLITE_OK); + + sqlite3_int64 count = 0; + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_content;", &count); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(count, 1); + + char path[256]; + rc = exec_get_text(db, "SELECT path FROM dbmem_content;", path, sizeof(path)); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_STR_EQ(path, sibling_file); + + rmdir_p(test_dir); + sqlite3_close(db); +} + TEST(sqlite_cache_table_exists) { sqlite3 *db = open_test_db(); ASSERT(db != NULL); @@ -1948,6 +2029,7 @@ TEST(sqlite_cache_table_exists) { sql, sizeof(sql)); ASSERT_EQ(rc, SQLITE_OK); ASSERT(strstr(sql, "text_hash") != NULL); + ASSERT(strstr(sql, "text_hash TEXT NOT NULL") != NULL); ASSERT(strstr(sql, "provider") != NULL); ASSERT(strstr(sql, "model") != NULL); ASSERT(strstr(sql, "embedding") != NULL); @@ -1975,9 +2057,9 @@ TEST(sqlite_cache_clear_with_data) { // Insert some fake cache entries int rc = sqlite3_exec(db, "INSERT INTO dbmem_cache (text_hash, provider, model, embedding, dimension) VALUES " - "(111, 'openai', 'text-embedding-3-small', X'00000000', 1), " - "(222, 'openai', 'text-embedding-3-small', X'00000000', 1), " - "(333, 'local', 'nomic', X'00000000', 1);", + "(printf('%016x', 111), 'openai', 'text-embedding-3-small', X'00000000', 1), " + "(printf('%016x', 222), 'openai', 'text-embedding-3-small', X'00000000', 1), " + "(printf('%016x', 333), 'local', 'nomic', X'00000000', 1);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -2003,9 +2085,9 @@ TEST(sqlite_cache_clear_by_provider_model) { // Insert cache entries for different provider/model combos int rc = sqlite3_exec(db, "INSERT INTO dbmem_cache (text_hash, provider, model, embedding, dimension) VALUES " - "(111, 'openai', 'text-embedding-3-small', X'00000000', 1), " - "(222, 'openai', 'text-embedding-3-small', X'00000000', 1), " - "(333, 'local', 'nomic', X'00000000', 1);", + "(printf('%016x', 111), 'openai', 'text-embedding-3-small', X'00000000', 1), " + "(printf('%016x', 222), 'openai', 'text-embedding-3-small', X'00000000', 1), " + "(printf('%016x', 333), 'local', 'nomic', X'00000000', 1);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -2093,11 +2175,11 @@ TEST(sqlite_cache_eviction) { // Insert 5 entries (rowids 1-5) rc = sqlite3_exec(db, "INSERT INTO dbmem_cache (text_hash, provider, model, embedding, dimension) VALUES " - "(1, 'p', 'm', X'00000000', 1), " - "(2, 'p', 'm', X'00000000', 1), " - "(3, 'p', 'm', X'00000000', 1), " - "(4, 'p', 'm', X'00000000', 1), " - "(5, 'p', 'm', X'00000000', 1);", + "(printf('%016x', 1), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 2), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 3), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 4), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 5), 'p', 'm', X'00000000', 1);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -2114,9 +2196,9 @@ TEST(sqlite_cache_eviction) { // Insert exactly 3 (at limit) rc = sqlite3_exec(db, "INSERT INTO dbmem_cache (text_hash, provider, model, embedding, dimension) VALUES " - "(10, 'p', 'm', X'00000000', 1), " - "(11, 'p', 'm', X'00000000', 1), " - "(12, 'p', 'm', X'00000000', 1);", + "(printf('%016x', 10), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 11), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 12), 'p', 'm', X'00000000', 1);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -2135,11 +2217,11 @@ TEST(sqlite_cache_no_eviction_when_unlimited) { // Insert many entries, none should be evicted int rc = sqlite3_exec(db, "INSERT INTO dbmem_cache (text_hash, provider, model, embedding, dimension) VALUES " - "(1, 'p', 'm', X'00000000', 1), " - "(2, 'p', 'm', X'00000000', 1), " - "(3, 'p', 'm', X'00000000', 1), " - "(4, 'p', 'm', X'00000000', 1), " - "(5, 'p', 'm', X'00000000', 1);", + "(printf('%016x', 1), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 2), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 3), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 4), 'p', 'm', X'00000000', 1), " + "(printf('%016x', 5), 'p', 'm', X'00000000', 1);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -2177,6 +2259,30 @@ TEST(sqlite_search_oversample_setting) { sqlite3_close(db); } +TEST(sqlite_search_zero_value_settings_apply_to_context) { + sqlite3 *db = open_test_db(); + ASSERT(db != NULL); + + sqlite3_int64 result = 0; + int rc = exec_get_int(db, "SELECT memory_set_option('max_results', 0);", &result); + ASSERT_EQ(rc, SQLITE_OK); + rc = exec_get_int(db, "SELECT memory_set_option('vector_weight', 0.0);", &result); + ASSERT_EQ(rc, SQLITE_OK); + rc = exec_get_int(db, "SELECT memory_set_option('text_weight', 0.0);", &result); + ASSERT_EQ(rc, SQLITE_OK); + rc = exec_get_int(db, "SELECT memory_set_option('min_score', 0.0);", &result); + ASSERT_EQ(rc, SQLITE_OK); + + dbmem_context *ctx = get_test_ctx(db); + ASSERT(ctx != NULL); + ASSERT_EQ(dbmem_context_max_results(ctx), 0); + ASSERT_EQ(dbmem_context_vector_weight(ctx), 0.0); + ASSERT_EQ(dbmem_context_text_weight(ctx), 0.0); + ASSERT_EQ(dbmem_context_min_score(ctx), 0.0); + + sqlite3_close(db); +} + TEST(sqlite_memory_delete_context_with_vault) { sqlite3 *db = open_test_db(); ASSERT(db != NULL); @@ -2184,15 +2290,15 @@ TEST(sqlite_memory_delete_context_with_vault) { // Insert records with different contexts into content and vault int rc = sqlite3_exec(db, "INSERT INTO dbmem_content (hash, path, value, length, context, created_at) VALUES " - "(800, 'p1', 'v1', 2, 'delete_me', 0), " - "(801, 'p2', 'v2', 2, 'keep_me', 0);", + "(printf('%016x', 800), 'p1', 'v1', 2, 'delete_me', 0), " + "(printf('%016x', 801), 'p2', 'v2', 2, 'keep_me', 0);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); rc = sqlite3_exec(db, "INSERT INTO dbmem_vault (hash, seq, embedding, offset, length) VALUES " - "(800, 0, X'00000000', 0, 2), " - "(801, 0, X'00000000', 0, 2);", + "(printf('%016x', 800), 0, X'00000000', 0, 2), " + "(printf('%016x', 801), 0, X'00000000', 0, 2);", NULL, NULL, NULL); ASSERT_EQ(rc, SQLITE_OK); @@ -2213,9 +2319,10 @@ TEST(sqlite_memory_delete_context_with_vault) { ASSERT_EQ(rc, SQLITE_OK); ASSERT_EQ(count, 1); - rc = exec_get_int(db, "SELECT hash FROM dbmem_vault;", &result); + char hash[64]; + rc = exec_get_text(db, "SELECT hash FROM dbmem_vault;", hash, sizeof(hash)); ASSERT_EQ(rc, SQLITE_OK); - ASSERT_EQ(result, 801); + ASSERT_STR_EQ(hash, "0000000000000321"); sqlite3_close(db); } @@ -2271,6 +2378,26 @@ static void *dummy_init_fail(const char *model, const char *api_key, void *xdata return NULL; } +typedef struct { + int fail_after; + int calls; +} flaky_provider_state_t; + +static void *flaky_init(const char *model, const char *api_key, void *xdata, char err_msg[1024]) { + flaky_provider_state_t *state = (flaky_provider_state_t *)xdata; + if (state) state->calls = 0; + return dummy_init(model, api_key, NULL, err_msg); +} + +static int flaky_compute(void *engine, const char *text, int text_len, void *xdata, dbmem_embedding_result_t *result) { + flaky_provider_state_t *state = (flaky_provider_state_t *)xdata; + if (state) { + state->calls++; + if (state->calls >= state->fail_after) return -1; + } + return dummy_compute(engine, text, text_len, NULL, result); +} + TEST(sqlite_custom_provider_register) { sqlite3 *db = open_test_db(); ASSERT(db != NULL); @@ -2393,6 +2520,126 @@ TEST(sqlite_custom_provider_apikey_passed) { sqlite3_close(db); } +TEST(sqlite_set_model_failed_reindex_preserves_existing_rows) { + sqlite3 *db = open_test_db(); + ASSERT(db != NULL); + + dbmem_provider_t ok_prov = { .init = dummy_init, .compute = dummy_compute, .free = dummy_free }; + int rc = sqlite3_memory_register_provider(db, "dummy", &ok_prov); + ASSERT_EQ(rc, SQLITE_OK); + + sqlite3_int64 result = 0; + rc = exec_get_int(db, "SELECT memory_set_model('dummy', 'test-model');", &result); + ASSERT_EQ(rc, SQLITE_OK); + rc = exec_get_int(db, "SELECT memory_add_text('Persist me through failed reindex.', 'keep');", &result); + ASSERT_EQ(rc, SQLITE_OK); + + sqlite3_int64 count = 0; + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_content;", &count); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(count, 1); + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault;", &count); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT(count >= 1); + + flaky_provider_state_t state = { .fail_after = 1, .calls = 0 }; + dbmem_provider_t flaky_prov = { .init = flaky_init, .compute = flaky_compute, .free = dummy_free, .xdata = &state }; + rc = sqlite3_memory_register_provider(db, "flaky", &flaky_prov); + ASSERT_EQ(rc, SQLITE_OK); + + sqlite3_stmt *stmt = NULL; + rc = sqlite3_prepare_v2(db, "SELECT memory_set_model('flaky', 'test-model');", -1, &stmt, NULL); + ASSERT_EQ(rc, SQLITE_OK); + rc = sqlite3_step(stmt); + ASSERT_EQ(rc, SQLITE_ERROR); + sqlite3_finalize(stmt); + + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_content;", &count); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(count, 1); + rc = exec_get_int(db, "SELECT COUNT(*) FROM dbmem_vault;", &count); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT(count >= 1); + + char context[64]; + rc = exec_get_text(db, "SELECT context FROM dbmem_content;", context, sizeof(context)); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_STR_EQ(context, "keep"); + + char provider[64]; + rc = exec_get_text(db, "SELECT memory_get_option('provider');", provider, sizeof(provider)); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_STR_EQ(provider, "dummy"); + + char model[64]; + rc = exec_get_text(db, "SELECT memory_get_option('model');", model, sizeof(model)); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_STR_EQ(model, "test-model"); + + sqlite3_close(db); +} + +// Regression: when memory_set_model() switches from a custom provider to a +// remote provider (a different provider class), the previous custom engine +// must be released immediately — not leaked until the database is closed. +typedef struct { + int free_count; +} tracking_free_state_t; + +static void *tracking_init(const char *model, const char *api_key, void *xdata, char err_msg[1024]) { + UNUSED_PARAM(model); + UNUSED_PARAM(api_key); + UNUSED_PARAM(xdata); + UNUSED_PARAM(err_msg); + // any non-NULL pointer is fine; the test only cares about the free callback + return calloc(1, 1); +} + +static int tracking_compute(void *engine, const char *text, int text_len, void *xdata, dbmem_embedding_result_t *result) { + UNUSED_PARAM(engine); + UNUSED_PARAM(text); + UNUSED_PARAM(text_len); + UNUSED_PARAM(xdata); + UNUSED_PARAM(result); + return -1; +} + +static void tracking_free(void *engine, void *xdata) { + tracking_free_state_t *s = (tracking_free_state_t *)xdata; + if (s) s->free_count++; + free(engine); +} + +TEST(sqlite_set_model_releases_previous_engine_on_class_switch) { + sqlite3 *db = open_test_db(); + ASSERT(db != NULL); + + // remote engine init requires an api key to succeed + sqlite3_int64 result = 0; + int rc = exec_get_int(db, "SELECT memory_set_apikey('test-key');", &result); + ASSERT_EQ(rc, SQLITE_OK); + + tracking_free_state_t state = {0}; + dbmem_provider_t prov = { .init = tracking_init, .compute = tracking_compute, .free = tracking_free, .xdata = &state }; + rc = sqlite3_memory_register_provider(db, "tracker", &prov); + ASSERT_EQ(rc, SQLITE_OK); + + // activate the custom provider — ctx->custom_engine is now non-NULL + rc = exec_get_int(db, "SELECT memory_set_model('tracker', 'm1');", &result); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(state.free_count, 0); + + // switch to a provider from a different class (remote). The previous + // custom engine must be released during this call, not kept alive on ctx. + rc = exec_get_int(db, "SELECT memory_set_model('openai', 'text-embedding-3-small');", &result); + ASSERT_EQ(rc, SQLITE_OK); + ASSERT_EQ(state.free_count, 1); + + // closing the db must not double-free the already-released custom engine + sqlite3_close(db); + ASSERT_EQ(state.free_count, 1); +} + #endif // TEST_SQLITE_EXTENSION // ============================================================================ @@ -2434,6 +2681,7 @@ int main(int argc, char *argv[]) { RUN_TEST(dbmem_parse_shortcut_link); RUN_TEST(dbmem_parse_nested_blockquote); RUN_TEST(dbmem_parse_heading_levels); + RUN_TEST(dbmem_parse_heading_sections_stay_split); RUN_TEST(dbmem_parse_heading_trailing_hashes); RUN_TEST(dbmem_parse_multiline_html); RUN_TEST(dbmem_parse_blank_lines); @@ -2496,6 +2744,7 @@ int main(int argc, char *argv[]) { RUN_TEST(sqlite_sync_directory_removes_deleted); RUN_TEST(sqlite_sync_directory_removes_all_deleted); RUN_TEST(sqlite_sync_directory_skips_unchanged); + RUN_TEST(sqlite_sync_directory_ignores_sibling_prefixes); printf("\nSQLite extension advanced tests:\n"); RUN_TEST(sqlite_memory_delete_with_vault_data); @@ -2520,6 +2769,7 @@ int main(int argc, char *argv[]) { printf("\nSearch oversampling tests:\n"); RUN_TEST(sqlite_search_oversample_setting); + RUN_TEST(sqlite_search_zero_value_settings_apply_to_context); printf("\nCustom provider tests:\n"); RUN_TEST(sqlite_custom_provider_register); @@ -2528,6 +2778,8 @@ int main(int argc, char *argv[]) { RUN_TEST(sqlite_custom_provider_null_callbacks); RUN_TEST(sqlite_custom_provider_init_error); RUN_TEST(sqlite_custom_provider_apikey_passed); + RUN_TEST(sqlite_set_model_failed_reindex_preserves_existing_rows); + RUN_TEST(sqlite_set_model_releases_previous_engine_on_class_switch); #endif printf("\n=== Results ===\n");