From 93fdd5dcfc583fc427b479e289f01d59a5035af7 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Tue, 19 May 2026 12:14:18 -0700 Subject: [PATCH 01/23] fix(mem_wal): allow append-only tables without primary keys (#6848) ## Summary - allow MemWAL initialization on append-only tables without unenforced primary key metadata - keep bucket sharding constrained to the primary key when a primary key exists, but allow no-PK tables to bucket by a non-nested column - update Rust, Python, and Java tests plus MemWAL docs Fixes #6846 ## Tests - cargo fmt --all - cargo test -p lance test_initialize_mem_wal_bucket_sharding - uv run make build - uv run pytest python/tests/test_mem_wal.py::test_initialize_mem_wal_bucket_sharding_without_primary_key - git diff --check --- docs/pyproject.toml | 2 +- docs/src/format/index/system/mem_wal.md | 2 + docs/src/format/table/mem_wal.md | 24 +++++-- java/src/main/java/org/lance/Dataset.java | 5 +- .../java/org/lance/memwal/MemWalTest.java | 49 ++++++++++++++ python/python/tests/test_mem_wal.py | 33 ++++++++++ python/src/dataset.rs | 6 +- rust/lance/src/dataset/mem_wal/api.rs | 64 +++++++++++-------- rust/lance/src/dataset/mem_wal/write.rs | 53 +++++++++++++++ 9 files changed, 200 insertions(+), 38 deletions(-) diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 7ef4e59666c..4112230aec5 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -3,7 +3,7 @@ name = "lance-docs" version = "0.1.0" description = "Documentation for Lance - Modern columnar data format for ML and LLMs" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10,<3.11" dependencies = [ "mkdocs>=1.5.0", "mkdocs-material>=9.4.0", diff --git a/docs/src/format/index/system/mem_wal.md b/docs/src/format/index/system/mem_wal.md index ac8696aa5cc..44515f0af28 100644 --- a/docs/src/format/index/system/mem_wal.md +++ b/docs/src/format/index/system/mem_wal.md @@ -4,6 +4,8 @@ The MemWAL Index is a system index that serves as the centralized structure for It stores configuration (shard specs, indexes to maintain), merge progress, and shard state snapshots. A table has at most one MemWAL index. +The table may be a primary-key table or an append-only table without primary-key metadata. +Primary-key-dependent lookup and deduplication semantics only apply when a primary key is defined. For the complete specification, see: diff --git a/docs/src/format/table/mem_wal.md b/docs/src/format/table/mem_wal.md index 892b38d1d01..8a228721123 100644 --- a/docs/src/format/table/mem_wal.md +++ b/docs/src/format/table/mem_wal.md @@ -8,7 +8,9 @@ scan, point lookup, vector search and full-text search. ![MemWAL Overview](../../images/mem_wal_overview.png) A Lance table is called a **base table** under the context of the MemWAL spec. -It must have an [unenforced primary key](index.md#unenforced-primary-key) defined in the table schema. +It may have an [unenforced primary key](index.md#unenforced-primary-key) defined in the table schema. +Primary keys are required for primary-key lookups and last-write-wins upsert semantics, +but append-only MemWAL tables may omit them. On top of the base table, the MemWAL spec defines a set of shards. Writers write to shards, and data in each shard is merged into the base table asynchronously. @@ -22,7 +24,7 @@ Each shard has exactly one active writer at any time. Writers claim a shard and then write data to that shard. Data in each shard is expected to be merged into the base table asynchronously. -Rows of the same primary key must be written to one and only one shard. +For tables with a primary key, rows of the same primary key must be written to one and only one shard. If two shards contain rows with the same primary key, the following scenario can cause data corruption: 1. Shard A receives a write with primary key `pk=1` at time T1 @@ -34,6 +36,8 @@ If two shards contain rows with the same primary key, the following scenario can This violates the expected "last write wins" semantics. By ensuring each primary key is assigned to exactly one shard via the sharding spec, merge order between shards becomes irrelevant for correctness. +Append-only tables without a primary key do not rely on last-write-wins conflict resolution +and may shard by any deterministic append key or partitioning column. See [MemWAL Shard Architecture](#shard-architecture) for the complete shard architecture. @@ -162,8 +166,9 @@ The content within the generation directory follows the [Lance table storage lay Generation numbers determine merge order of flushed MemTable into base table: lower numbers represent older data and must be merged to the base table first to preserve correct upsert semantics. -Within a single flushed MemTable, if there are multiple rows of the same primary key, -the row that is last inserted wins. +Within a single flushed MemTable for a primary-key table, +if there are multiple rows of the same primary key, the row that is last inserted wins. +Append-only tables without a primary key retain all inserted rows. ### Shard Manifest @@ -479,7 +484,8 @@ The garbage collector removes obsolete data from shard directories. Flushed MemT ### LSM Tree Merging Read -Readers **MUST** merge results from multiple data sources (base table, flushed MemTables, in-memory MemTables) by primary key to ensure correctness. +For tables with a primary key, readers **MUST** merge results from multiple data sources +(base table, flushed MemTables, in-memory MemTables) by primary key to ensure correctness. When the same primary key exists in multiple sources, the reader must keep only the newest version based on: @@ -496,6 +502,10 @@ This deduplication is essential because: Without proper merging, queries would return duplicate or stale rows. +Append-only tables without a primary key do not perform primary-key deduplication. +Readers should include the relevant base table, flushed MemTables, and in-memory MemTables +according to the requested consistency level; duplicate values are treated as distinct appended rows. + ### Reader Consistency Reader consistency depends on two factors: @@ -526,7 +536,9 @@ Datasets come from: Each dataset is tagged with a generation number: 0 for the base table, and positive integers for MemTable generations. Within a shard, the generation number determines data freshness, with higher numbers representing newer data. -Rows from different shards do not need deduplication since each primary key maps to exactly one shard. +For primary-key tables, rows from different shards do not need deduplication +since each primary key maps to exactly one shard. +Append-only tables without a primary key do not require cross-shard primary-key deduplication. The planner also collects bloom filters from each generation for staleness detection during search queries. diff --git a/java/src/main/java/org/lance/Dataset.java b/java/src/main/java/org/lance/Dataset.java index 1647204582b..02d61fb708b 100644 --- a/java/src/main/java/org/lance/Dataset.java +++ b/java/src/main/java/org/lance/Dataset.java @@ -2030,8 +2030,9 @@ private native MergeInsertResult nativeMergeInsert( /** * Initialize MemWAL on this dataset. * - *

Must be called once before any call to {@link #memWalWriter}. The dataset schema must have - * at least one field carrying the {@code lance-schema:unenforced-primary-key} metadata. + *

Must be called once before any call to {@link #memWalWriter}. Append-only tables may omit + * primary-key metadata; primary keys are only required for primary-key lookup and last-write-wins + * deduplication workflows. * * @param params MemWAL initialization parameters */ diff --git a/java/src/test/java/org/lance/memwal/MemWalTest.java b/java/src/test/java/org/lance/memwal/MemWalTest.java index 6e1bd6bd4f7..d3aaebf8891 100644 --- a/java/src/test/java/org/lance/memwal/MemWalTest.java +++ b/java/src/test/java/org/lance/memwal/MemWalTest.java @@ -72,6 +72,14 @@ public class MemWalTest { new Field( "id", new FieldType(false, new ArrowType.Int(64, true), null, PK_META), null), Field.nullable("name", new ArrowType.Utf8()))); + private static final Schema APPEND_ONLY_SCHEMA = + new Schema( + Arrays.asList( + new Field( + "id", + new FieldType(false, new ArrowType.Int(64, true), null, Collections.emptyMap()), + null), + Field.nullable("name", new ArrowType.Utf8()))); /** Build a single-batch root where {@code name = "{prefix}_{id}"}. */ private static VectorSchemaRoot lookupRoot(BufferAllocator allocator, long[] ids, String prefix) { @@ -88,6 +96,22 @@ private static VectorSchemaRoot lookupRoot(BufferAllocator allocator, long[] ids return root; } + /** Build a single-batch append-only root without primary-key metadata. */ + private static VectorSchemaRoot appendOnlyRoot( + BufferAllocator allocator, long[] ids, String prefix) { + VectorSchemaRoot root = VectorSchemaRoot.create(APPEND_ONLY_SCHEMA, allocator); + BigIntVector idVector = (BigIntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + idVector.allocateNew(ids.length); + nameVector.allocateNew(); + for (int i = 0; i < ids.length; i++) { + idVector.set(i, ids[i]); + nameVector.setSafe(i, (prefix + "_" + ids[i]).getBytes(StandardCharsets.UTF_8)); + } + root.setRowCount(ids.length); + return root; + } + /** Wrap an in-memory root into an {@link ArrowReader} via the Arrow IPC stream format. */ private static ArrowReader toReader(BufferAllocator allocator, VectorSchemaRoot root) throws Exception { @@ -109,6 +133,15 @@ private static Dataset writeLookupDataset( } } + /** Write an append-only base dataset of `(id, name)` rows at {@code path}. */ + private static Dataset writeAppendOnlyDataset( + BufferAllocator allocator, String path, long[] ids, String prefix) throws Exception { + try (VectorSchemaRoot root = appendOnlyRoot(allocator, ids, prefix); + ArrowReader reader = toReader(allocator, root)) { + return Dataset.write().allocator(allocator).reader(reader).uri(path).execute(); + } + } + /** Read an LSM scanner fully into an {@code id -> name} map. */ private static Map readByName(ArrowReader reader) throws Exception { Map byId = new HashMap<>(); @@ -145,6 +178,22 @@ void testInitializeMemWalUnsharded(@TempDir Path tempDir) throws Exception { } } + @Test + void testInitializeMemWalBucketShardingWithoutPrimaryKey(@TempDir Path tempDir) throws Exception { + String path = tempDir.resolve("append_only").toString(); + try (BufferAllocator allocator = new RootAllocator(); + Dataset dataset = writeAppendOnlyDataset(allocator, path, new long[] {1, 2, 3}, "base")) { + dataset.initializeMemWal(new InitializeMemWalParams().withBucketSharding("id", 4)); + + Optional details = dataset.memWalIndexDetails(); + assertTrue(details.isPresent()); + assertEquals(4L, details.get().numShards()); + ShardingField field = details.get().shardingSpecs().get(0).fields().get(0); + assertEquals("bucket", field.transform().get()); + assertEquals("4", field.parameters().get("num_buckets")); + } + } + @Test void testInitializeMemWalRejectsConflictingSharding(@TempDir Path tempDir) throws Exception { String path = tempDir.resolve("base").toString(); diff --git a/python/python/tests/test_mem_wal.py b/python/python/tests/test_mem_wal.py index e63aacff57b..88397e94167 100644 --- a/python/python/tests/test_mem_wal.py +++ b/python/python/tests/test_mem_wal.py @@ -21,6 +21,12 @@ pa.field("name", pa.utf8()), ] ) +_APPEND_ONLY_SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.utf8()), + ] +) def _lookup_table(ids, prefix: str) -> pa.Table: @@ -34,6 +40,17 @@ def _lookup_table(ids, prefix: str) -> pa.Table: ) +def _append_only_table(ids, prefix: str) -> pa.Table: + """Build a table without primary-key metadata.""" + return pa.table( + { + "id": pa.array(ids, pa.int64()), + "name": pa.array([f"{prefix}_{i}" for i in ids], pa.utf8()), + }, + schema=_APPEND_ONLY_SCHEMA, + ) + + def _write_flushed_gen(base_path: str, region_id: str, gen_folder: str, data: pa.Table): """Write a flushed-generation Lance dataset at the expected sub-path. @@ -367,6 +384,22 @@ def test_initialize_mem_wal_bucket_sharding(tmp_path): assert field["parameters"]["num_buckets"] == "8" +def test_initialize_mem_wal_bucket_sharding_without_primary_key(tmp_path): + ds_path = str(tmp_path / "append_only") + ds = lance.write_dataset( + _append_only_table([1, 2, 3], "base"), + ds_path, + schema=_APPEND_ONLY_SCHEMA, + ) + ds.initialize_mem_wal(bucket_column="id", num_buckets=8) + + details = ds.mem_wal_index_details() + assert details["num_shards"] == 8 + field = details["sharding_specs"][0]["fields"][0] + assert field["transform"] == "bucket" + assert field["parameters"]["num_buckets"] == "8" + + def test_initialize_mem_wal_identity_sharding(tmp_path): ds = _mem_wal_dataset(tmp_path) ds.initialize_mem_wal(identity_column="name") diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 128f3f38b53..ec3214f011b 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -3228,9 +3228,9 @@ impl Dataset { /// Initialize MemWAL on this dataset. /// - /// Must be called once before any `mem_wal_writer()` calls. Requires the - /// dataset schema to have at least one field with the - /// `lance-schema:unenforced-primary-key` metadata. + /// Must be called once before any `mem_wal_writer()` calls. Append-only + /// tables may omit primary-key metadata; primary keys are only required + /// for primary-key lookup and last-write-wins deduplication workflows. /// /// At most one sharding mode may be selected: bucket sharding /// (`bucket_column` + `num_buckets`), identity sharding (`identity_column`), diff --git a/rust/lance/src/dataset/mem_wal/api.rs b/rust/lance/src/dataset/mem_wal/api.rs index a00337dd131..65623c4f5e0 100644 --- a/rust/lance/src/dataset/mem_wal/api.rs +++ b/rust/lance/src/dataset/mem_wal/api.rs @@ -63,8 +63,7 @@ enum Sharding { Manual, /// A single shard; every row is routed to it. Unsharded, - /// Hash-bucket the single-column unenforced primary key into `num_buckets` - /// shards. + /// Hash-bucket a shard key into `num_buckets` shards. Bucket { column: String, num_buckets: u32 }, /// Shard by the raw value of `column` (identity transform). Identity { column: String }, @@ -111,11 +110,13 @@ impl<'a> InitializeMemWalBuilder<'a> { self } - /// Hash-bucket the unenforced primary key into `num_buckets` shards. + /// Hash-bucket `column` into `num_buckets` shards. /// - /// `column` must name the dataset's single-column unenforced primary key; - /// `num_buckets` must be in `[1, 1024]`. Both are validated by - /// [`execute`](Self::execute). + /// For primary-key tables, `column` must name the dataset's single-column + /// unenforced primary key so every update for the same key routes to the + /// same shard. Append-only tables without a primary key may use any scalar + /// column. `num_buckets` must be in `[1, 1024]`. These constraints are + /// validated by [`execute`](Self::execute). pub fn bucket_sharding(mut self, column: impl Into, num_buckets: u32) -> Self { self.sharding = Sharding::Bucket { column: column.into(), @@ -127,10 +128,10 @@ impl<'a> InitializeMemWalBuilder<'a> { /// Shard by the raw value of `column` (the identity transform). /// /// Each distinct value of `column` becomes its own shard; use this when the - /// data is already partitioned by that column. The caller is responsible - /// for ensuring every primary key maps consistently to a single value of - /// `column`. `column` must be a scalar column that exists on the dataset; - /// it is validated by [`execute`](Self::execute). + /// data is already partitioned by that column. For primary-key tables, the + /// caller is responsible for ensuring every primary key maps consistently + /// to a single value of `column`. `column` must be a scalar column that + /// exists on the dataset; it is validated by [`execute`](Self::execute). pub fn identity_sharding(mut self, column: impl Into) -> Self { self.sharding = Sharding::Identity { column: column.into(), @@ -142,7 +143,8 @@ impl<'a> InitializeMemWalBuilder<'a> { /// previously set list. /// /// Each name must reference an index that already exists on the dataset. - /// The primary key btree is maintained implicitly and must not be listed. + /// The primary key btree, when present, is maintained implicitly and must + /// not be listed. pub fn maintained_indexes(mut self, indexes: I) -> Self where I: IntoIterator, @@ -184,8 +186,8 @@ impl<'a> InitializeMemWalBuilder<'a> { /// Initialize MemWAL on the dataset, committing the MemWAL system index. /// - /// Fails if the dataset has no unenforced primary key, if any maintained - /// index does not exist, or if MemWAL is already initialized. + /// Fails if any maintained index does not exist, if the selected sharding + /// configuration is invalid, or if MemWAL is already initialized. pub async fn execute(self) -> Result<()> { let Self { dataset, @@ -194,13 +196,6 @@ impl<'a> InitializeMemWalBuilder<'a> { writer_config_defaults, } = self; - if dataset.schema().unenforced_primary_key().is_empty() { - return Err(Error::invalid_input( - "MemWAL requires a primary key on the dataset. \ - Define a primary key using the 'lance-schema:unenforced-primary-key' Arrow field metadata.", - )); - } - // Resolve (and validate) the sharding choice before any I/O. let (sharding_specs, num_shards) = resolve_sharding(dataset, sharding)?; @@ -288,8 +283,23 @@ fn bucket_sharding_spec(dataset: &Dataset, column: &str, num_buckets: u32) -> Re } let pk_fields = dataset.schema().unenforced_primary_key(); - let pk = match pk_fields.as_slice() { - [single] => *single, + let source_field = match pk_fields.as_slice() { + [single] => { + let pk = *single; + if pk.name.as_str() != column { + return Err(Error::invalid_input(format!( + "bucket_sharding: column '{}' does not match the unenforced primary key column '{}'", + column, pk.name + ))); + } + pk + } + [] => dataset.schema().field(column).ok_or_else(|| { + Error::invalid_input(format!( + "bucket_sharding: column '{}' not found on the dataset", + column + )) + })?, _ => { return Err(Error::invalid_input( "bucket_sharding requires a single-column unenforced primary key; \ @@ -297,10 +307,12 @@ fn bucket_sharding_spec(dataset: &Dataset, column: &str, num_buckets: u32) -> Re )); } }; - if pk.name.as_str() != column { + + let data_type = source_field.data_type(); + if data_type.is_nested() || data_type.is_null() { return Err(Error::invalid_input(format!( - "bucket_sharding: column '{}' does not match the unenforced primary key column '{}'", - column, pk.name + "bucket_sharding: column '{}' has type {:?}, which cannot be used as a shard key", + column, data_type ))); } @@ -308,7 +320,7 @@ fn bucket_sharding_spec(dataset: &Dataset, column: &str, num_buckets: u32) -> Re spec_id: SHARDING_SPEC_ID, fields: vec![ShardingField { field_id: SHARDING_FIELD_ID.to_string(), - source_ids: vec![pk.id], + source_ids: vec![source_field.id], transform: Some(BUCKET_TRANSFORM.to_string()), expression: None, result_type: SHARDING_RESULT_TYPE.to_string(), diff --git a/rust/lance/src/dataset/mem_wal/write.rs b/rust/lance/src/dataset/mem_wal/write.rs index d8aaddca4ad..5f6a2af7243 100644 --- a/rust/lance/src/dataset/mem_wal/write.rs +++ b/rust/lance/src/dataset/mem_wal/write.rs @@ -4213,6 +4213,21 @@ mod shard_writer_tests { ])) } + fn create_append_only_schema(vector_dim: i32) -> Arc { + Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + vector_dim, + ), + true, + ), + Field::new("text", DataType::Utf8, true), + ])) + } + fn create_test_batch( schema: &ArrowSchema, start_id: i64, @@ -4363,6 +4378,44 @@ mod shard_writer_tests { ); } + #[tokio::test] + async fn test_initialize_mem_wal_bucket_sharding_without_primary_key() { + let vector_dim = 128; + let schema = create_append_only_schema(vector_dim); + let uri = format!( + "memory://test_bucket_sharding_no_primary_key_{}", + Uuid::new_v4() + ); + + let initial_batch = create_test_batch(&schema, 0, 100, vector_dim); + let batches = RecordBatchIterator::new([Ok(initial_batch)], schema.clone()); + let mut dataset = Dataset::write(batches, &uri, Some(WriteParams::default())) + .await + .expect("Failed to create dataset"); + + dataset + .initialize_mem_wal() + .bucket_sharding("id", 8) + .execute() + .await + .expect("Failed to initialize append-only MemWAL"); + + let details = dataset + .mem_wal_index_details() + .await + .expect("Failed to read MemWAL index details") + .expect("MemWAL index details should exist"); + + assert_eq!(details.num_shards, 8); + assert_eq!(details.sharding_specs.len(), 1); + let field = &details.sharding_specs[0].fields[0]; + assert_eq!(field.transform.as_deref(), Some("bucket")); + assert_eq!( + field.parameters.get("num_buckets").map(String::as_str), + Some("8") + ); + } + #[tokio::test] async fn test_initialize_mem_wal_unsharded() { let vector_dim = 128; From dd887ec76845fae53795cceb098bec9316a05c5d Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Tue, 19 May 2026 13:13:50 -0700 Subject: [PATCH 02/23] feat: add manifest version hint for fast latest-version lookup (#6752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Carries on #5997 (and the benchmarking in discussion #5947), and follows up on #6728 where moving S3 Express away from O(n) manifest listing to a version hint was raised — picking that up here. ## What On object stores where `list` is **not** lexicographically ordered (e.g. S3 Express, the local filesystem), resolving the latest manifest version is O(n) in the number of versions. To avoid this, after every successful commit on such a store we write a small JSON file `_versions/latest_version_hint.json` with content `{"version":N}`. A reader then does a GET on the hint file plus a few HEAD probes (O(k), where k = versions added since the hint was written), and falls back to a full listing if the hint is missing (older datasets) or stale. - The hint is written/read **only on non-lexically-ordered stores**. On S3 Standard / GCS / Azure / OSS / Tencent / DynamoDB / memory the ordered listing already resolves the latest version in roughly one request, so the hint would only add a PUT per commit for nothing. - `current_manifest_path` uses the hint for non-lexically-ordered, non-local stores (the local filesystem keeps its existing single-directory-read fast path); `CommitHandler::list_manifest_locations_since` (used by `load_new_transactions`) follows the same strategy. - The hint write is **awaited** as part of the commit (no fire-and-forget mode). It is best-effort: failures are logged and ignored, since the hint only accelerates reads and never affects correctness — readers always verify the hinted version and probe upward from it. Detached versions are never written to the hint. - A transient (non-`NotFound`) object-store error while probing abandons the hint path so the caller falls back to a full listing rather than trust a possibly-stale or incomplete result. The gap-fill HEADs are bounded by `io_parallelism()`, and a far-behind reader (gap > 1000) falls back to a single paginated listing. ## Differences from #5997 - Only the JSON hint format is kept (the alternative file-size-encoded format and its env var are dropped). - The fire-and-forget / async hint-write mode is removed — the hint is always written synchronously, which keeps concurrent writes simpler with no meaningful latency cost. - The hint is gated to non-lexically-ordered stores, where it's actually read. - `current_manifest_path` picks one strategy based on the store rather than racing a HEAD-probe against a listing, keeping IO behavior deterministic. A `manifest_commit` benchmark is included to measure commit/load latency growth with many small fragments. Co-Authored-By: Jack Ye --- docs/src/format/table/layout.md | 15 +- java/src/test/java/org/lance/DatasetTest.java | 7 +- python/python/tests/test_dataset.py | 17 +- rust/lance-namespace-impls/src/dir.rs | 17 +- rust/lance-table/src/io/commit.rs | 531 +++++++++++++++++- .../src/io/commit/external_manifest.rs | 7 +- rust/lance/Cargo.toml | 8 + rust/lance/benches/concurrent_append.rs | 454 +++++++++++++++ rust/lance/benches/manifest_commit.rs | 371 ++++++++++++ rust/lance/src/dataset.rs | 17 +- .../src/dataset/tests/dataset_versioning.rs | 2 + rust/lance/src/dataset/write/commit.rs | 81 ++- rust/lance/src/io/commit/external_manifest.rs | 18 + 13 files changed, 1516 insertions(+), 29 deletions(-) create mode 100644 rust/lance/benches/concurrent_append.rs create mode 100644 rust/lance/benches/manifest_commit.rs diff --git a/docs/src/format/table/layout.md b/docs/src/format/table/layout.md index 46efa56a908..7b08ce0a0dd 100644 --- a/docs/src/format/table/layout.md +++ b/docs/src/format/table/layout.md @@ -20,7 +20,8 @@ A Lance dataset in its basic form stores all files within the dataset root direc data/ *.lance -- Data files containing column data _versions/ - *.manifest -- Manifest files (one per version) + *.manifest -- Manifest files (one per version) + latest_version_hint.json -- Optional hint of the latest version (see below) _transactions/ *.txn -- Transaction files for commit coordination _deletions/ @@ -201,3 +202,15 @@ Manifest files are stored in the `_versions/` directory with naming schemes that See [Manifest Naming Schemes](transaction.md#manifest-naming-schemes) for details on the V1 and V2 patterns and their implications for version discovery. +### Version Hint + +The optional file `_versions/latest_version_hint.json` records the latest committed version as JSON: + +```json +{"version": 42} +``` + +It exists to accelerate latest-version discovery on stores where listing `_versions/` is expensive: a reader can read the hint and probe higher versions with HEAD requests instead of listing the whole directory, falling back to a full listing if the hint is missing or stale. + +The hint is purely an optimization. It is always safe to delete, never affects correctness, and can be ignored by readers that don't understand it. Writers may choose not to write it. + diff --git a/java/src/test/java/org/lance/DatasetTest.java b/java/src/test/java/org/lance/DatasetTest.java index c01154b0b71..315b010da1e 100644 --- a/java/src/test/java/org/lance/DatasetTest.java +++ b/java/src/test/java/org/lance/DatasetTest.java @@ -528,7 +528,12 @@ void testOpenSerializedManifest(@TempDir Path tempDir) throws IOException { assertEquals(1, dataset1.version()); Path manifestPath = datasetPath.resolve("_versions"); try (Stream fileStream = Files.list(manifestPath)) { - assertEquals(1, fileStream.count()); + // Ignore the version hint file, which is not a manifest. + assertEquals( + 1, + fileStream + .filter(p -> !p.getFileName().toString().startsWith("latest_version_hint")) + .count()); ByteBuffer manifestBuffer = readManifest(manifestPath.resolve("1.manifest")); try (Dataset dataset2 = testDataset.write(1, 5)) { assertEquals(2, dataset2.version()); diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 428b85095ce..dc2030c4b2b 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -436,18 +436,27 @@ def test_has_stable_row_ids_property(tmp_path: Path): assert lance.dataset(non_stable_path).has_stable_row_ids is False +def _list_manifests(versions_dir): + # Ignore the version hint file, which is not a manifest. + return [ + name + for name in os.listdir(versions_dir) + if not name.startswith("latest_version_hint") + ] + + def test_v2_manifest_paths(tmp_path: Path): lance.write_dataset( pa.table({"a": range(100)}), tmp_path, enable_v2_manifest_paths=True ) - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert len(manifest_path) == 1 assert re.match(r"\d{20}\.manifest", manifest_path[0]) def test_default_v2_manifest_paths(tmp_path: Path): lance.write_dataset(pa.table({"a": range(100)}), tmp_path) - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert len(manifest_path) == 1 assert re.match(r"\d{20}\.manifest", manifest_path[0]) @@ -457,12 +466,12 @@ def test_v2_manifest_paths_migration(tmp_path: Path): lance.write_dataset( pa.table({"a": range(100)}), tmp_path, enable_v2_manifest_paths=False ) - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert manifest_path == ["1.manifest"] # Migrate to v2 manifest paths lance.dataset(tmp_path).migrate_manifest_paths_v2() - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert len(manifest_path) == 1 assert re.match(r"\d{20}\.manifest", manifest_path[0]) diff --git a/rust/lance-namespace-impls/src/dir.rs b/rust/lance-namespace-impls/src/dir.rs index 425993e3956..4b9c69b739a 100644 --- a/rust/lance-namespace-impls/src/dir.rs +++ b/rust/lance-namespace-impls/src/dir.rs @@ -9546,8 +9546,9 @@ mod tests { .await .unwrap(); - // table_exists first checks __manifest (one list on __manifest/_versions), - // then falls back to the table directory (one list_with_delimiter on test_table.lance). + // table_exists first checks __manifest (which on local FS uses the + // version hint and does no list call), then falls back to the table + // directory (one list_with_delimiter on test_table.lance). listing_count.store(0, Ordering::SeqCst); let mut exists_req = TableExistsRequest::new(); @@ -9556,9 +9557,9 @@ mod tests { let count = listing_count.load(Ordering::SeqCst); assert_eq!( - count, 2, - "Expected exactly 2 listing calls for table_exists with migration mode \ - (manifest reload + table directory fallback), but got {}", + count, 1, + "Expected exactly 1 listing call for table_exists with migration mode \ + (table directory fallback; manifest reload uses the version hint), but got {}", count ); @@ -9571,9 +9572,9 @@ mod tests { let count = listing_count.load(Ordering::SeqCst); assert_eq!( - count, 2, - "Expected exactly 2 listing calls for describe_table with migration mode \ - (manifest reload + table directory fallback), but got {}", + count, 1, + "Expected exactly 1 listing call for describe_table with migration mode \ + (table directory fallback; manifest reload uses the version hint), but got {}", count ); } diff --git a/rust/lance-table/src/io/commit.rs b/rust/lance-table/src/io/commit.rs index c90909d7db5..5dbf62002fa 100644 --- a/rust/lance-table/src/io/commit.rs +++ b/rust/lance-table/src/io/commit.rs @@ -70,6 +70,13 @@ use { pub const VERSIONS_DIR: &str = "_versions"; const MANIFEST_EXTENSION: &str = "manifest"; const DETACHED_VERSION_PREFIX: &str = "d"; +/// File name for the JSON version hint file, stored under `_versions/`. +/// +/// The file contains `{"version":N}` where `N` is the latest committed version +/// at the time of writing. It enables O(1)/O(k) latest-version lookup via HEAD +/// requests on object stores where listing is not lexicographically ordered +/// (e.g. S3 Express, local filesystem) instead of an O(n) listing. +const VERSION_HINT_FILE: &str = "latest_version_hint.json"; /// How manifest files should be named. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -260,17 +267,292 @@ impl TryFrom for ManifestLocation { } } -/// Get the latest manifest path +/// Get the latest manifest path. +/// +/// - Local filesystem: a single directory read. +/// - Stores where listing is not lexicographically ordered (e.g. S3 Express): +/// the version hint (read the hint file, then probe higher versions with +/// HEADs), falling back to a listing if the hint is missing or stale. A full +/// listing on these stores is O(n) in the number of versions. +/// - Lexicographically ordered stores (e.g. S3 Standard, GCS): the listing +/// already resolves the latest version in roughly one request. async fn current_manifest_path( object_store: &ObjectStore, base: &Path, ) -> Result { - if object_store.is_local() - && let Ok(Some(location)) = current_manifest_local(base) + if object_store.is_local() { + if let Ok(Some(location)) = current_manifest_local(base) { + return Ok(location); + } + } else if uses_version_hint(object_store) + && let Some(location) = read_version_hint_and_probe(object_store, base).await { return Ok(location); } + resolve_version_from_listing(object_store, base).await +} + +/// JSON body of the version hint file: `{"version":N}`. +#[derive(serde::Serialize, serde::Deserialize)] +struct VersionHint { + version: u64, +} + +/// Set `LANCE_USE_VERSION_HINT=0` (or `false`) to globally disable the version +/// hint — writers stop emitting the hint file and readers stop consulting it, +/// falling back to plain listing. Intended as a benchmark/escape-hatch knob; +/// the hint is on by default. +const VERSION_HINT_ENV: &str = "LANCE_USE_VERSION_HINT"; + +fn version_hint_globally_enabled() -> bool { + static ENABLED: std::sync::OnceLock = std::sync::OnceLock::new(); + *ENABLED.get_or_init(|| match std::env::var(VERSION_HINT_ENV) { + Ok(v) => !matches!( + v.trim().to_ascii_lowercase().as_str(), + "0" | "false" | "off" + ), + Err(_) => true, + }) +} + +/// Whether this object store benefits from a version hint. +/// +/// On stores where listing is lexicographically ordered (S3 Standard, GCS, +/// Azure, ...) the latest version is already resolved in roughly one request, +/// so the hint would only add a write per commit for nothing. We write (and +/// read) it only on stores where listing is not lexicographically ordered — +/// S3 Express and the local filesystem. Can be force-disabled with the +/// `LANCE_USE_VERSION_HINT=0` environment variable. +pub fn uses_version_hint(object_store: &ObjectStore) -> bool { + version_hint_globally_enabled() && !object_store.list_is_lexically_ordered +} + +/// Path to the JSON version hint file for a dataset. +fn version_hint_path(base: &Path) -> Path { + base.clone().join(VERSIONS_DIR).join(VERSION_HINT_FILE) +} + +/// Write the version hint file after a successful commit. +/// +/// The hint is stored as JSON: `{"version":N}`. This write is best-effort — +/// failures are logged and ignored, since the hint only accelerates reads and +/// never affects correctness (readers verify the hinted version and probe +/// upward from there). It is a no-op for detached versions and for stores that +/// do not benefit from a hint (see [`uses_version_hint`]). +pub async fn write_version_hint(object_store: &ObjectStore, base: &Path, version: u64) { + if is_detached_version(version) || !uses_version_hint(object_store) { + return; + } + let hint_path = version_hint_path(base); + let content = serde_json::to_vec(&VersionHint { version }).expect("serialize version hint"); + if let Err(e) = object_store.put(&hint_path, content.as_slice()).await { + warn!("Failed to write version hint file for version {version}: {e}"); + } +} + +/// Read the latest version from the hint file, or `None` if it does not exist +/// or cannot be parsed. +async fn read_version_from_hint(object_store: &ObjectStore, base: &Path) -> Option { + let bytes = object_store + .inner + .get(&version_hint_path(base)) + .await + .ok()? + .bytes() + .await + .ok()?; + Some(serde_json::from_slice::(&bytes).ok()?.version) +} + +/// Read the version hint and probe upward to find the true latest manifest. +/// +/// Returns `None` if the hint file is missing, the hinted version no longer +/// exists, or any error occurred — callers should fall back to listing. +async fn read_version_hint_and_probe( + object_store: &ObjectStore, + base: &Path, +) -> Option { + let hint_version = read_version_from_hint(object_store, base).await?; + let (version, scheme, mut probed) = probe_versions_upward(object_store, base, hint_version) + .await + .ok() + .flatten()?; + // `probed` is non-empty and its last entry is the highest version found. + let (_, meta) = probed.pop()?; + Some(ManifestLocation { + version, + path: scheme.manifest_path(base, version), + size: Some(meta.size), + naming_scheme: scheme, + e_tag: meta.e_tag, + }) +} + +/// Maximum version gap between the hint and the read version for which we use +/// the hint-based parallel-HEAD path; beyond this a single (paginated) listing +/// is cheaper, so callers fall back to it. +const MAX_HINT_PROBE_GAP: u64 = 1000; + +/// Probe `from_version`, then `from_version + 1`, `+ 2`, ... with HEAD requests +/// until one is not found. +/// +/// Assumes attached versions are contiguous above `from_version` (true in +/// practice: every commit increments by one, and cleanup only removes *old* +/// versions, never ones newer than the latest). A `NotFound` therefore marks +/// the end of the history. +/// +/// - `Ok(Some((true_latest_version, naming_scheme, [(version, meta), ...])))`: +/// the vec covers every version from `from_version` through the true latest +/// in ascending order. +/// - `Ok(None)`: `from_version` itself does not exist (a `NotFound` for both +/// naming schemes) — i.e. the hint pointed past the end. +/// - `Err(_)`: a transient object-store error was hit, so the probed range may +/// be incomplete; callers should fall back to a full listing rather than +/// trust a possibly-stale result. +async fn probe_versions_upward( + object_store: &ObjectStore, + base: &Path, + from_version: u64, +) -> Result< + Option<( + u64, + ManifestNamingScheme, + Vec<(u64, object_store::ObjectMeta)>, + )>, +> { + // Newer datasets use V2; fall back to V1 if the V2 path is not found. + let mut scheme = ManifestNamingScheme::V2; + let meta = match object_store + .inner + .head(&scheme.manifest_path(base, from_version)) + .await + { + Ok(meta) => meta, + Err(ObjectStoreError::NotFound { .. }) => { + scheme = ManifestNamingScheme::V1; + match object_store + .inner + .head(&scheme.manifest_path(base, from_version)) + .await + { + Ok(meta) => meta, + Err(ObjectStoreError::NotFound { .. }) => return Ok(None), + Err(e) => return Err(e.into()), + } + } + Err(e) => return Err(e.into()), + }; + + let mut probed = vec![(from_version, meta)]; + let mut version = from_version; + loop { + let next = version + 1; + match object_store + .inner + .head(&scheme.manifest_path(base, next)) + .await + { + Ok(meta) => { + probed.push((next, meta)); + version = next; + } + // NotFound means we found the latest version. + Err(ObjectStoreError::NotFound { .. }) => break, + // A transient error means a newer version might exist that we + // failed to observe — surface it so callers fall back to listing. + Err(e) => return Err(e.into()), + } + } + Ok(Some((version, scheme, probed))) +} + +/// List manifest locations with version `> since_version` using the version +/// hint, in descending order of version. +/// +/// Returns `None` if the hint is missing or stale enough that this is not +/// usable — callers should fall back to a full listing. `Some(vec![])` is the +/// fast path where the hint confirms there are no new versions. +async fn list_manifests_since_version_with_hint( + object_store: &ObjectStore, + base: &Path, + since_version: u64, +) -> Option> { + let hint_version = read_version_from_hint(object_store, base).await?; + + // A reader that is very far behind is cheaper to serve with one paginated + // listing than with thousands of HEADs. + if hint_version.saturating_sub(since_version) > MAX_HINT_PROBE_GAP { + return None; + } + + // If the hint is not newer than the read version, the only versions that + // could exist are right above it; otherwise start at the hint. + let probe_from = if hint_version > since_version { + hint_version + } else { + since_version + 1 + }; + + let (scheme, probed) = match probe_versions_upward(object_store, base, probe_from).await { + Ok(Some((_true_latest, scheme, probed))) => (scheme, probed), + // Nothing at `probe_from`. If we were probing from the hint, the hint + // is stale — bail to a full listing. If we were probing from + // `since_version + 1`, there are simply no new versions. + Ok(None) if hint_version > since_version => return None, + Ok(None) => return Some(Vec::new()), + // Transient error: don't trust the hint path, fall back to listing. + Err(_) => return None, + }; + + let mut locations: Vec = probed + .into_iter() + .filter(|(v, _)| *v > since_version) + .map(|(version, meta)| ManifestLocation { + version, + path: scheme.manifest_path(base, version), + size: Some(meta.size), + naming_scheme: scheme, + e_tag: meta.e_tag, + }) + .collect(); + + // Fill the gap between `since_version` and the hint with HEADs (the probe + // above already covered `hint_version` and up). The range is contiguous, so + // any error here (including a `NotFound`) means we can't trust the hint path + // — fall back to a full listing. + if hint_version > since_version + 1 { + let gap_locations: Vec = + futures::stream::iter((since_version + 1)..hint_version) + .map(|version| async move { + object_store + .inner + .head(&scheme.manifest_path(base, version)) + .await + .map(|meta| ManifestLocation { + version, + path: scheme.manifest_path(base, version), + size: Some(meta.size), + naming_scheme: scheme, + e_tag: meta.e_tag, + }) + }) + .buffer_unordered(object_store.io_parallelism()) + .try_collect() + .await + .ok()?; + locations.extend(gap_locations); + } + + locations.sort_by_key(|loc| std::cmp::Reverse(loc.version)); + Some(locations) +} + +/// Resolve the latest manifest by listing the versions directory. +async fn resolve_version_from_listing( + object_store: &ObjectStore, + base: &Path, +) -> Result { let manifest_files = object_store.list(Some(base.clone().join(VERSIONS_DIR))); let mut valid_manifests = manifest_files.try_filter_map(|res| { @@ -588,6 +870,51 @@ pub trait CommitHandler: Debug + Send + Sync { } } + /// List manifest locations with version `> since_version`, in descending + /// order of version. + /// + /// On lexically-ordered stores this is the standard listing with early + /// termination. On non-lexically-ordered stores (e.g. S3 Express) it uses + /// the version hint to avoid an O(n) listing, falling back to a full + /// listing if the hint is missing or stale. + fn list_manifest_locations_since<'a>( + &self, + base_path: &Path, + object_store: &'a ObjectStore, + since_version: u64, + ) -> BoxStream<'a, Result> { + if !uses_version_hint(object_store) { + return self + .list_manifest_locations(base_path, object_store, true) + .try_take_while(move |loc| future::ready(Ok(loc.version > since_version))) + .boxed(); + } + + let base_path = base_path.clone(); + futures::stream::once(async move { + let locations = match list_manifests_since_version_with_hint( + object_store, + &base_path, + since_version, + ) + .await + { + Some(locations) => locations, + None => { + let mut locations = list_manifests(&base_path, &object_store.inner) + .try_collect::>() + .await?; + locations.retain(|loc| loc.version > since_version); + locations.sort_by_key(|loc| std::cmp::Reverse(loc.version)); + locations + } + }; + Ok::<_, Error>(futures::stream::iter(locations.into_iter().map(Ok))) + }) + .try_flatten() + .boxed() + } + /// Commit a manifest. /// /// This function should return an [CommitError::CommitConflict] if another @@ -877,6 +1204,8 @@ impl CommitHandler for UnsafeCommitHandler { let res = manifest_writer(object_store, manifest, indices, &version_path, transaction).await?; + write_version_hint(object_store, base_path, manifest.version).await; + Ok(ManifestLocation { version: manifest.version, size: Some(res.size as u64), @@ -960,6 +1289,9 @@ impl CommitHandler for T { lease.release(res.is_ok()).await?; let res = res?; + + write_version_hint(object_store, base_path, manifest.version).await; + Ok(ManifestLocation { version: manifest.version, size: Some(res.size as u64), @@ -1028,6 +1360,7 @@ impl CommitHandler for RenameCommitHandler { { Ok(_) => { // Successfully committed + write_version_hint(object_store, base_path, manifest.version).await; Ok(ManifestLocation { version: manifest.version, path, @@ -1103,6 +1436,8 @@ impl CommitHandler for ConditionalPutCommitHandler { _ => CommitError::OtherError(err.into()), })?; + write_version_hint(object_store, base_path, manifest.version).await; + Ok(ManifestLocation { version: manifest.version, path, @@ -1282,6 +1617,196 @@ mod tests { assert_eq!(location.path, naming_scheme.manifest_path(&base, 11)); } + /// A memory store that reports `list_is_lexically_ordered == false`, like + /// S3 Express, so the version-hint paths are exercised. + fn non_lexical_memory_store() -> Box { + let mut object_store = ObjectStore::memory(); + object_store.list_is_lexically_ordered = false; + Box::new(object_store) + } + + #[tokio::test] + async fn test_write_version_hint() { + let base = Path::from("base"); + + // No hint is written on lexically-ordered stores (it would not be read). + let lexical = ObjectStore::memory(); + write_version_hint(&lexical, &base, 42).await; + assert_eq!(read_version_from_hint(&lexical, &base).await, None); + + let object_store = non_lexical_memory_store(); + write_version_hint(&object_store, &base, 42).await; + assert_eq!(read_version_from_hint(&object_store, &base).await, Some(42)); + + // A later commit overwrites the hint. + write_version_hint(&object_store, &base, 100).await; + assert_eq!( + read_version_from_hint(&object_store, &base).await, + Some(100) + ); + + // Detached versions are never written to the hint. + write_version_hint( + &object_store, + &base, + crate::format::DETACHED_VERSION_MASK | 7, + ) + .await; + assert_eq!( + read_version_from_hint(&object_store, &base).await, + Some(100) + ); + + // A corrupt / non-JSON hint file is treated as missing. + let hint_path = version_hint_path(&base); + object_store + .put(&hint_path, b"not json".as_slice()) + .await + .unwrap(); + assert_eq!(read_version_from_hint(&object_store, &base).await, None); + } + + #[tokio::test] + #[rstest::rstest] + async fn test_read_version_hint_and_probe( + #[values(ManifestNamingScheme::V1, ManifestNamingScheme::V2)] + naming_scheme: ManifestNamingScheme, + ) { + let object_store = non_lexical_memory_store(); + let base = Path::from("base"); + + // No hint file yet. + assert!( + read_version_hint_and_probe(&object_store, &base) + .await + .is_none() + ); + + for version in 1..=5 { + object_store + .put(&naming_scheme.manifest_path(&base, version), b"".as_slice()) + .await + .unwrap(); + } + + // Stale hint: should probe forward and find version 5. + write_version_hint(&object_store, &base, 3).await; + let location = read_version_hint_and_probe(&object_store, &base) + .await + .unwrap(); + assert_eq!(location.version, 5); + assert_eq!(location.naming_scheme, naming_scheme); + + // Up-to-date hint: returns version 5 directly. + write_version_hint(&object_store, &base, 5).await; + let location = read_version_hint_and_probe(&object_store, &base) + .await + .unwrap(); + assert_eq!(location.version, 5); + + // Hint points past the latest version: not usable. + write_version_hint(&object_store, &base, 10).await; + assert!( + read_version_hint_and_probe(&object_store, &base) + .await + .is_none() + ); + } + + #[tokio::test] + async fn test_list_manifests_since_version_with_hint() { + let object_store = non_lexical_memory_store(); + let base = Path::from("base"); + let scheme = ManifestNamingScheme::V2; + + for version in 1..=10 { + object_store + .put(&scheme.manifest_path(&base, version), b"".as_slice()) + .await + .unwrap(); + } + + // No hint yet -> not usable, caller must fall back. + assert!( + list_manifests_since_version_with_hint(&object_store, &base, 7) + .await + .is_none() + ); + + // Hint exactly at the read version -> fast path, nothing new. + write_version_hint(&object_store, &base, 10).await; + assert!(matches!( + list_manifests_since_version_with_hint(&object_store, &base, 10).await, + Some(v) if v.is_empty() + )); + + // Hint ahead of the read version, with a gap to fill (8, 9) plus probing + // from the hint (10). Results are descending by version. + let locations = list_manifests_since_version_with_hint(&object_store, &base, 7) + .await + .unwrap(); + assert_eq!( + locations.iter().map(|l| l.version).collect::>(), + vec![10, 9, 8] + ); + + // Slightly stale hint (points at 8) still probes up to the true latest. + write_version_hint(&object_store, &base, 8).await; + let locations = list_manifests_since_version_with_hint(&object_store, &base, 7) + .await + .unwrap(); + assert_eq!( + locations.iter().map(|l| l.version).collect::>(), + vec![10, 9, 8] + ); + + // Hint points past the latest -> not usable, caller falls back. + write_version_hint(&object_store, &base, 20).await; + assert!( + list_manifests_since_version_with_hint(&object_store, &base, 7) + .await + .is_none() + ); + } + + #[tokio::test] + async fn test_current_manifest_path_with_hint_non_lexical() { + // Simulate S3 Express (non-lexically ordered list) with many versions. + let object_store = non_lexical_memory_store(); + let base = Path::from("base"); + let naming_scheme = ManifestNamingScheme::V2; + + for version in 1..=100 { + object_store + .put(&naming_scheme.manifest_path(&base, version), b"".as_slice()) + .await + .unwrap(); + } + + // Slightly stale hint: probing from 98 still resolves the true latest. + write_version_hint(&object_store, &base, 98).await; + let location = current_manifest_path(&object_store, &base).await.unwrap(); + assert_eq!(location.version, 100); + } + + #[tokio::test] + async fn test_current_manifest_path_with_stale_hint_falls_back_to_listing() { + let object_store = non_lexical_memory_store(); + let base = Path::from("base"); + let naming_scheme = ManifestNamingScheme::V2; + + // Only version 5 exists, but the hint claims version 10. + object_store + .put(&naming_scheme.manifest_path(&base, 5), b"".as_slice()) + .await + .unwrap(); + write_version_hint(&object_store, &base, 10).await; + + // The stale hint is ignored; listing finds version 5. + let location = current_manifest_path(&object_store, &base).await.unwrap(); + assert_eq!(location.version, 5); + } + #[test] fn test_parse_detached_version() { // Valid detached version filenames diff --git a/rust/lance-table/src/io/commit/external_manifest.rs b/rust/lance-table/src/io/commit/external_manifest.rs index 49d651fe6eb..75993ca8d1f 100644 --- a/rust/lance-table/src/io/commit/external_manifest.rs +++ b/rust/lance-table/src/io/commit/external_manifest.rs @@ -21,7 +21,7 @@ use tracing::info; use super::{ MANIFEST_EXTENSION, ManifestLocation, ManifestNamingScheme, current_manifest_path, - default_resolve_version, make_staging_manifest_path, + default_resolve_version, make_staging_manifest_path, write_version_hint, }; use crate::format::{IndexMetadata, Manifest, Transaction}; use crate::io::commit::{CommitError, CommitHandler}; @@ -490,7 +490,10 @@ impl CommitHandler for ExternalManifestCommitHandler { .await; match result { - Ok(location) => Ok(location), + Ok(location) => { + write_version_hint(object_store, base_path, manifest.version).await; + Ok(location) + } Err(_) => { // delete the staging manifest match object_store.inner.delete(&staging_path).await { diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 9f055fc9154..1f7c4af0bd1 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -247,5 +247,13 @@ name = "mem_wal_fts_bench" path = "benches/mem_wal/fts/mem_wal_fts_bench.rs" harness = false +[[bench]] +name = "manifest_commit" +harness = false + +[[bench]] +name = "concurrent_append" +harness = false + [lints] workspace = true diff --git a/rust/lance/benches/concurrent_append.rs b/rust/lance/benches/concurrent_append.rs new file mode 100644 index 00000000000..ac7cf3f610f --- /dev/null +++ b/rust/lance/benches/concurrent_append.rs @@ -0,0 +1,454 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark for concurrent-append throughput against S3 / S3 Express. +//! +//! Many writers append to the same dataset at once. The output measures how +//! the version-hint optimization affects conflict resolution and overall +//! commit rate as the version count grows. Designed to be run on a single +//! large EC2 instance so the writer count itself isn't the bottleneck. +//! +//! ## Running against S3 Standard +//! +//! ```bash +//! export AWS_REGION=us-east-1 +//! export DATASET_URI=s3://jack-devland-build/bench/concurrent_append +//! export NUM_WRITERS=64 +//! export APPENDS_PER_WRITER=200 +//! cargo bench --bench concurrent_append --release +//! ``` +//! +//! ## Running against S3 Express +//! +//! ```bash +//! export AWS_REGION=us-east-1 +//! export DATASET_URI=s3://jack-lancedb-devland--use1-az24--x-s3/bench/concurrent_append +//! export NUM_WRITERS=64 +//! export APPENDS_PER_WRITER=200 +//! cargo bench --bench concurrent_append --release +//! ``` +//! +//! ## Configuration +//! +//! - `DATASET_URI`: base URI under which a uniquely-named dataset is created. +//! Required. +//! - `NUM_WRITERS`: number of concurrent writers (default 64). +//! - `APPENDS_PER_WRITER`: appends each writer attempts (default 200). +//! - `ROWS_PER_APPEND`: rows per appended batch (default 100). +//! - `BASE_ROWS`: rows in the initial table before concurrent writes begin +//! (default 100_000). +//! - `KEEP_DATASET`: when set to `true`, leaves the dataset in place after +//! the run (default: deleted on S3, kept on local). + +#![allow(clippy::print_stdout, clippy::print_stderr)] + +use arrow_array::{Int64Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use lance::dataset::{Dataset, InsertBuilder, WriteMode, WriteParams, builder::DatasetBuilder}; +use lance::session::Session; +use lance_io::object_store::{ObjectStoreParams, ObjectStoreRegistry, StorageOptionsAccessor}; +use std::collections::HashMap; +use std::env; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use uuid::Uuid; + +const DEFAULT_NUM_WRITERS: usize = 64; +const DEFAULT_APPENDS_PER_WRITER: usize = 200; +const DEFAULT_ROWS_PER_APPEND: usize = 100; +const DEFAULT_BASE_ROWS: usize = 100_000; + +fn env_usize(key: &str, default: usize) -> usize { + env::var(key) + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +fn env_bool(key: &str) -> bool { + env::var(key) + .map(|s| s.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +fn storage_label(uri: &str) -> &'static str { + if uri.contains("--x-s3") { + "s3express" + } else if uri.starts_with("s3://") { + "s3" + } else if uri.starts_with("gs://") { + "gcs" + } else if uri.starts_with("az://") { + "azure" + } else { + "local" + } +} + +fn schema() -> Arc { + Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])) +} + +fn batch(start_id: usize, num_rows: usize) -> RecordBatch { + let ids = Int64Array::from_iter_values((start_id as i64)..((start_id + num_rows) as i64)); + let names = StringArray::from_iter_values( + (start_id..(start_id + num_rows)).map(|i| format!("name_{i}")), + ); + RecordBatch::try_new(schema(), vec![Arc::new(ids), Arc::new(names)]).expect("build batch") +} + +/// Storage options that turn on S3 Express when the URI advertises it. +/// +/// S3 Express directory buckets don't support GetBucketLocation, so we also +/// require the caller to set `AWS_REGION` and forward it explicitly. +fn store_params_for(uri: &str) -> Option { + if !uri.contains("--x-s3") { + return None; + } + let region = env::var("AWS_REGION") + .or_else(|_| env::var("AWS_DEFAULT_REGION")) + .expect("AWS_REGION is required when DATASET_URI points at S3 Express"); + let storage_options: HashMap = + [("s3_express", "true"), ("region", region.as_str())] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Some(ObjectStoreParams { + storage_options_accessor: Some(Arc::new(StorageOptionsAccessor::with_static_options( + storage_options, + ))), + ..Default::default() + }) +} + +fn write_params(session: Arc, store_params: Option) -> WriteParams { + WriteParams { + mode: WriteMode::Append, + session: Some(session), + store_params, + skip_auto_cleanup: true, + ..Default::default() + } +} + +async fn create_base_dataset( + uri: &str, + base_rows: usize, + rows_per_append: usize, + session: Arc, + store_params: Option, +) -> Dataset { + // When `base_rows == 0` the dataset starts empty: one create commit with a + // zero-row batch so the writers begin at version 1 with no data. + let initial_rows = if base_rows == 0 { + 0 + } else { + rows_per_append.min(base_rows) + }; + let initial = batch(0, initial_rows); + let reader = RecordBatchIterator::new(vec![Ok(initial)], schema()); + let create_params = WriteParams { + mode: WriteMode::Create, + session: Some(session.clone()), + store_params: store_params.clone(), + skip_auto_cleanup: true, + ..Default::default() + }; + let mut dataset = Dataset::write(reader, uri, Some(create_params)) + .await + .expect("create base dataset"); + + // Top up to BASE_ROWS in chunks so we don't allocate one huge batch. + let chunk = 10_000.min(base_rows); + let mut written = initial_rows; + while written < base_rows { + let to_write = chunk.min(base_rows - written); + let batch = batch(written, to_write); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema()); + let params = write_params(session.clone(), store_params.clone()); + dataset = Dataset::write(reader, uri, Some(params)) + .await + .expect("seed appends"); + written += to_write; + } + dataset +} + +struct WriterStats { + successes: usize, + failures: usize, + latencies: Vec, +} + +#[allow(clippy::too_many_arguments)] +async fn run_writer( + writer_id: usize, + uri: String, + appends: usize, + rows_per_append: usize, + deadline: Option, + per_attempt_timeout: Option, + session: Arc, + store_params: Option, +) -> WriterStats { + // Each writer keeps its own dataset handle; CommitBuilder rebases on + // conflict so we don't need to manually reload between appends. + let mut dataset = Arc::new( + DatasetBuilder::from_uri(&uri) + .with_session(session.clone()) + .load() + .await + .expect("writer load"), + ); + + let mut stats = WriterStats { + successes: 0, + failures: 0, + latencies: Vec::with_capacity(appends), + }; + + // Disjoint id ranges per writer so the data inserted is identifiable. + let id_base = 1_000_000 + writer_id * appends * rows_per_append; + for i in 0..appends { + if let Some(d) = deadline + && Instant::now() >= d + { + break; + } + let batch = batch(id_base + i * rows_per_append, rows_per_append); + let params = write_params(session.clone(), store_params.clone()); + let start = Instant::now(); + // Per-attempt cap keeps the slow-tail commits from extending the run + // far past the writer-side deadline at high concurrency. + let result = match per_attempt_timeout { + Some(t) => { + let ds = dataset.clone(); + let params_ref = ¶ms; + match tokio::time::timeout(t, async move { + InsertBuilder::new(ds) + .with_params(params_ref) + .execute(vec![batch]) + .await + }) + .await + { + Ok(r) => r, + Err(_) => Err(lance_core::Error::io_source(Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "per-attempt timeout", + )))), + } + } + None => { + InsertBuilder::new(dataset.clone()) + .with_params(¶ms) + .execute(vec![batch]) + .await + } + }; + let elapsed = start.elapsed(); + match result { + Ok(new_ds) => { + stats.successes += 1; + stats.latencies.push(elapsed); + dataset = Arc::new(new_ds); + } + Err(e) => { + stats.failures += 1; + eprintln!("writer {writer_id} append {i} failed after {elapsed:?}: {e}"); + // Reload and keep going so a single failure doesn't end the run. + dataset = Arc::new( + DatasetBuilder::from_uri(&uri) + .with_session(session.clone()) + .load() + .await + .expect("writer reload after error"), + ); + } + } + } + stats +} + +fn percentile(sorted: &[Duration], p: f64) -> Duration { + if sorted.is_empty() { + return Duration::ZERO; + } + let idx = ((sorted.len() as f64 - 1.0) * p).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +fn ms(d: Duration) -> f64 { + d.as_secs_f64() * 1000.0 +} + +fn bench_concurrent_append(c: &mut Criterion) { + let dataset_base = + env::var("DATASET_URI").expect("DATASET_URI is required for concurrent_append bench"); + let num_writers = env_usize("NUM_WRITERS", DEFAULT_NUM_WRITERS); + let appends_per_writer = env_usize("APPENDS_PER_WRITER", DEFAULT_APPENDS_PER_WRITER); + let rows_per_append = env_usize("ROWS_PER_APPEND", DEFAULT_ROWS_PER_APPEND); + let base_rows = env_usize("BASE_ROWS", DEFAULT_BASE_ROWS); + let keep_dataset = env_bool("KEEP_DATASET"); + // Per-writer wall-clock budget. When non-zero, each writer stops looping + // once this many seconds have elapsed since the run started, even if it + // hasn't issued `APPENDS_PER_WRITER` commits yet. Lets us bound run time + // at high concurrency where conflict retries make commits arbitrarily slow. + let max_wall_secs = env_usize("MAX_WALL_SECS", 0); + // Per-attempt timeout. Caps any single commit attempt (including its + // internal retries) so the slow-tail of an under-contention commit doesn't + // extend the run past the writer deadline. 0 disables it. + let per_attempt_timeout_secs = env_usize("PER_ATTEMPT_TIMEOUT_SECS", 0); + + let uri = format!( + "{}/concurrent_append_{}", + dataset_base.trim_end_matches('/'), + &Uuid::new_v4().to_string()[..8] + ); + let label = storage_label(&uri); + let store_params = store_params_for(&uri); + + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("build tokio runtime"); + + println!("=== Concurrent Append Benchmark ==="); + println!("Storage: {uri} ({label})"); + println!( + "Writers: {num_writers}, appends/writer: {appends_per_writer}, rows/append: {rows_per_append}" + ); + println!("Base rows: {base_rows}, keep_dataset: {keep_dataset}"); + println!(); + + // Share one ObjectStoreRegistry so all writers reuse warm TCP/TLS sessions. + let registry = Arc::new(ObjectStoreRegistry::default()); + let session = Arc::new(Session::new(0, 0, registry)); + + println!("Seeding base dataset ({base_rows} rows)..."); + let seed_start = Instant::now(); + let base_dataset = runtime.block_on(create_base_dataset( + &uri, + base_rows, + rows_per_append, + session.clone(), + store_params.clone(), + )); + let starting_version = base_dataset.manifest().version; + println!( + "Base dataset ready in {:.2}s at version {starting_version}", + seed_start.elapsed().as_secs_f64() + ); + + println!("Starting {num_writers} concurrent writers..."); + let wall_start = Instant::now(); + let deadline = + (max_wall_secs > 0).then(|| wall_start + Duration::from_secs(max_wall_secs as u64)); + if let Some(d) = deadline { + println!( + "Per-writer wall budget: {max_wall_secs}s (deadline {:?} from now)", + d.duration_since(wall_start) + ); + } + let per_attempt_timeout = (per_attempt_timeout_secs > 0) + .then(|| Duration::from_secs(per_attempt_timeout_secs as u64)); + if let Some(t) = per_attempt_timeout { + println!("Per-attempt timeout: {:?}", t); + } + let all_stats: Vec = runtime.block_on(async { + let mut tasks = Vec::with_capacity(num_writers); + for writer_id in 0..num_writers { + let uri = uri.clone(); + let session = session.clone(); + let store_params = store_params.clone(); + tasks.push(tokio::spawn(async move { + run_writer( + writer_id, + uri, + appends_per_writer, + rows_per_append, + deadline, + per_attempt_timeout, + session, + store_params, + ) + .await + })); + } + let mut out = Vec::with_capacity(num_writers); + for t in tasks { + out.push(t.await.expect("writer task panicked")); + } + out + }); + let wall = wall_start.elapsed(); + + let total_attempts = all_stats + .iter() + .map(|s| s.successes + s.failures) + .sum::(); + let total_success = all_stats.iter().map(|s| s.successes).sum::(); + let total_failed = all_stats.iter().map(|s| s.failures).sum::(); + let mut latencies: Vec = all_stats + .into_iter() + .flat_map(|s| s.latencies.into_iter()) + .collect(); + latencies.sort(); + + let throughput = total_success as f64 / wall.as_secs_f64(); + + println!(); + println!("=== Results ==="); + println!("Wall time: {:.2}s", wall.as_secs_f64()); + println!( + "Commits: {total_success} succeeded, {total_failed} failed out of {total_attempts} attempts" + ); + println!("Throughput: {throughput:.2} commits/sec"); + if !latencies.is_empty() { + let mean = latencies.iter().map(|d| d.as_secs_f64()).sum::() / latencies.len() as f64; + println!( + "Commit latency (per writer, includes any retries): \ + p50={:.2}ms p90={:.2}ms p95={:.2}ms p99={:.2}ms max={:.2}ms mean={:.2}ms", + ms(percentile(&latencies, 0.50)), + ms(percentile(&latencies, 0.90)), + ms(percentile(&latencies, 0.95)), + ms(percentile(&latencies, 0.99)), + ms(*latencies.last().unwrap()), + mean * 1000.0, + ); + } + + let final_dataset = runtime.block_on(async { + DatasetBuilder::from_uri(&uri) + .with_session(session.clone()) + .load() + .await + .expect("final load") + }); + println!( + "Final dataset version: {} (started at {})", + final_dataset.manifest().version, + starting_version + ); + + // Pin the numbers into criterion so they show up in regression tracking. + let mut group = c.benchmark_group(format!("concurrent_append_{label}")); + group.bench_function("commits_per_sec", |b| b.iter(|| throughput)); + group.bench_function("p50_ms", |b| b.iter(|| ms(percentile(&latencies, 0.50)))); + group.bench_function("p99_ms", |b| b.iter(|| ms(percentile(&latencies, 0.99)))); + group.finish(); + + if !keep_dataset && label == "local" { + let _ = std::fs::remove_dir_all(&uri); + println!("Local dataset removed: {uri}"); + } else { + println!("Dataset preserved: {uri}"); + } +} + +criterion_group!(benches, bench_concurrent_append); +criterion_main!(benches); diff --git a/rust/lance/benches/manifest_commit.rs b/rust/lance/benches/manifest_commit.rs new file mode 100644 index 00000000000..2a98a37a498 --- /dev/null +++ b/rust/lance/benches/manifest_commit.rs @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark for manifest commit performance with many small fragments. +//! +//! This benchmark tests how performance degrades as the number of small fragments +//! grows. Each fragment contains only 10 rows, and we measure both: +//! - Commit time (manifest write only, excludes fragment data writing) +//! - Load time (manifest read from storage, using checkout_latest) +//! +//! Key optimizations: +//! - Uses shared ObjectStoreRegistry to reuse TCP/TLS connections +//! - Disables auto-cleanup to avoid background cleanup overhead +//! - Separates fragment writing from commit measurement +//! +//! ## Running against S3 Express +//! +//! ```bash +//! export AWS_REGION=us-east-1 +//! export DATASET_PREFIX=s3://your-bucket--use1-az4--x-s3/bench/manifest_commit +//! export NUM_ITERATIONS=100 +//! cargo bench --bench manifest_commit +//! ``` +//! +//! ## Running against local filesystem (with temp directory) +//! +//! ```bash +//! cargo bench --bench manifest_commit +//! ``` +//! +//! ## Configuration +//! +//! - `DATASET_PREFIX`: Base URI for datasets (e.g. s3://bucket/prefix or /tmp/bench). +//! If not set, uses a temporary directory. +//! - `NUM_ITERATIONS`: Number of small fragment writes to perform (default: 100). +//! - `ROWS_PER_FRAGMENT`: Number of rows per fragment (default: 10). +//! - `DELETE_DATASET`: When "true", delete the dataset after benchmark completes. +//! - `ENABLE_CACHE`: When "true", enable manifest caching for load measurements. +//! Default is "false" to measure actual storage read latency. + +#![allow(clippy::print_stdout)] + +use arrow_array::{Int64Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use lance::dataset::builder::DatasetBuilder; +use lance::dataset::{CommitBuilder, Dataset, InsertBuilder, WriteMode, WriteParams}; +use lance::session::Session; +use lance_io::object_store::ObjectStoreRegistry; +use std::sync::Arc; +use std::time::Instant; +use tokio::runtime::Runtime; +use uuid::Uuid; + +const DEFAULT_ROWS_PER_FRAGMENT: usize = 10; +const DEFAULT_NUM_ITERATIONS: usize = 100; + +fn get_rows_per_fragment() -> usize { + std::env::var("ROWS_PER_FRAGMENT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_ROWS_PER_FRAGMENT) +} + +fn get_num_iterations() -> usize { + std::env::var("NUM_ITERATIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_NUM_ITERATIONS) +} + +fn get_delete_dataset() -> bool { + std::env::var("DELETE_DATASET") + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false) +} + +fn get_enable_cache() -> bool { + std::env::var("ENABLE_CACHE") + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false) +} + +fn get_dataset_prefix() -> String { + std::env::var("DATASET_PREFIX").unwrap_or_else(|_| { + let temp_dir = std::env::temp_dir().join(format!("lance_bench_{}", Uuid::new_v4())); + std::fs::create_dir_all(&temp_dir).expect("Failed to create temp directory"); + temp_dir.to_string_lossy().to_string() + }) +} + +fn get_storage_label(prefix: &str) -> &'static str { + if prefix.starts_with("s3://") { + "s3" + } else if prefix.starts_with("gs://") { + "gcs" + } else if prefix.starts_with("az://") { + "azure" + } else if prefix.starts_with("memory://") { + "memory" + } else { + "local" + } +} + +async fn create_initial_dataset( + uri: &str, + rows_per_fragment: usize, + session: Arc, +) -> Dataset { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = create_batch(schema.clone(), 0, rows_per_fragment); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + + std::fs::remove_dir_all(uri).ok(); + + let params = WriteParams { + session: Some(session), + skip_auto_cleanup: true, + ..Default::default() + }; + + Dataset::write(reader, uri, Some(params)) + .await + .expect("failed to create initial dataset") +} + +fn create_batch(schema: Arc, start_id: usize, num_rows: usize) -> RecordBatch { + let ids = Int64Array::from_iter_values((start_id as i64)..((start_id + num_rows) as i64)); + let names = StringArray::from_iter_values( + (start_id..(start_id + num_rows)).map(|i| format!("name_{}", i)), + ); + + RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(names)]) + .expect("failed to create batch") +} + +fn bench_manifest_commit(c: &mut Criterion) { + let runtime = Runtime::new().expect("failed to build tokio runtime"); + + let dataset_prefix = get_dataset_prefix(); + let num_iterations = get_num_iterations(); + let rows_per_fragment = get_rows_per_fragment(); + let delete_dataset = get_delete_dataset(); + let enable_cache = get_enable_cache(); + let storage_label = get_storage_label(&dataset_prefix); + + let short_id = &Uuid::new_v4().to_string()[..8]; + let uri = format!( + "{}/manifest_commit_{}", + dataset_prefix.trim_end_matches('/'), + short_id + ); + + println!("=== Manifest Commit Benchmark Setup ==="); + println!("Storage: {} ({})", uri, storage_label); + println!("Rows per fragment: {}", rows_per_fragment); + println!("Number of iterations: {}", num_iterations); + println!( + "Total fragments (including initial): {}", + num_iterations + 1 + ); + println!("Delete dataset: {}", delete_dataset); + println!( + "Cache enabled: {} ({})", + enable_cache, + if enable_cache { + "using default cache size" + } else { + "zero cache size - measures actual storage read" + } + ); + println!(); + + // Create a shared session for both commit and load operations + // When cache is disabled, use zero cache size to measure actual storage read latency + // When cache is enabled, use default cache sizes (6GB index, 1GB metadata) + let shared_store_registry = Arc::new(ObjectStoreRegistry::default()); + let session = if enable_cache { + Arc::new(Session::default()) + } else { + Arc::new(Session::new(0, 0, shared_store_registry)) + }; + + let initial_dataset = runtime.block_on(create_initial_dataset( + &uri, + rows_per_fragment, + session.clone(), + )); + + let uri_clone = uri.clone(); + let mut load_dataset = runtime.block_on(async { + DatasetBuilder::from_uri(&uri_clone) + .with_session(session.clone()) + .load() + .await + .expect("failed to load dataset for load measurements") + }); + + let mut current_dataset = Arc::new(initial_dataset); + + let mut commit_latencies = Vec::with_capacity(num_iterations); + let mut load_latencies = Vec::with_capacity(num_iterations); + + println!("Running commit and load benchmarks..."); + println!("fragments,commit_ms,load_ms"); + + for i in 1..=num_iterations { + let num_fragments = i + 1; + + let (commit_time, new_dataset) = { + let dataset = current_dataset.clone(); + let session_clone = session.clone(); + runtime.block_on(async move { + let schema: Arc = Arc::new((&dataset.schema().clone()).into()); + let start_id = dataset.count_rows(None).await.unwrap() as usize; + let batch = create_batch(schema.clone(), start_id, rows_per_fragment); + + let write_params = WriteParams { + mode: WriteMode::Append, + session: Some(session_clone.clone()), + skip_auto_cleanup: true, + ..Default::default() + }; + + let transaction = InsertBuilder::new(dataset.clone()) + .with_params(&write_params) + .execute_uncommitted(vec![batch]) + .await + .expect("failed to write fragment"); + + let start = Instant::now(); + let new_ds = CommitBuilder::new(dataset) + .with_session(session_clone) + .with_skip_auto_cleanup(true) + .execute(transaction) + .await + .expect("failed to commit"); + (start.elapsed(), Arc::new(new_ds)) + }) + }; + + let load_time = runtime.block_on(async { + let start = Instant::now(); + load_dataset + .checkout_latest() + .await + .expect("failed to checkout latest"); + let elapsed = start.elapsed(); + + assert_eq!( + load_dataset.manifest().fragments.len(), + num_fragments, + "Expected {} fragments", + num_fragments + ); + elapsed + }); + + current_dataset = new_dataset; + + commit_latencies.push(commit_time); + load_latencies.push(load_time); + + println!( + "{},{:.2},{:.2}", + num_fragments, + commit_time.as_secs_f64() * 1000.0, + load_time.as_secs_f64() * 1000.0 + ); + } + + println!(); + println!("=== Summary Statistics ==="); + + let avg_commit: f64 = commit_latencies + .iter() + .map(|d| d.as_secs_f64()) + .sum::() + / commit_latencies.len() as f64; + let avg_load: f64 = + load_latencies.iter().map(|d| d.as_secs_f64()).sum::() / load_latencies.len() as f64; + + let min_commit = commit_latencies.iter().min().unwrap(); + let max_commit = commit_latencies.iter().max().unwrap(); + let min_load = load_latencies.iter().min().unwrap(); + let max_load = load_latencies.iter().max().unwrap(); + + println!( + "Commit latency: avg={:.2}ms, min={:.2}ms, max={:.2}ms", + avg_commit * 1000.0, + min_commit.as_secs_f64() * 1000.0, + max_commit.as_secs_f64() * 1000.0 + ); + println!( + "Load latency: avg={:.2}ms, min={:.2}ms, max={:.2}ms", + avg_load * 1000.0, + min_load.as_secs_f64() * 1000.0, + max_load.as_secs_f64() * 1000.0 + ); + + let first_10_avg_commit = commit_latencies + .iter() + .take(10) + .map(|d| d.as_secs_f64()) + .sum::() + / 10.0; + let last_10_avg_commit = commit_latencies + .iter() + .rev() + .take(10) + .map(|d| d.as_secs_f64()) + .sum::() + / 10.0; + let first_10_avg_load = load_latencies + .iter() + .take(10) + .map(|d| d.as_secs_f64()) + .sum::() + / 10.0; + let last_10_avg_load = load_latencies + .iter() + .rev() + .take(10) + .map(|d| d.as_secs_f64()) + .sum::() + / 10.0; + + println!(); + println!( + "First 10 iterations avg: commit={:.2}ms, load={:.2}ms", + first_10_avg_commit * 1000.0, + first_10_avg_load * 1000.0 + ); + println!( + "Last 10 iterations avg: commit={:.2}ms, load={:.2}ms", + last_10_avg_commit * 1000.0, + last_10_avg_load * 1000.0 + ); + println!( + "Degradation ratio: commit={:.2}x, load={:.2}x", + last_10_avg_commit / first_10_avg_commit, + last_10_avg_load / first_10_avg_load + ); + + let mut group = c.benchmark_group("manifest_commit"); + + group.bench_function("avg_commit_latency", |b| { + b.iter(|| std::time::Duration::from_secs_f64(avg_commit)) + }); + + group.bench_function("avg_load_latency", |b| { + b.iter(|| std::time::Duration::from_secs_f64(avg_load)) + }); + + group.finish(); + + if delete_dataset { + std::fs::remove_dir_all(&uri).ok(); + println!("Dataset deleted: {}", uri); + } else { + println!("Dataset preserved: {}", uri); + } +} + +criterion_group!(benches, bench_manifest_commit); +criterion_main!(benches); diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 8abb5975fdd..30e43c5bb32 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -2760,16 +2760,15 @@ pub(crate) struct NewTransactionResult<'a> { } pub(crate) fn load_new_transactions(dataset: &Dataset) -> NewTransactionResult<'_> { - // Re-use the same list call for getting the latest manifest and the metadata - // for all manifests in between. + // Resolve every manifest with version > our current version (the latest plus + // the ones in between). On non-lexically-ordered stores this uses the version + // hint to avoid an O(n) listing. let io_parallelism = dataset.object_store.as_ref().io_parallelism(); - let latest_version = dataset.manifest.version; - let locations = dataset - .commit_handler - .list_manifest_locations(&dataset.base, dataset.object_store.as_ref(), true) - .try_take_while(move |location| { - futures::future::ready(Ok(location.version > latest_version)) - }); + let locations = dataset.commit_handler.list_manifest_locations_since( + &dataset.base, + dataset.object_store.as_ref(), + dataset.manifest.version, + ); // Will send the latest manifest via a channel. let (latest_tx, latest_rx) = tokio::sync::oneshot::channel(); diff --git a/rust/lance/src/dataset/tests/dataset_versioning.rs b/rust/lance/src/dataset/tests/dataset_versioning.rs index 088ce85b9a5..5ac01c498b2 100644 --- a/rust/lance/src/dataset/tests/dataset_versioning.rs +++ b/rust/lance/src/dataset/tests/dataset_versioning.rs @@ -34,6 +34,8 @@ fn assert_all_manifests_use_scheme(test_dir: &TempStdDir, scheme: ManifestNaming .read_dir() .unwrap() .map(|entry| entry.unwrap().file_name().into_string().unwrap()) + // Ignore the version hint file, which is not a manifest. + .filter(|name| !name.starts_with("latest_version_hint")) .collect::>(); assert!( entries_names diff --git a/rust/lance/src/dataset/write/commit.rs b/rust/lance/src/dataset/write/commit.rs index efab12d249c..45b78b48bcb 100644 --- a/rust/lance/src/dataset/write/commit.rs +++ b/rust/lance/src/dataset/write/commit.rs @@ -561,6 +561,7 @@ mod tests { // Should see 2 IOPs: // 1. Write the transaction files // 2. Write (conditional put) the manifest + // (the version hint is only written on non-lexically-ordered stores) assert_io_eq!(io_stats, write_iops, 2, "write txn + manifest, i = {}", i); } @@ -629,7 +630,7 @@ mod tests { // Assert io requests let io_stats = new_ds.object_store.as_ref().io_stats_incremental(); // This could be zero, if we decided to be optimistic. However, that - // would mean two wasted write requests (txn + manifest) if there was + // would mean wasted write requests (txn + manifest) if there was // a conflict. We choose to be pessimistic for more consistent performance. assert_io_eq!(io_stats, read_iops, 1); assert_io_eq!(io_stats, write_iops, 2); @@ -786,4 +787,82 @@ mod tests { ); assert_eq!(transaction.read_version, 1); } + + /// On non-lexically-ordered stores (e.g. S3 Express) a commit should use the + /// version hint (a few HEAD probes, O(k)) instead of a full O(n) listing. + #[tokio::test] + async fn test_commit_uses_version_hint_on_non_lexical_store() { + // Make `list` artificially slow per entry so a full listing would be + // obvious; HEAD/GET/PUT stay fast. + let throttled = Arc::new(ThrottledStoreWrapper { + config: ThrottleConfig { + wait_list_per_entry: Duration::from_millis(50), + wait_get_per_call: Duration::from_millis(1), + wait_put_per_call: Duration::from_millis(1), + ..Default::default() + }, + }); + let session = Arc::new(Session::default()); + let write_params = WriteParams { + store_params: Some(ObjectStoreParams { + object_store_wrapper: Some(throttled), + list_is_lexically_ordered: Some(false), + ..Default::default() + }), + session: Some(session.clone()), + enable_v2_manifest_paths: true, + ..Default::default() + }; + + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_iter_values(0..10_i32))], + ) + .unwrap(); + let mut dataset = Arc::new( + InsertBuilder::new("memory://test_version_hint") + .with_params(&write_params) + .execute(vec![batch]) + .await + .unwrap(), + ); + + // Build up many versions so a full listing would be expensive. + for _ in 0..50 { + dataset = Arc::new( + CommitBuilder::new(dataset.clone()) + .execute(sample_transaction(dataset.manifest().version)) + .await + .unwrap(), + ); + } + assert_eq!(dataset.manifest().version, 51); + + dataset.object_store.as_ref().io_stats_incremental(); + + let start = std::time::Instant::now(); + let new_ds = CommitBuilder::new(dataset.clone()) + .execute(sample_transaction(dataset.manifest().version)) + .await + .unwrap(); + let elapsed = start.elapsed(); + + // A full listing of ~52 entries at 50ms each would take ~2.6s. + assert!( + elapsed < Duration::from_secs(1), + "commit took {elapsed:?}; the version hint path was likely not used" + ); + + let io_stats = new_ds.object_store.as_ref().io_stats_incremental(); + assert!( + io_stats.read_iops < 10, + "read_iops = {}; a full listing was likely used", + io_stats.read_iops + ); + } } diff --git a/rust/lance/src/io/commit/external_manifest.rs b/rust/lance/src/io/commit/external_manifest.rs index d68ef08fad4..df2b84a4878 100644 --- a/rust/lance/src/io/commit/external_manifest.rs +++ b/rust/lance/src/io/commit/external_manifest.rs @@ -267,6 +267,15 @@ mod test { .to_string_lossy() .contains(".manifest#") }) + // The version hint file is expected to be present. + .filter(|entry| { + let entry = entry.as_ref().unwrap(); + !entry + .file_name() + .as_os_str() + .to_string_lossy() + .starts_with("latest_version_hint") + }) .collect::>(); assert!(unexpected_entries.is_empty(), "{:?}", unexpected_entries); } @@ -373,6 +382,15 @@ mod test { .to_string_lossy() .ends_with(".manifest") }) + // The version hint file is expected to be present. + .filter(|entry| { + let entry = entry.as_ref().unwrap(); + !entry + .file_name() + .as_os_str() + .to_string_lossy() + .starts_with("latest_version_hint") + }) .collect::>(); assert!(unexpected_entries.is_empty(), "{:?}", unexpected_entries); } From 98e2b8a248789b187bffe4574a42deabe19e66e9 Mon Sep 17 00:00:00 2001 From: Lance Release Bot Date: Tue, 19 May 2026 20:18:07 +0000 Subject: [PATCH 03/23] chore: release beta version 7.0.0-beta.16 --- .bumpversion.toml | 2 +- Cargo.lock | 44 +++++++++++++++++++-------------------- Cargo.toml | 40 +++++++++++++++++------------------ java/lance-jni/Cargo.lock | 36 ++++++++++++++++---------------- java/lance-jni/Cargo.toml | 2 +- java/pom.xml | 2 +- python/Cargo.lock | 36 ++++++++++++++++---------------- python/Cargo.toml | 2 +- 8 files changed, 82 insertions(+), 82 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 43ddd048a98..98be7158b5b 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "7.0.0-beta.15" +current_version = "7.0.0-beta.16" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(-(?P(beta|rc))\\.(?P\\d+))?" serialize = [ "{major}.{minor}.{patch}-{prerelease}.{prerelease_num}", diff --git a/Cargo.lock b/Cargo.lock index 64655ee60d8..1b4f587ba2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3093,7 +3093,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "rand 0.9.4", @@ -4321,7 +4321,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "all_asserts", "approx", @@ -4421,7 +4421,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -4469,7 +4469,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrayref", "paste", @@ -4478,7 +4478,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -4515,7 +4515,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4548,7 +4548,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4568,7 +4568,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-arith", "arrow-array", @@ -4613,7 +4613,7 @@ dependencies = [ [[package]] name = "lance-examples" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "all_asserts", "arrow", @@ -4639,7 +4639,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-arith", "arrow-array", @@ -4679,7 +4679,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "datafusion", "geo-traits", @@ -4693,7 +4693,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "approx", "arc-swap", @@ -4770,7 +4770,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-arith", @@ -4819,7 +4819,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "approx", "arrow-array", @@ -4840,7 +4840,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "async-trait", @@ -4852,7 +4852,7 @@ dependencies = [ [[package]] name = "lance-namespace-datafusion" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-schema", @@ -4868,7 +4868,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-ipc", @@ -4924,7 +4924,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4970,7 +4970,7 @@ dependencies = [ [[package]] name = "lance-test-macros" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "proc-macro2", "quote", @@ -4979,7 +4979,7 @@ dependencies = [ [[package]] name = "lance-testing" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-schema", @@ -4990,7 +4990,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "jieba-rs", "lindera", @@ -5001,7 +5001,7 @@ dependencies = [ [[package]] name = "lance-tools" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "clap", "lance-core", diff --git a/Cargo.toml b/Cargo.toml index 08c5f4e024a..3a6c61acbe9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ resolver = "3" [workspace.package] -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" @@ -55,25 +55,25 @@ rust-version = "1.91.0" [workspace.dependencies] arc-swap = "1.7" libc = "0.2.176" -lance = { version = "=7.0.0-beta.15", path = "./rust/lance", default-features = false } -lance-arrow = { version = "=7.0.0-beta.15", path = "./rust/lance-arrow" } -lance-core = { version = "=7.0.0-beta.15", path = "./rust/lance-core" } -lance-datafusion = { version = "=7.0.0-beta.15", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=7.0.0-beta.15", path = "./rust/lance-datagen" } -lance-encoding = { version = "=7.0.0-beta.15", path = "./rust/lance-encoding" } -lance-file = { version = "=7.0.0-beta.15", path = "./rust/lance-file" } -lance-geo = { version = "=7.0.0-beta.15", path = "./rust/lance-geo" } -lance-index = { version = "=7.0.0-beta.15", path = "./rust/lance-index" } -lance-io = { version = "=7.0.0-beta.15", path = "./rust/lance-io", default-features = false } -lance-linalg = { version = "=7.0.0-beta.15", path = "./rust/lance-linalg" } -lance-namespace = { version = "=7.0.0-beta.15", path = "./rust/lance-namespace" } -lance-namespace-impls = { version = "=7.0.0-beta.15", path = "./rust/lance-namespace-impls" } +lance = { version = "=7.0.0-beta.16", path = "./rust/lance", default-features = false } +lance-arrow = { version = "=7.0.0-beta.16", path = "./rust/lance-arrow" } +lance-core = { version = "=7.0.0-beta.16", path = "./rust/lance-core" } +lance-datafusion = { version = "=7.0.0-beta.16", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=7.0.0-beta.16", path = "./rust/lance-datagen" } +lance-encoding = { version = "=7.0.0-beta.16", path = "./rust/lance-encoding" } +lance-file = { version = "=7.0.0-beta.16", path = "./rust/lance-file" } +lance-geo = { version = "=7.0.0-beta.16", path = "./rust/lance-geo" } +lance-index = { version = "=7.0.0-beta.16", path = "./rust/lance-index" } +lance-io = { version = "=7.0.0-beta.16", path = "./rust/lance-io", default-features = false } +lance-linalg = { version = "=7.0.0-beta.16", path = "./rust/lance-linalg" } +lance-namespace = { version = "=7.0.0-beta.16", path = "./rust/lance-namespace" } +lance-namespace-impls = { version = "=7.0.0-beta.16", path = "./rust/lance-namespace-impls" } lance-namespace-datafusion = { version = "=7.0.0-beta.9", path = "./rust/lance-namespace-datafusion" } lance-namespace-reqwest-client = "0.7.5" -lance-tokenizer = { version = "=7.0.0-beta.15", path = "./rust/lance-tokenizer" } -lance-table = { version = "=7.0.0-beta.15", path = "./rust/lance-table" } -lance-test-macros = { version = "=7.0.0-beta.15", path = "./rust/lance-test-macros" } -lance-testing = { version = "=7.0.0-beta.15", path = "./rust/lance-testing" } +lance-tokenizer = { version = "=7.0.0-beta.16", path = "./rust/lance-tokenizer" } +lance-table = { version = "=7.0.0-beta.16", path = "./rust/lance-table" } +lance-test-macros = { version = "=7.0.0-beta.16", path = "./rust/lance-test-macros" } +lance-testing = { version = "=7.0.0-beta.16", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "58.0.0", optional = false, features = ["prettyprint"] } @@ -99,7 +99,7 @@ half = { "version" = "2.1", default-features = false, features = [ "num-traits", "std", ] } -lance-bitpacking = { version = "=7.0.0-beta.15", path = "./rust/compression/bitpacking" } +lance-bitpacking = { version = "=7.0.0-beta.16", path = "./rust/compression/bitpacking" } bitpacking = "0.9" bitvec = "1" bytes = "1.11.1" @@ -139,7 +139,7 @@ deepsize = "0.2.0" dirs = "6.0.0" either = "1.0" fst = { version = "0.4.7", features = ["levenshtein"] } -fsst = { version = "=7.0.0-beta.15", path = "./rust/compression/fsst" } +fsst = { version = "=7.0.0-beta.16", path = "./rust/compression/fsst" } futures = "0.3" geoarrow-array = "0.8" geoarrow-schema = "0.8" diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index 61d7f1d7987..0e2059bc749 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -2509,7 +2509,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3616,7 +3616,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arc-swap", "arrow", @@ -3686,7 +3686,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -3706,7 +3706,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrayref", "paste", @@ -3715,7 +3715,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -3750,7 +3750,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -3782,7 +3782,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -3800,7 +3800,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-arith", "arrow-array", @@ -3835,7 +3835,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-arith", "arrow-array", @@ -3866,7 +3866,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "datafusion", "geo-traits", @@ -3880,7 +3880,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arc-swap", "arrow", @@ -3947,7 +3947,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-arith", @@ -3989,7 +3989,7 @@ dependencies = [ [[package]] name = "lance-jni" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4025,7 +4025,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -4041,7 +4041,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "async-trait", @@ -4053,7 +4053,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-ipc", @@ -4097,7 +4097,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4134,7 +4134,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "rust-stemmers", "serde", diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 1cbee8d10e1..5c246583a3e 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lance-jni" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" edition = "2024" authors = ["Lance Devs "] rust-version = "1.91" diff --git a/java/pom.xml b/java/pom.xml index 9b44c467605..db32abad4d7 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -7,7 +7,7 @@ org.lance lance-core Lance Core - 7.0.0-beta.15 + 7.0.0-beta.16 jar Lance Format Java API diff --git a/python/Cargo.lock b/python/Cargo.lock index fa439bf442a..5e13dd8ac51 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -2853,7 +2853,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3975,7 +3975,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arc-swap", "arrow", @@ -4046,7 +4046,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -4066,7 +4066,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrayref", "paste", @@ -4075,7 +4075,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -4110,7 +4110,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4142,7 +4142,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4160,7 +4160,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-arith", "arrow-array", @@ -4195,7 +4195,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-arith", "arrow-array", @@ -4226,7 +4226,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "datafusion", "geo-traits", @@ -4240,7 +4240,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arc-swap", "arrow", @@ -4308,7 +4308,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-arith", @@ -4350,7 +4350,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow-array", "arrow-buffer", @@ -4366,7 +4366,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "async-trait", @@ -4378,7 +4378,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-ipc", @@ -4422,7 +4422,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", @@ -4461,7 +4461,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "jieba-rs", "lindera", @@ -5881,7 +5881,7 @@ dependencies = [ [[package]] name = "pylance" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" dependencies = [ "arrow", "arrow-array", diff --git a/python/Cargo.toml b/python/Cargo.toml index c9b7d918a7c..078e0ad6057 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "7.0.0-beta.15" +version = "7.0.0-beta.16" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" From 64d7b78d2d3a54ad7b86f718a538edd9edac52b7 Mon Sep 17 00:00:00 2001 From: Vova Kolmakov Date: Wed, 20 May 2026 06:31:59 +0700 Subject: [PATCH 04/23] fix: make HNSW graph build deterministic to stabilize test_ann_prefilter (#6818) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem `test_ann_prefilter` is flaky and failed on CI (`linux-arm`, Rust) — e.g. on the unrelated PR #6757 — with the HNSW+SQ parametrization returning a near-miss neighbor (row 10 instead of 6). ## Root cause HNSW node-level assignment uses an **unseeded** thread RNG (`rand::rng()`) in both the offline (`HnswBuilder`) and online (`OnlineHnswBuilder`) builders, so every index build produces a different random graph. On a tiny 300-vector dataset, an approximate HNSW+SQ search over a different graph each run can return a near neighbor instead of the exact one. `main` was green by luck of the RNG, not correctness. This is **not** caused by #6757 (the `String`→`Uuid` index-id refactor): index cache keys and on-disk index paths are byte-identical before/after that change; the test only surfaced the pre-existing flakiness. ## Fix - Seed both level-assignment sites with a shared fixed constant (`HNSW_LEVEL_RNG_SEED`) via `SmallRng`, making graph construction reproducible. Recall is statistically unaffected (identical level distribution; only the draws are fixed). A constant — rather than a new `HnswBuildParams` field — keeps the change contained (no serde/proto/binding changes). - Harden `test_ann_prefilter` to assert the property it actually validates (prefilter honored: `filterable > 5`) instead of an exact nearest-neighbor id, per the repo guideline that vector-index tests assert recall, not exact matches. Co-authored-by: Vova Kolmakov Co-authored-by: Claude Opus 4.7 (1M context) --- rust/lance-index/src/vector/hnsw/builder.rs | 13 +++++++++++-- rust/lance-index/src/vector/hnsw/online.rs | 6 +++--- rust/lance/src/dataset/scanner.rs | 8 +++++++- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index f09ec27be32..f605e15714d 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -24,7 +24,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use tracing::instrument; use lance_core::{Error, Result}; -use rand::{Rng, rng}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; use serde::{Deserialize, Serialize}; use super::super::graph::beam_search; @@ -44,6 +44,15 @@ use crate::vector::{DIST_COL, Query, VECTOR_RESULT_SCHEMA}; pub const HNSW_METADATA_KEY: &str = "lance:hnsw"; +/// Fixed seed for HNSW node-level assignment. +/// +/// A constant seed makes graph construction reproducible (same data + params => +/// same graph), which keeps index builds deterministic and tests stable. Recall +/// is statistically unaffected — the level distribution is identical, only the +/// random draws become fixed. Shared by the offline ([`HNSWBuilder`]) and online +/// ([`super::online::OnlineHnswBuilder`]) builders so both produce comparable graphs. +pub(crate) const HNSW_LEVEL_RNG_SEED: u64 = 42; + /// Parameters of building HNSW index #[derive(Debug, Clone, Serialize, Deserialize, DeepSizeOf)] pub struct HnswBuildParams { @@ -475,7 +484,7 @@ impl HnswBuilder { if len > 0 { nodes.push(RwLock::new(GraphBuilderNode::new(0, max_level as usize))); } - let mut level_rng = rng(); + let mut level_rng = SmallRng::seed_from_u64(HNSW_LEVEL_RNG_SEED); for i in 1..len { nodes.push(RwLock::new(GraphBuilderNode::new( i as u32, diff --git a/rust/lance-index/src/vector/hnsw/online.rs b/rust/lance-index/src/vector/hnsw/online.rs index 5d996342350..8fbdcbcb1c5 100644 --- a/rust/lance-index/src/vector/hnsw/online.rs +++ b/rust/lance-index/src/vector/hnsw/online.rs @@ -36,9 +36,9 @@ use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering}; use arc_swap::ArcSwap; use crossbeam_queue::ArrayQueue; -use rand::{Rng, rng}; +use rand::{Rng, SeedableRng, rngs::SmallRng}; -use super::builder::{HNSW, HnswBuildParams, HnswQueryParams}; +use super::builder::{HNSW, HNSW_LEVEL_RNG_SEED, HnswBuildParams, HnswQueryParams}; use super::select_neighbors_heuristic; use crate::vector::graph::builder::GraphBuilderNode; use crate::vector::graph::{ @@ -166,7 +166,7 @@ impl OnlineHnswBuilder { let max_level = params.max_level; let level_count = (0..max_level).map(|_| AtomicUsize::new(0)).collect(); - let mut level_rng = rng(); + let mut level_rng = SmallRng::seed_from_u64(HNSW_LEVEL_RNG_SEED); let nodes: Vec<_> = (0..capacity) .map(|i| { let target_level = if i == 0 { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 3814b5fc017..6916c9ca258 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -6570,7 +6570,13 @@ mod test { let first_match = batches[0][ROW_ID].as_primitive::().values()[0]; - assert_eq!(6, first_match); + // HNSW+SQ is an approximate index; this test validates *prefiltering*, so + // every row failing `filterable > 5` (row ids 0..=5) must be excluded. + // HNSW recall is covered by dedicated vector-index tests elsewhere. + assert!( + first_match > 5, + "prefilter not honored: returned row id {first_match} should satisfy `filterable > 5`" + ); } #[rstest] From 29a8f92eb089117336ec2040997e25f901733b87 Mon Sep 17 00:00:00 2001 From: Jerry He Date: Tue, 19 May 2026 16:35:05 -0700 Subject: [PATCH 05/23] fix: set _row_created_at_version to new version for MERGE INTO INSERT rows (#6774) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary * Fixes lance-format/lance#6735. * `transaction.rs` — `resolve_update_version_metadata`: in the per-row `created_at` mapping, check `row_id_to_source.contains_key(&rid)` before calling `resolve_created_at_version`. Rows not in the map are INSERT branch rows (no source in existing fragments); they now receive `new_version` as `created_at` instead of the previous fallback of `UNKNOWN_CREATED_AT_VERSION` (1). * `resolve_created_at_version` doc comment updated to clarify it is only called for UPDATE branch rows (source confirmed present). The unmapped-row-ID branch inside the function is unused when called from `resolve_update_version_metadata`; `UNKNOWN` (1) still applies for UPDATE rows whose source fragment has missing or bad `created_at_version_meta` (cache miss, decode failure, or out-of-range offset). * Two existing tests updated to assert the corrected behavior; one new test added. ## Background MERGE INTO commits through `Operation::Update` and produces both UPDATE branch rows (rewritten into `new_fragments` with a source row in the previous manifest, stable row ID present in `row_id_to_source`) and INSERT branch rows (new rows also in `new_fragments`, stable row ID assigned fresh, not present in any existing fragment). Before this change, `resolve_update_version_metadata` built the per-row `created_at_versions` vector by calling `resolve_created_at_version` for every row ID. For UPDATE branch rows that function correctly copies `created_at` from the source fragment. For INSERT branch rows the map lookup fails and the function returns `UNKNOWN_CREATED_AT_VERSION = 1`, producing a wrong historical version for every newly inserted row. CDF consumers cannot distinguish merge-inserted rows from updated rows via `_row_created_at_version`, and the value 1 is meaningless for rows that first appeared in a recent commit. The fix is a single guard at the call site: only call `resolve_created_at_version` for rows confirmed to have a source (UPDATE branch); for all other rows use `new_version` directly. ## Implementation notes * The guard uses `row_id_to_source.contains_key(&rid)`, which is an O(1) hash lookup on the same map already built for the UPDATE branch path — no additional data structures or iteration. * No lance-spark changes are needed. The Spark commit path (SparkPositionDeltaWrite) already attaches `RowIdMeta` to new fragment rows. This change activates the correct behavior automatically for all callers of `Operation::Update`, including lance-spark MERGE INTO. * lane-spark test update https://github.com/lance-format/lance-spark/pull/530 ## Test plan * `test_update_version_tracking_insert_branch_gets_new_version` (renamed from `test_update_version_tracking_unknown_row_id_defaults_to_1`): new fragment with one UPDATE branch row (ID 10, source `created_at = 5`) and one INSERT branch row (ID 999); asserts `created_at = [5, 5]` — UPDATE branch copies from source, INSERT branch gets `new_version` (5). * `test_update_version_tracking_merge_into_distinguishes_insert_and_update_branch` (new): new fragment interleaves UPDATE branch rows (IDs 10, 11, source `created_at = 3`) and INSERT branch rows (IDs 500, 501); asserts `created_at = [3, 5, 3, 5]` to verify per-row correctness across both branches in the same fragment. * `test_update_version_tracking_no_row_id_meta_fallback`: assertion updated from `[1, 1, 1]` to `[5, 5, 5]` — a fragment with no `row_id_meta` gets fresh stable IDs assigned by `assign_row_ids`; those IDs have no source and are INSERT branch rows, so `created_at` equals `new_version`. * `test_update_version_tracking_source_fragment_no_created_at_defaults_to_1` (unchanged): confirms that UPDATE branch rows whose source fragment has no `created_at_version_meta` still fall back to `UNKNOWN` (1) — the remaining reachable path through `resolve_created_at_version`. Co-authored-by: Jing chen He --- rust/lance/src/dataset/transaction.rs | 108 +++++++++++++++++++++++--- 1 file changed, 97 insertions(+), 11 deletions(-) diff --git a/rust/lance/src/dataset/transaction.rs b/rust/lance/src/dataset/transaction.rs index b82498589f4..3f96b9964d5 100644 --- a/rust/lance/src/dataset/transaction.rs +++ b/rust/lance/src/dataset/transaction.rs @@ -53,15 +53,20 @@ use uuid::Uuid; /// Version 1 is the initial dataset version in the Lance format. const UNKNOWN_CREATED_AT_VERSION: u64 = 1; -/// Look up the `created_at` version for a single row ID. +/// Look up the `created_at` version for a single UPDATE-branch row ID. +/// +/// Callers must only call this for row IDs that are confirmed to be present in +/// `row_id_to_source` (i.e. UPDATE branch rows whose source exists in an existing +/// fragment). INSERT branch rows (no source) must use `new_version` directly and +/// must not call this function. /// /// Uses `row_id_to_source` to find the originating fragment and row offset, then /// performs a O(K) random-access lookup via [`RowDatasetVersionSequence::version_at`] /// on the pre-decoded sequence in `version_cache` (keyed by fragment ID). /// -/// Returns [`UNKNOWN_CREATED_AT_VERSION`] for any failure: unmapped row ID, missing -/// cache entry (fragment had no `created_at_version_meta` or it failed to decode), -/// or an out-of-range offset. +/// Returns [`UNKNOWN_CREATED_AT_VERSION`] if the source fragment has no +/// `created_at_version_meta` (missing or failed to decode) or the offset is +/// out of range. fn resolve_created_at_version( row_id: u64, row_id_to_source: &HashMap, @@ -173,7 +178,19 @@ fn resolve_update_version_metadata( let physical_rows = fragment.physical_rows.unwrap_or(0); let created_at_versions: Vec = row_ids .iter() - .map(|rid| resolve_created_at_version(rid, &row_id_to_source, &version_cache)) + .map(|rid| { + if row_id_to_source.contains_key(&rid) { + // UPDATE branch: stable row ID resolves to a source row in an + // existing fragment. Copy created_at from the original row so + // the row's first-appearance version is preserved across rewrites. + resolve_created_at_version(rid, &row_id_to_source, &version_cache) + } else { + // INSERT branch: stable row ID has no source in existing fragments + // (e.g. NOT MATCHED arm of MERGE INTO). The row first appears in + // this commit, so created_at equals the new commit version. + new_version + } + }) .collect(); debug_assert_eq!(created_at_versions.len(), physical_rows); @@ -5569,7 +5586,15 @@ mod tests { } #[test] - fn test_update_version_tracking_unknown_row_id_defaults_to_1() { + fn test_update_version_tracking_insert_branch_gets_new_version() { + // Simulates the INSERT branch (NOT MATCHED) of a MERGE INTO commit: + // the new fragment contains a mix of rewritten rows (UPDATE branch, row ID + // present in existing fragments) and freshly inserted rows (INSERT branch, + // row ID not present in any existing fragment). + // + // UPDATE branch row (10): created_at must be copied from the source fragment. + // INSERT branch row (999): created_at must equal new_version (the merge commit + // version), because the row first appeared in this commit. let existing_seq = RowIdSequence::from([10u64, 11].as_slice()); let existing_created = RowDatasetVersionSequence { runs: vec![RowDatasetVersionRun { @@ -5589,7 +5614,7 @@ mod tests { last_updated_at_version_meta: None, }; - // New fragment has row 10 (known) and row 999 (unknown — freshly inserted) + // New fragment has row 10 (UPDATE branch) and row 999 (INSERT branch) let new_seq = RowIdSequence::from([10u64, 999].as_slice()); let new_fragment = Fragment { id: 10, @@ -5601,6 +5626,7 @@ mod tests { last_updated_at_version_meta: None, }; + // update_txn uses read_version 4 → new_version is 5 let manifest = make_stable_row_id_manifest(vec![existing_fragment]); let (result, _) = update_txn(vec![new_fragment]) .build_manifest( @@ -5611,11 +5637,70 @@ mod tests { ) .unwrap(); - // Row 10: offset 0 in frag 1 → version 5. Row 999: unknown → default 1 - assert_eq!(created_at_versions(&result, 10), vec![5, 1]); + // Row 10 (UPDATE branch): created_at copied from source (version 5). + // Row 999 (INSERT branch): created_at == new_version (5). + assert_eq!(created_at_versions(&result, 10), vec![5, 5]); assert_eq!(last_updated_at_versions(&result, 10), vec![5, 5]); } + #[test] + fn test_update_version_tracking_merge_into_distinguishes_insert_and_update_branch() { + // Verifies the MERGE INTO correctness contract when UPDATE branch rows and INSERT + // branch rows have *different* source created_at values, so we can distinguish + // which row got which value. + // + // Existing fragment (id=1): row IDs [10, 11], created_at = version 3. + // New fragment (id=20): row IDs [10, 500, 11, 501]. + // - Rows 10 and 11: UPDATE branch (present in existing fragment) → created_at = 3. + // - Rows 500 and 501: INSERT branch (no source) → created_at = new_version = 5. + let existing_seq = RowIdSequence::from([10u64, 11].as_slice()); + let existing_created = RowDatasetVersionSequence { + runs: vec![RowDatasetVersionRun { + span: U64Segment::Range(0..2), + version: 3, + }], + }; + let existing_fragment = Fragment { + id: 1, + files: vec![], + deletion_file: None, + row_id_meta: Some(RowIdMeta::Inline(write_row_ids(&existing_seq))), + physical_rows: Some(2), + created_at_version_meta: Some( + RowDatasetVersionMeta::from_sequence(&existing_created).unwrap(), + ), + last_updated_at_version_meta: None, + }; + + let new_seq = RowIdSequence::from([10u64, 500, 11, 501].as_slice()); + let new_fragment = Fragment { + id: 20, + files: vec![], + deletion_file: None, + row_id_meta: Some(RowIdMeta::Inline(write_row_ids(&new_seq))), + physical_rows: Some(4), + created_at_version_meta: None, + last_updated_at_version_meta: None, + }; + + // update_txn uses read_version 4 → new_version is 5 + let manifest = make_stable_row_id_manifest(vec![existing_fragment]); + let (result, _) = update_txn(vec![new_fragment]) + .build_manifest( + Some(&manifest), + vec![], + "txn", + &ManifestWriteConfig::default(), + ) + .unwrap(); + + // UPDATE branch rows (10, 11): created_at preserved from source (version 3). + // INSERT branch rows (500, 501): created_at == new_version (5). + assert_eq!(created_at_versions(&result, 20), vec![3, 5, 3, 5]); + // All rows in the new fragment get last_updated == new_version. + assert_eq!(last_updated_at_versions(&result, 20), vec![5, 5, 5, 5]); + } + #[test] fn test_update_version_tracking_source_fragment_no_created_at_defaults_to_1() { // Source fragment has row_id_meta but no created_at_version_meta. @@ -5691,8 +5776,9 @@ mod tests { .unwrap(); // Fragment starts with no row_id_meta → assign_row_ids gives it fresh IDs → - // those IDs aren't found in existing fragments → created_at defaults to 1 - assert_eq!(created_at_versions(&result, 10), vec![1, 1, 1]); + // those IDs have no source in existing fragments (INSERT branch) → + // created_at == new_version (5) for each row. + assert_eq!(created_at_versions(&result, 10), vec![5, 5, 5]); assert_eq!(last_updated_at_versions(&result, 10), vec![5, 5, 5]); } From 742e6a317c95840ca18adf47646ea79e0070c57b Mon Sep 17 00:00:00 2001 From: "nathan.ma" Date: Wed, 20 May 2026 13:23:33 +0800 Subject: [PATCH 06/23] fix: branch_identfier unstable for legacy branches (#6390) ## Problem Legacy branches, i.e. branches whose `BranchContents` were written without a persisted `branch_identifier`, currently deserialize through `BranchIdentifier::none()`. That fallback generates a fresh random UUID on each read, so the same unchanged branch can surface a different `branch_identifier` across repeated loads. This makes branch identity unstable in both Python and Java for legacy datasets. On the Python side, `branches.list()` / `branches_ordered()` expose `branch_identifier` directly, so callers that diff, cache, or snapshot branch metadata can observe false changes even when the branch itself has not changed. On the Java side, the same legacy branch can also appear with a different identifier across refreshes, which makes equality-style comparisons unstable as well. ## Summary - stabilize fallback branch identifiers for legacy branch metadata by replacing the missing-identifier sentinel with a deterministic synthetic UUID during branch metadata reads - keep the fallback logic localized to Rust branch metadata loading so Python and Java continues returning stable `branch_identifier` values without API shape changes - add a lightweight Rust regression test that exercises `BranchContents::from_path` on in-memory branch metadata and verifies stable repeated reads plus distinct identifiers for different branch names --- rust/lance/src/dataset/refs.rs | 127 ++++++++++++++++++++++++++++++--- 1 file changed, 116 insertions(+), 11 deletions(-) diff --git a/rust/lance/src/dataset/refs.rs b/rust/lance/src/dataset/refs.rs index 12f91a3950f..98b4f0cbc0a 100644 --- a/rust/lance/src/dataset/refs.rs +++ b/rust/lance/src/dataset/refs.rs @@ -398,6 +398,7 @@ impl Branches<'_> { let contents = BranchContents::from_path( &branch_contents_path(branch_path, &name), self.object_store(), + &name, ) .await?; Ok((name, contents)) @@ -425,7 +426,8 @@ impl Branches<'_> { }); } - let branch_contents = BranchContents::from_path(&branch_file, self.object_store()).await?; + let branch_contents = + BranchContents::from_path(&branch_file, self.object_store(), branch).await?; Ok(branch_contents) } @@ -481,7 +483,7 @@ impl Branches<'_> { let parent_branch_id = if let Some(ref parent_branch) = source_branch { let parent_file = branch_contents_path(&root_location.path, parent_branch); if self.object_store().exists(&parent_file).await? { - BranchContents::from_path(&parent_file, self.object_store()) + BranchContents::from_path(&parent_file, self.object_store(), parent_branch) .await? .identifier } else { @@ -531,7 +533,7 @@ impl Branches<'_> { } let mut branch_contents = - BranchContents::from_path(&branch_file, self.object_store()).await?; + BranchContents::from_path(&branch_file, self.object_store(), branch).await?; branch_contents.metadata = metadata; self.object_store() @@ -747,7 +749,7 @@ pub struct TagContents { #[serde(rename_all = "camelCase")] pub struct BranchContents { pub parent_branch: Option, - #[serde(default = "BranchIdentifier::none")] + #[serde(default = "BranchIdentifier::missing_identifier_sentinel")] pub identifier: BranchIdentifier, pub parent_version: u64, pub create_at: u64, // unix timestamp @@ -771,14 +773,60 @@ impl BranchIdentifier { Self { version_mapping } } - /// Creates a branch identifier for legacy branches without explicit lineage. - /// Legacy branches have parent_version=0 and are skipped during cleanup. - pub fn none() -> Self { + /// Creates a sentinel identifier for legacy branch metadata that lacks an explicit + /// identifier. + /// + /// `BranchContents::from_path` replaces this value with a deterministic synthetic + /// identifier. Keeping this sentinel stable lets us distinguish missing identifiers from + /// persisted identifiers without changing this field to `Option`. + pub fn missing_identifier_sentinel() -> Self { Self { - version_mapping: vec![(0, Uuid::new_v4().simple().to_string())], + version_mapping: vec![(0, Uuid::nil().simple().to_string())], } } + fn synthetic_identifier( + branch_name: &str, + parent_branch: Option<&str>, + parent_version: u64, + create_at: u64, + ) -> Self { + let identifier_input = format!( + "branch_name={branch_name}\nparent_branch={}\nparent_version={parent_version}\ncreate_at={create_at}", + parent_branch.unwrap_or("") + ); + Self { + version_mapping: vec![( + 0, + Uuid::from_bytes(Self::synthetic_identifier_bytes( + identifier_input.as_bytes(), + )) + .simple() + .to_string(), + )], + } + } + + fn synthetic_identifier_bytes(input: &[u8]) -> [u8; 16] { + // Use fixed, local hashing so legacy fallback identifiers stay deterministic without + // enabling extra UUID generation features. + const FNV_OFFSET: u64 = 0xcbf29ce484222325; + const FNV_PRIME: u64 = 0x100000001b3; + + fn hash_with_seed(input: &[u8], seed: u64) -> u64 { + input.iter().fold(seed, |hash, byte| { + (hash ^ u64::from(*byte)).wrapping_mul(FNV_PRIME) + }) + } + + let first = hash_with_seed(input, FNV_OFFSET); + let second = hash_with_seed(input, FNV_OFFSET ^ 0x9e3779b97f4a7c15); + let mut bytes = [0; 16]; + bytes[..8].copy_from_slice(&first.to_be_bytes()); + bytes[8..].copy_from_slice(&second.to_be_bytes()); + bytes + } + pub fn main() -> Self { Self { version_mapping: vec![], @@ -896,8 +944,24 @@ impl TagContents { } impl BranchContents { - pub async fn from_path(path: &Path, object_store: &ObjectStore) -> Result { - from_path(path, object_store).await + pub async fn from_path( + path: &Path, + object_store: &ObjectStore, + branch_name: &str, + ) -> Result { + let mut contents: Self = from_path(path, object_store).await?; + if contents.identifier == BranchIdentifier::missing_identifier_sentinel() { + // Legacy branch files do not store an identifier. Derive a deterministic fallback + // from stable branch metadata so repeated reads expose the same public + // branch_identifier. + contents.identifier = BranchIdentifier::synthetic_identifier( + branch_name, + contents.parent_branch.as_deref(), + contents.parent_version, + contents.create_at, + ); + } + Ok(contents) } } @@ -1160,7 +1224,7 @@ mod tests { async fn test_branch_contents_serialization() { let branch_contents = BranchContents { parent_branch: Some("main".to_string()), - identifier: BranchIdentifier::none(), + identifier: BranchIdentifier::missing_identifier_sentinel(), parent_version: 42, create_at: 1234567890, manifest_size: 1024, @@ -1189,6 +1253,47 @@ mod tests { assert!(legacy_deserialized.metadata.is_empty()); } + #[tokio::test] + async fn test_branch_synthetic_uuid_is_stable() { + let legacy_json = r#"{"parentBranch":"main","parentVersion":42,"createAt":1234567890,"manifestSize":1024}"#; + let store = ObjectStore::memory(); + let base_path = Path::from("dataset"); + let first_path = branch_contents_path(&base_path, "legacy_branch"); + store + .put(&first_path, legacy_json.as_bytes()) + .await + .unwrap(); + let second_path = branch_contents_path(&base_path, "legacy_branch_other"); + store + .put(&second_path, legacy_json.as_bytes()) + .await + .unwrap(); + + let first = BranchContents::from_path(&first_path, &store, "legacy_branch") + .await + .unwrap(); + let second = BranchContents::from_path(&first_path, &store, "legacy_branch") + .await + .unwrap(); + assert_eq!(first.identifier, second.identifier); + assert_ne!( + first.identifier, + BranchIdentifier::missing_identifier_sentinel() + ); + assert_eq!(first.identifier.version_mapping[0].1.len(), 32); + assert!( + first.identifier.version_mapping[0] + .1 + .chars() + .all(|ch| ch.is_ascii_hexdigit() && !ch.is_ascii_uppercase()) + ); + + let other = BranchContents::from_path(&second_path, &store, "legacy_branch_other") + .await + .unwrap(); + assert_ne!(first.identifier, other.identifier); + } + #[tokio::test] async fn test_tag_contents_serialization() { let tag_contents = TagContents { From 79b363af5e869d7a3a2913ac3c8d5e61d3af6d9d Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Wed, 20 May 2026 18:50:05 +0800 Subject: [PATCH 07/23] docs: document PR publishing requirements (#6870) Document the PR publishing requirements in `AGENTS.md` so agent-created PRs match the checks enforced by CI before they are opened or updated. This records two required gates: - PR titles must follow Conventional Commits because `.github/workflows/pr-title.yml` validates the title and body with commitlint. - PRs must run lint checks for every touched language surface before creation or update. Rust changes require `cargo fmt --all` and `cargo clippy --all --tests --benches -- -D warnings`; Python changes require the `python/AGENTS.md` environment workflow and `uv run make lint` from `python/`. If a required lint check cannot be run, the blocker must be stated explicitly in the PR summary. --- AGENTS.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index df823a8ac5e..2003d6dba10 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,6 +130,11 @@ AWS_DEFAULT_REGION=us-east-1 pytest --run-integration python/tests/test_s3_ddb.p - Indent content under MkDocs admonition directives (`!!! note`, etc.) with 4 spaces. - Proofread comments and docs for typos before committing. +## Pull Requests + +- PR titles must follow the Conventional Commits specification because `.github/workflows/pr-title.yml` validates the PR title and body with commitlint. Use prefixes like `feat:`, `fix:`, `docs:`, `perf:`, `ci:`, `test:`, `build:`, `style:`, or `chore:`; add a scope when useful. +- Before creating or updating a PR, run the lint checks for every touched language surface, even when they are expensive. For Rust changes, run `cargo fmt --all` and `cargo clippy --all --tests --benches -- -D warnings`. For Python changes, follow the environment workflow in `python/AGENTS.md` and run `uv run make lint` from `python/`. If a required lint check cannot be run, state the blocker explicitly in the PR summary. + ## Review Guidelines Contributor and maintainer attention is the most valuable resource. Less is more. From 16070102b20bca82b7a8c257d1e64899d7063a90 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 20 May 2026 09:08:18 -0700 Subject: [PATCH 08/23] feat(index): serializable cache for the BTree scalar index (#6793) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Makes BTree scalar index cache entries serializable, so a persistent cache backend can store and reload them without re-reading from storage. Previously the whole BTree index was cached as `Arc` under an `UnsizedCacheKey`, which can never carry a codec, and each `FlatIndex` page was cached in-memory only. Changes: - `CacheCodecImpl for FlatIndex` (one BTree page) and `BTreePageKey::codec()`. - Top-level scalar index caching becomes a plugin implementation detail via the existing `ScalarIndexPlugin::get_from_cache`/`put_in_cache` hooks. The default impl preserves today's in-memory unsized caching (backwards compatible); the BTree plugin overrides it with a sized, codec-backed `BTreeIndexState` (the lookup `RecordBatch` + `batch_size` + `ranges_to_files`, from which `try_from_serialized` rebuilds the index with no IO). - Caching moves into `scalar::open_scalar_index` (get → miss → load → put); the dataset-level `ScalarIndexCacheKey` logic is removed from `Dataset::open_scalar_index`. This keeps index-type-specific knowledge in `lance-index` rather than leaking a state trait + dispatch into `lance/src/index.rs`. Adds an integration test asserting that after prewarming with a serializing cache backend, an indexed-filter query does 0 read IOPS. Bitmap index will follow the same pattern in a separate PR. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- rust/lance-index/src/scalar/btree.rs | 516 +++++++++++++++++- rust/lance-index/src/scalar/btree/flat.rs | 69 +++ rust/lance-index/src/scalar/registry.rs | 62 ++- rust/lance/src/dataset/tests/dataset_index.rs | 103 ++++ rust/lance/src/index.rs | 59 +- rust/lance/src/index/scalar.rs | 18 +- 6 files changed, 767 insertions(+), 60 deletions(-) diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index f2be3241b85..eba6f1c6205 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -44,9 +44,10 @@ use futures::{ future::BoxFuture, stream::{self}, }; +use lance_arrow::ipc::{read_ipc_stream_single_at, write_ipc_stream}; use lance_core::{ Error, ROW_ID, Result, - cache::{CacheKey, LanceCache, WeakLanceCache}, + cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, utils::{ mask::NullableRowAddrSet, @@ -998,6 +999,173 @@ impl CacheKey for BTreePageKey { fn type_name() -> &'static str { "BTreePage" } + + fn codec() -> Option { + // Pages are cached as `FlatIndex` values (see `ValueType` above). + Some(CacheCodec::from_impl::()) + } +} + +/// The serializable state of a [`BTreeIndex`]. +/// +/// A `BTreeIndex` holds non-serializable infrastructure (an `IndexStore`, a +/// cache handle, a fragment-reuse index). `BTreeIndexState` captures just the +/// data needed to rebuild it: the `page_lookup.lance` batch (from which +/// `BTreeIndex::try_from_serialized` reconstructs the in-memory lookup with +/// no IO) plus the page batch size and range-partition map. +#[derive(Debug, Clone)] +pub struct BTreeIndexState { + lookup_batch: RecordBatch, + batch_size: u64, + ranges_to_files: Option>>, +} + +impl DeepSizeOf for BTreeIndexState { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + // `ranges_to_files` is tiny and `RangeInclusiveMap` is not `DeepSizeOf`; + // the lookup batch dominates, matching how `BTreeIndex` accounts for itself. + self.lookup_batch.get_array_memory_size() + } +} + +impl BTreeIndexState { + fn reconstruct( + &self, + store: Arc, + index_cache: &LanceCache, + frag_reuse_index: Option>, + ) -> Result> { + let index = BTreeIndex::try_from_serialized( + self.lookup_batch.clone(), + store, + index_cache, + self.batch_size, + self.ranges_to_files.clone(), + frag_reuse_index, + )?; + Ok(Arc::new(index) as Arc) + } +} + +impl CacheCodecImpl for BTreeIndexState { + /// Wire format (no stability guarantees yet — the cache is rebuilt from + /// source on any version mismatch): + /// ```text + /// u64 batch_size (LE) + /// u8 has_ranges (0 = None, 1 = Some) + /// if has_ranges: + /// u32 entry_count (LE) + /// per entry: u32 start | u32 end | u32 offset | u32 path_len | path bytes + /// lookup batch (Arrow IPC stream) + /// ``` + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + writer.write_all(&self.batch_size.to_le_bytes())?; + match &self.ranges_to_files { + None => writer.write_all(&[0u8])?, + Some(ranges) => { + writer.write_all(&[1u8])?; + let count = u32::try_from(ranges.len()).map_err(|_| { + Error::io("BTreeIndexState: ranges_to_files exceeds u32::MAX entries") + })?; + writer.write_all(&count.to_le_bytes())?; + for (range, (path, page_offset)) in ranges.iter() { + writer.write_all(&range.start().to_le_bytes())?; + writer.write_all(&range.end().to_le_bytes())?; + writer.write_all(&page_offset.to_le_bytes())?; + let path_len = u32::try_from(path.len()).map_err(|_| { + Error::io("BTreeIndexState: ranges_to_files path exceeds u32::MAX bytes") + })?; + writer.write_all(&path_len.to_le_bytes())?; + writer.write_all(path.as_bytes())?; + } + } + } + write_ipc_stream(&self.lookup_batch, writer)?; + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + let mut offset = 0; + let batch_size = read_u64_le(data, &mut offset)?; + let has_ranges = read_u8(data, &mut offset)?; + let ranges_to_files = match has_ranges { + 0 => None, + 1 => { + let count = read_u32_le(data, &mut offset)? as usize; + let mut entries = Vec::with_capacity(count); + for _ in 0..count { + let start = read_u32_le(data, &mut offset)?; + let end = read_u32_le(data, &mut offset)?; + let page_offset = read_u32_le(data, &mut offset)?; + let path_len = read_u32_le(data, &mut offset)? as usize; + let path = read_bytes(data, &mut offset, path_len)?; + let path = std::str::from_utf8(&path) + .map_err(|e| Error::io(format!("BTreeIndexState path: {e}")))? + .to_string(); + entries.push((start..=end, (path, page_offset))); + } + Some(Arc::new(entries.into_iter().collect())) + } + other => { + return Err(Error::io(format!( + "BTreeIndexState: invalid has_ranges tag {other}" + ))); + } + }; + let lookup_batch = read_ipc_stream_single_at(data, &mut offset)?; + Ok(Self { + lookup_batch, + batch_size, + ranges_to_files, + }) + } +} + +fn read_bytes(data: &bytes::Bytes, offset: &mut usize, len: usize) -> Result { + if data.len() < *offset + len { + return Err(Error::io(format!( + "BTreeIndexState: short read of {len} bytes at offset {offset} (have {})", + data.len() + ))); + } + let slice = data.slice(*offset..*offset + len); + *offset += len; + Ok(slice) +} + +fn read_u8(data: &bytes::Bytes, offset: &mut usize) -> Result { + let bytes = read_bytes(data, offset, 1)?; + Ok(bytes[0]) +} + +fn read_u32_le(data: &bytes::Bytes, offset: &mut usize) -> Result { + let bytes = read_bytes(data, offset, 4)?; + Ok(u32::from_le_bytes(bytes.as_ref().try_into().unwrap())) +} + +fn read_u64_le(data: &bytes::Bytes, offset: &mut usize) -> Result { + let bytes = read_bytes(data, offset, 8)?; + Ok(u64::from_le_bytes(bytes.as_ref().try_into().unwrap())) +} + +/// Cache key for a [`BTreeIndexState`]. The cache it is used with is already +/// namespaced per-index, so the key string is a constant. +struct BTreeIndexStateKey; + +impl CacheKey for BTreeIndexStateKey { + type ValueType = BTreeIndexState; + + fn key(&self) -> std::borrow::Cow<'_, str> { + "state".into() + } + + fn type_name() -> &'static str { + "BTreeIndexState" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } } /// Note: this is very similar to the IVF index except we store the IVF part in a btree @@ -1040,13 +1208,26 @@ pub struct BTreeIndex { /// - The system now knows to read page `42` from the file `part_2_page_file.lance`. ranges_to_files: Option>>, frag_reuse_index: Option>, + + /// The raw lookup batch this index was built from (the contents of + /// `page_lookup.lance`). Retained so the index can be serialized into a + /// cache as a [`BTreeIndexState`] without re-reading it from storage. + /// + /// TODO: this duplicates the min/max values already held in `page_lookup`. + /// A follow-up could rewrite `BTreeLookup` to query this batch directly + /// (binary search on the sorted `min` column + linear scan, type-dispatched + /// per column type), eliminating the duplication and making this batch the + /// single source of truth. + lookup_batch: RecordBatch, } impl DeepSizeOf for BTreeIndex { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { // We don't include the index cache, or anything stored in it. For example: // sub_index and fri. - self.page_lookup.deep_size_of_children(context) + self.store.deep_size_of_children(context) + self.page_lookup.deep_size_of_children(context) + + self.store.deep_size_of_children(context) + + self.lookup_batch.get_array_memory_size() } } @@ -1060,6 +1241,7 @@ impl BTreeIndex { batch_size: u64, ranges_to_files: Option>>, frag_reuse_index: Option>, + lookup_batch: RecordBatch, ) -> Self { Self { page_lookup, @@ -1069,6 +1251,7 @@ impl BTreeIndex { batch_size, ranges_to_files, frag_reuse_index, + lookup_batch, } } @@ -1158,6 +1341,7 @@ impl BTreeIndex { batch_size, ranges_to_files, frag_reuse_index, + data, )); } @@ -1199,18 +1383,19 @@ impl BTreeIndex { let last_max = ScalarValue::try_from_array(&maxs, data.num_rows() - 1)?; map.entry(OrderableScalarValue(last_max)).or_default(); - let data_type = mins.data_type(); + let data_type = mins.data_type().clone(); let page_lookup = Arc::new(BTreeLookup::new(map, null_pages, all_null_pages)); Ok(Self::new( page_lookup, store, - data_type.clone(), + data_type, WeakLanceCache::from(index_cache), batch_size, ranges_to_files, frag_reuse_index, + data, )) } @@ -2753,6 +2938,37 @@ impl ScalarIndexPlugin for BTreeIndexPlugin { ) -> Result> { Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) } + + async fn get_from_cache( + &self, + index_store: Arc, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result>> { + let Some(state) = cache.get_with_key(&BTreeIndexStateKey).await else { + return Ok(None); + }; + Ok(Some(state.reconstruct( + index_store, + cache, + frag_reuse_index, + )?)) + } + + async fn put_in_cache(&self, cache: &LanceCache, index: Arc) -> Result<()> { + let btree = index.as_any().downcast_ref::().ok_or_else(|| { + Error::internal("BTreeIndexPlugin::put_in_cache called with a non-BTree index") + })?; + let state = BTreeIndexState { + lookup_batch: btree.lookup_batch.clone(), + batch_size: btree.batch_size, + ranges_to_files: btree.ranges_to_files.clone(), + }; + cache + .insert_with_key(&BTreeIndexStateKey, Arc::new(state)) + .await; + Ok(()) + } } #[cfg(test)] @@ -2791,9 +3007,13 @@ mod tests { }; use super::{ - DEFAULT_BTREE_BATCH_SIZE, OrderableScalarValue, part_lookup_file_path, - part_page_data_file_path, train_btree_index, + BTreeIndexPlugin, BTreeIndexState, BTreePageKey, DEFAULT_BTREE_BATCH_SIZE, + OrderableScalarValue, part_lookup_file_path, part_page_data_file_path, train_btree_index, }; + use crate::scalar::registry::ScalarIndexPlugin; + use arrow_array::RecordBatch; + use lance_core::cache::{CacheCodecImpl, CacheKey}; + use rangemap::RangeInclusiveMap; lance_testing::define_stage_event_progress!( RecordingProgress, @@ -4769,4 +4989,288 @@ mod tests { _ => panic!("BTree search should return Exact"), } } + + fn sample_lookup_batch() -> RecordBatch { + record_batch!( + ("min", Int32, [Some(0), Some(10), Some(20)]), + ("max", Int32, [Some(9), Some(19), Some(29)]), + ("null_count", UInt32, [0, 2, 0]), + ("page_idx", UInt32, [0, 1, 2]) + ) + .unwrap() + } + + fn assert_state_roundtrips(state: &BTreeIndexState) { + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + let restored = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + assert_eq!(restored.lookup_batch, state.lookup_batch); + assert_eq!(restored.batch_size, state.batch_size); + assert_eq!(restored.ranges_to_files, state.ranges_to_files); + } + + #[test] + fn test_btree_page_key_codec() { + // FlatIndex pages can be serialized by a persistent cache backend. + assert!(BTreePageKey::codec().is_some()); + } + + #[test] + fn test_btree_index_state_roundtrip() { + // Not range-partitioned. + assert_state_roundtrips(&BTreeIndexState { + lookup_batch: sample_lookup_batch(), + batch_size: DEFAULT_BTREE_BATCH_SIZE, + ranges_to_files: None, + }); + + // Range-partitioned across multiple files. + let ranges: RangeInclusiveMap = [ + (0..=99, ("part_0_page_file.lance".to_string(), 0)), + (100..=199, ("part_1_page_file.lance".to_string(), 100)), + ] + .into_iter() + .collect(); + assert_state_roundtrips(&BTreeIndexState { + lookup_batch: sample_lookup_batch(), + batch_size: 8192, + ranges_to_files: Some(Arc::new(ranges)), + }); + + // Empty index. + assert_state_roundtrips(&BTreeIndexState { + lookup_batch: RecordBatch::new_empty(sample_lookup_batch().schema()), + batch_size: DEFAULT_BTREE_BATCH_SIZE, + ranges_to_files: None, + }); + } + + #[tokio::test] + async fn test_btree_index_state_reconstruct_and_plugin_cache() { + let tmpdir = TempObjDir::default(); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let stream = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(1000), BatchCount::from(5)); + train_btree_index(stream, test_store.as_ref(), 1000, None, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Round-trip the state through the codec and reconstruct an index from it. + let state = BTreeIndexState { + lookup_batch: index.lookup_batch.clone(), + batch_size: index.batch_size, + ranges_to_files: index.ranges_to_files.clone(), + }; + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + let restored = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + let reconstructed = restored + .reconstruct(test_store.clone(), &LanceCache::no_cache(), None) + .unwrap(); + assert_eq!( + reconstructed + .as_any() + .downcast_ref::() + .unwrap() + .page_lookup, + index.page_lookup + ); + + // The plugin's put/get hooks round-trip through a real cache + the codec. + let cache = LanceCache::with_capacity(64 * 1024 * 1024); + let plugin = BTreeIndexPlugin; + plugin.put_in_cache(&cache, index.clone()).await.unwrap(); + let from_cache = plugin + .get_from_cache(test_store.clone(), None, &cache) + .await + .unwrap() + .expect("index should be served from the cache"); + + // Searches against the cached index match the original. + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int32(Some(100))), + std::ops::Bound::Excluded(ScalarValue::Int32(Some(200))), + ); + let expected = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + let actual = from_cache + .search(&query, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(expected, actual); + } + + #[test] + fn test_btree_index_state_rejects_invalid_has_ranges_tag() { + // u64 batch_size (any) then a bad has_ranges tag. + let mut buf = Vec::new(); + buf.extend_from_slice(&1000u64.to_le_bytes()); + buf.push(7u8); + let err = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("has_ranges") && msg.contains("7"), + "expected error to mention the bad has_ranges tag, got: {msg}" + ); + } + + #[tokio::test] + async fn test_btree_index_state_reconstruct_applies_frag_reuse_index() { + use crate::frag_reuse::{FragReuseIndex, FragReuseIndexDetails}; + use std::collections::HashMap; + use uuid::Uuid; + + let tmpdir = TempObjDir::default(); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // value == _rowid for all rows in [0, 1000). + let stream = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(1000), BatchCount::from(1)); + train_btree_index(stream, test_store.as_ref(), 1000, None, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + let state = BTreeIndexState { + lookup_batch: index.lookup_batch.clone(), + batch_size: index.batch_size, + ranges_to_files: index.ranges_to_files.clone(), + }; + + // Remap row 0 -> row 5000 (outside the original [0, 1000) range so no collision). + // Querying for value == 0 should now return row 5000, confirming reconstruct threaded + // the FragReuseIndex through to the rebuilt BTreeIndex. + let frag_reuse_index = Arc::new(FragReuseIndex::new( + Uuid::new_v4(), + vec![HashMap::from([(0u64, Some(5000u64))])], + FragReuseIndexDetails { versions: vec![] }, + )); + let reconstructed = state + .reconstruct( + test_store.clone(), + &LanceCache::no_cache(), + Some(frag_reuse_index), + ) + .unwrap(); + + let result = reconstructed + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(0))), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let row_ids: Vec = match &result { + SearchResult::Exact(set) => set + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(), + other => panic!("expected Exact, got {other:?}"), + }; + assert_eq!( + row_ids, + vec![5000], + "frag_reuse_index remap was not applied" + ); + } + + #[tokio::test] + async fn test_btree_index_state_range_partitioned_plugin_cache_roundtrip() { + // Build a range-partitioned BTree (two range partitions merged into one index) and + // round-trip it through the plugin's cache hooks. This exercises the + // `ranges_to_files = Some` path end-to-end through serialize/deserialize/reconstruct. + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let half = DEFAULT_BTREE_BATCH_SIZE; + let total = (2 * half) as i32; + + // Partition 0: values/rowids [0, half). + let part0 = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half), BatchCount::from(1)); + train_btree_index(part0, store.as_ref(), half, None, Some(0u32)) + .await + .unwrap(); + + // Partition 1: values/rowids [half, 2*half). + let values: Vec = (half as i32..total).collect(); + let row_ids: Vec = (half..total as u64).collect(); + let part1 = gen_batch() + .col("value", array::cycle::(values)) + .col("_rowid", array::cycle::(row_ids)) + .into_df_stream(RowCount::from(half), BatchCount::from(1)); + train_btree_index(part1, store.as_ref(), half, None, Some(1u32)) + .await + .unwrap(); + + super::merge_metadata_files( + store.as_ref(), + &[ + part_page_data_file_path(0 << 32), + part_page_data_file_path(1 << 32), + ], + &[ + part_lookup_file_path(0 << 32), + part_lookup_file_path(1 << 32), + ], + Some(1usize), + noop_progress(), + ) + .await + .unwrap(); + + let index = BTreeIndex::load(store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + assert!( + index.ranges_to_files.is_some(), + "test setup should produce a range-partitioned index", + ); + + let cache = LanceCache::with_capacity(64 * 1024 * 1024); + let plugin = BTreeIndexPlugin; + plugin.put_in_cache(&cache, index.clone()).await.unwrap(); + let from_cache = plugin + .get_from_cache(store.clone(), None, &cache) + .await + .unwrap() + .expect("index should be served from the cache"); + + // Search a value from each range partition and confirm both paths agree. + for value in [0i32, total - 1] { + let query = SargableQuery::Equals(ScalarValue::Int32(Some(value))); + let expected = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + let actual = from_cache + .search(&query, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(expected, actual, "value {value}"); + } + } } diff --git a/rust/lance-index/src/scalar/btree/flat.rs b/rust/lance-index/src/scalar/btree/flat.rs index 113a850315b..10f0b1ad339 100644 --- a/rust/lance-index/src/scalar/btree/flat.rs +++ b/rust/lance-index/src/scalar/btree/flat.rs @@ -14,7 +14,9 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_physical_expr::create_physical_expr; use deepsize::DeepSizeOf; use lance_arrow::RecordBatchExt; +use lance_arrow::ipc::{read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream}; use lance_core::Result; +use lance_core::cache::CacheCodecImpl; use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps}; use roaring::RoaringBitmap; @@ -233,6 +235,45 @@ impl FlatIndex { } } +impl CacheCodecImpl for FlatIndex { + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + // Format: + // [len-prefixed all_addrs_map][len-prefixed null_addrs_map][batch IPC stream] + writer.write_all(&(self.all_addrs_map.serialized_size() as u64).to_le_bytes())?; + self.all_addrs_map.serialize_into(&mut *writer)?; + + writer.write_all(&(self.null_addrs_map.serialized_size() as u64).to_le_bytes())?; + self.null_addrs_map.serialize_into(&mut *writer)?; + + write_ipc_stream(self.data.as_ref(), writer)?; + + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result + where + Self: Sized, + { + let mut offset = 0; + let all_addrs_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let all_addrs_map = RowAddrTreeMap::deserialize_from(all_addrs_bytes.as_ref())?; + + let null_addrs_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let null_addrs_map = RowAddrTreeMap::deserialize_from(null_addrs_bytes.as_ref())?; + + let batch = read_ipc_stream_single_at(data, &mut offset)?; + + let df_schema = DFSchema::try_from(batch.schema())?; + + Ok(Self { + data: Arc::new(batch), + all_addrs_map, + null_addrs_map, + df_schema, + }) + } +} + #[cfg(test)] mod tests { use crate::{ @@ -266,6 +307,34 @@ mod tests { assert_eq!(actual, expected); } + fn assert_roundtrips(index: &FlatIndex) { + let mut buf = Vec::new(); + index.serialize(&mut buf).unwrap(); + let restored = FlatIndex::deserialize(&bytes::Bytes::from(buf)).unwrap(); + + assert_eq!(restored.data, index.data); + assert_eq!(restored.all_addrs_map, index.all_addrs_map); + assert_eq!(restored.null_addrs_map, index.null_addrs_map); + } + + #[test] + fn test_cache_codec_roundtrip() { + // No nulls + assert_roundtrips(&example_index()); + + // With nulls in the values column + let batch = record_batch!( + (BTREE_VALUES_COLUMN, Int32, [None, Some(0), Some(5)]), + (BTREE_IDS_COLUMN, UInt64, [0, 1, 2]) + ) + .unwrap(); + assert_roundtrips(&FlatIndex::try_new(batch).unwrap()); + + // Empty index + let empty = RecordBatch::new_empty(example_index().data.schema()); + assert_roundtrips(&FlatIndex::try_new(empty).unwrap()); + } + #[tokio::test] async fn test_equality() { check_index(&SargableQuery::Equals(ScalarValue::from(100)), &[0]).await; diff --git a/rust/lance-index/src/scalar/registry.rs b/rust/lance-index/src/scalar/registry.rs index 4e44c207041..0add98d8ab3 100644 --- a/rust/lance-index/src/scalar/registry.rs +++ b/rust/lance-index/src/scalar/registry.rs @@ -1,12 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::borrow::Cow; use std::sync::Arc; use arrow_schema::Field; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; -use lance_core::{Result, cache::LanceCache}; +use lance_core::{ + Result, + cache::{LanceCache, UnsizedCacheKey}, +}; use crate::progress::IndexBuildProgress; use crate::registry::IndexPluginRegistry; @@ -158,6 +162,40 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { cache: &LanceCache, ) -> Result>; + /// Look up a previously-opened index in the cache. + /// + /// `cache` is already per-index namespaced by the caller, so a plugin's key + /// only needs to disambiguate entries within a single index. + /// + /// The default implementation reads an in-memory `Arc` entry. + /// Plugins whose index has a serializable representation should override this + /// (together with [`put_in_cache`](Self::put_in_cache)) to store that + /// representation under a sized [`CacheKey`](lance_core::cache::CacheKey) with + /// a codec, and reconstruct the index here. `index_store` and + /// `frag_reuse_index` are provided so the override can rebuild the index + /// without re-reading metadata. + async fn get_from_cache( + &self, + _index_store: Arc, + _frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result>> { + Ok(cache.get_unsized_with_key(&ScalarIndexCacheKey).await) + } + + /// Store a freshly-opened index in the cache. + /// + /// `cache` is already per-index namespaced; see + /// [`get_from_cache`](Self::get_from_cache). + /// + /// The default implementation stores the `Arc` in-memory. + async fn put_in_cache(&self, cache: &LanceCache, index: Arc) -> Result<()> { + cache + .insert_unsized_with_key(&ScalarIndexCacheKey, index) + .await; + Ok(()) + } + /// Optional hook allowing a plugin to provide statistics without loading the index. async fn load_statistics( &self, @@ -180,3 +218,25 @@ pub trait ScalarIndexPlugin: Send + Sync + std::fmt::Debug { Ok(serde_json::json!({})) } } + +/// In-memory cache key for a whole `Arc`. +/// +/// Used by the default [`ScalarIndexPlugin::get_from_cache`] / +/// [`ScalarIndexPlugin::put_in_cache`] implementations. The cache is already +/// per-index namespaced by the caller, so a constant key suffices. Trait objects +/// cannot be serialized, so this is an [`UnsizedCacheKey`] with no codec — +/// plugins that want a persistable cache entry override those methods with a +/// sized key. +pub struct ScalarIndexCacheKey; + +impl UnsizedCacheKey for ScalarIndexCacheKey { + type ValueType = dyn ScalarIndex; + + fn key(&self) -> Cow<'_, str> { + Cow::Borrowed("scalar_index") + } + + fn type_name() -> &'static str { + "ScalarIndex" + } +} diff --git a/rust/lance/src/dataset/tests/dataset_index.rs b/rust/lance/src/dataset/tests/dataset_index.rs index 3d127a43b3a..4fd1e2fcfd1 100644 --- a/rust/lance/src/dataset/tests/dataset_index.rs +++ b/rust/lance/src/dataset/tests/dataset_index.rs @@ -2194,6 +2194,109 @@ async fn test_fts_prewarm_with_serializing_backend_serves_query_with_no_io() { ); } +/// BTree analogue of `test_fts_prewarm_with_serializing_backend_serves_query_with_no_io`: +/// after prewarming a BTree scalar index through a serializing cache backend, +/// an indexed-filter query serves results without any further IO. The +/// serializing backend forces every cache hit through the `BTreeIndexState` +/// and `FlatIndex` `CacheCodec` impls, so this also smoke-tests those +/// round-trip paths on a multi-page index. +#[tokio::test] +async fn test_btree_prewarm_with_serializing_backend_serves_query_with_no_io() { + use lance_io::assert_io_eq; + + use fts_serializing_backend::SerializingBackend; + + let tmpdir = TempStrDir::default(); + let uri = tmpdir.to_owned(); + drop(tmpdir); + + // Enough rows to span several BTree pages (default page size is 4096) so + // the query has to consult more than one cached `FlatIndex`. + let num_rows = 16_384; + let values = Int32Array::from_iter_values(0..num_rows); + let ids = UInt64Array::from_iter_values(0..num_rows as u64); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("value", DataType::Int32, false), + arrow_schema::Field::new("id", DataType::UInt64, false), + ]) + .into(), + vec![Arc::new(values) as ArrayRef, Arc::new(ids) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, &uri, None).await.unwrap(); + dataset + .create_index( + &["value"], + IndexType::BTree, + Some("value_idx".to_owned()), + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + // Re-open on a session whose cache backend serializes every entry through + // its codec, with a generous capacity so nothing is evicted before we query. + let backend = Arc::new(SerializingBackend::new()); + let session = Arc::new(Session::with_index_cache_backend( + backend.clone(), + 128 * 1024 * 1024, + Arc::new(lance_io::object_store::ObjectStoreRegistry::default()), + )); + let dataset = DatasetBuilder::from_uri(&uri) + .with_session(session) + .load() + .await + .unwrap(); + + // Reset IO counters to isolate prewarm + query traffic from open/load. + dataset.object_store.as_ref().io_stats_incremental(); + + dataset.prewarm_index("value_idx").await.unwrap(); + + // Prewarm opens the index (serializing `BTreeIndexState`) and loads every + // page (serializing each `FlatIndex`), so the serialized store must be + // non-empty. The unsized fallback keys cannot have a codec by design. + let serialized_after_prewarm = backend.serialized_entry_count().await; + assert!( + serialized_after_prewarm > 0, + "prewarm should have routed the BTree state and pages through CacheCodec, \ + but the serializing store was empty" + ); + + // After prewarm, an indexed-filter query must reconstruct the index and + // every page it touches from the cache, deserializing via the codec, with + // no disk IO. Project only `_rowid` so the scan does not read a data column. + dataset.object_store.as_ref().io_stats_incremental(); + + let result = dataset + .scan() + .project(&[ROW_ID]) + .unwrap() + .filter("value >= 100 AND value < 200") + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!( + result.num_rows(), + 100, + "indexed filter should still return correct results after deserialization" + ); + + let stats = dataset.object_store.as_ref().io_stats_incremental(); + assert_io_eq!( + stats, + read_iops, + 0, + "BTree filter query should not perform IO after prewarm; the serializing \ + cache backend must serve the index state and every page from memory" + ); +} + #[tokio::test] async fn test_fts_phrase_query_with_removed_stop_words() { let tmpdir = TempStrDir::default(); diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index d9516bb6d0a..a721c600d3d 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -12,14 +12,13 @@ use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; use futures::FutureExt; use itertools::Itertools; -use lance_core::cache::{CacheKey, UnsizedCacheKey}; +use lance_core::cache::CacheKey; use lance_core::datatypes::Field; use lance_core::datatypes::Schema as LanceSchema; use lance_core::utils::address::RowAddress; use lance_core::utils::parse::parse_env_as_bool; use lance_core::utils::tracing::{ - IO_TYPE_OPEN_FRAG_REUSE, IO_TYPE_OPEN_MEM_WAL, IO_TYPE_OPEN_SCALAR, IO_TYPE_OPEN_VECTOR, - TRACE_IO_EVENTS, + IO_TYPE_OPEN_FRAG_REUSE, IO_TYPE_OPEN_MEM_WAL, IO_TYPE_OPEN_VECTOR, TRACE_IO_EVENTS, }; use lance_file::previous::reader::FileReader as PreviousFileReader; use lance_file::reader::FileReaderOptions; @@ -232,34 +231,6 @@ fn segment_has_inverted_details(segment: &IndexMetadata) -> bool { } // Cache keys for different index types -#[derive(Debug, Clone)] -pub struct ScalarIndexCacheKey<'a> { - pub uuid: &'a str, - pub fri_uuid: Option<&'a Uuid>, -} - -impl<'a> ScalarIndexCacheKey<'a> { - pub fn new(uuid: &'a str, fri_uuid: Option<&'a Uuid>) -> Self { - Self { uuid, fri_uuid } - } -} - -impl UnsizedCacheKey for ScalarIndexCacheKey<'_> { - type ValueType = dyn ScalarIndex; - - fn key(&self) -> std::borrow::Cow<'_, str> { - if let Some(fri_uuid) = self.fri_uuid { - format!("{}-{}", self.uuid, fri_uuid).into() - } else { - self.uuid.into() - } - } - - fn type_name() -> &'static str { - "ScalarIndex" - } -} - #[derive(Debug, Clone)] pub(crate) struct LegacyVectorIndexCacheKey<'a> { uuid: &'a str, @@ -1701,12 +1672,10 @@ impl DatasetIndexInternalExt for Dataset { uuid: &str, metrics: &dyn MetricsCollector, ) -> Result> { - // Checking for cache existence is cheap so we just check both scalar and vector caches + // Checking for cache existence is cheap so we just check the vector caches. + // Scalar indices cache themselves inside `open_scalar_index` (the cache + // key is a plugin detail), so there is no cheap scalar check here. let frag_reuse_uuid = self.frag_reuse_index_uuid().await; - let cache_key = ScalarIndexCacheKey::new(uuid, frag_reuse_uuid.as_ref()); - if let Some(index) = self.index_cache.get_unsized_with_key(&cache_key).await { - return Ok(index.as_index()); - } // Check sized cache for IvfIndexState (v2+ indices). let state_key = IvfIndexStateCacheKey::new(uuid, frag_reuse_uuid.as_ref()); @@ -1766,26 +1735,14 @@ impl DatasetIndexInternalExt for Dataset { uuid: &str, metrics: &dyn MetricsCollector, ) -> Result> { - let frag_reuse_uuid = self.frag_reuse_index_uuid().await; - let cache_key = ScalarIndexCacheKey::new(uuid, frag_reuse_uuid.as_ref()); - if let Some(index) = self.index_cache.get_unsized_with_key(&cache_key).await { - return Ok(index); - } - + // Caching (including the choice of in-memory vs. serializable state) is + // a plugin implementation detail handled inside `scalar::open_scalar_index`. let index_meta = self .load_index(uuid) .await? .ok_or_else(|| Error::index(format!("Index with id {} does not exist", uuid)))?; - let index = scalar::open_scalar_index(self, column, &index_meta, metrics).await?; - - info!(target: TRACE_IO_EVENTS, index_uuid=uuid, r#type=IO_TYPE_OPEN_SCALAR, index_type=index.index_type().to_string()); - metrics.record_index_load(); - - self.index_cache - .insert_unsized_with_key(&cache_key, index.clone()) - .await; - Ok(index) + scalar::open_scalar_index(self, column, &index_meta, metrics).await } async fn open_vector_index( diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 50d4f095c69..35a42b53828 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -23,6 +23,7 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::TryStreamExt; use itertools::Itertools; use lance_core::datatypes::Field; +use lance_core::utils::tracing::{IO_TYPE_OPEN_SCALAR, TRACE_IO_EVENTS}; use lance_core::{Error, ROW_ADDR, ROW_ID, Result}; use lance_datafusion::exec::LanceExecutionOptions; use lance_index::metrics::{MetricsCollector, NoOpMetricsCollector}; @@ -398,9 +399,22 @@ pub async fn open_scalar_index( .index_cache .for_index(&uuid_str, frag_reuse_index.as_ref().map(|f| &f.uuid)); - plugin + if let Some(index) = plugin + .get_from_cache(index_store.clone(), frag_reuse_index.clone(), &index_cache) + .await? + { + return Ok(index); + } + + let index = plugin .load_index(index_store, &index_details, frag_reuse_index, &index_cache) - .await + .await?; + + tracing::info!(target: TRACE_IO_EVENTS, index_uuid = uuid_str, r#type = IO_TYPE_OPEN_SCALAR, index_type = index.index_type().to_string()); + metrics.record_index_load(); + + plugin.put_in_cache(&index_cache, index.clone()).await?; + Ok(index) } pub(crate) async fn infer_scalar_index_details( From 6ddd7e28de7675777f68ccea5571f98f5e23dcc5 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 20 May 2026 10:11:33 -0700 Subject: [PATCH 09/23] feat: implement vector index details (#6099) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cache vector index configuration within the index metadata, such as the distance type and build parameters. Previously, to determine things like the distance type or index type of a vector index, the index file itself had to be opened. This PR stores that information in `VectorIndexDetails` within the manifest's `index_details` field, which is fetched and cached eagerly when loading the manifest. Old indexes have this field left blank. When blank, the details are extracted from the index files and cached. This migration happens on the first write with a new library version. ## What's stored in VectorIndexDetails **Core build parameters** (typed fields — required for any runtime to build the index): - `metric_type` - `target_partition_size` (IVF) - `hnsw_index_config` — `max_connections`, `construction_ef`, `max_level` (HNSW) - `compression` — PQ/SQ/RQ/flat, including `num_bits`, `num_sub_vectors`, `rotation_type` **Runtime hints** (`map runtime_hints`): Optional build preferences that don't affect index structure. Stored so a background rebuild process can reproduce the original configuration. Runtimes that don't recognize a key must silently ignore it. Only non-default values are written. Keys use reverse-DNS namespacing: `lance.*` for core Lance hints, other prefixes for runtime-specific hints (e.g., `lancedb.accelerator` for GPU acceleration in LanceDB Enterprise). Current `lance.*` hints: `lance.ivf.max_iters`, `lance.ivf.sample_rate`, `lance.ivf.shuffle_partition_batches`, `lance.ivf.shuffle_partition_concurrency`, `lance.pq.max_iters`, `lance.pq.sample_rate`, `lance.pq.kmeans_redos`, `lance.sq.sample_rate`, `lance.hnsw.prefetch_distance`, `lance.skip_transpose`. Also adds `apply_runtime_hints()` to read hints back into build params for future rebuild logic. Closes #5963 --------- Co-authored-by: Claude Opus 4.6 --- java/lance-jni/src/utils.rs | 1 + protos/index.proto | 56 + protos/table.proto | 3 +- .../tests/compat/test_vector_indices.py | 51 +- python/python/tests/test_vector_index.py | 27 +- python/src/dataset.rs | 13 + python/src/indices.rs | 3 +- rust/lance-index/src/vector/bq.rs | 14 + rust/lance-index/src/vector/hnsw/builder.rs | 10 + rust/lance-index/src/vector/pq/builder.rs | 9 + rust/lance-index/src/vector/sq/builder.rs | 8 + rust/lance/src/dataset/index.rs | 2 +- rust/lance/src/dataset/optimize.rs | 1 + rust/lance/src/dataset/scanner.rs | 30 +- rust/lance/src/index.rs | 204 ++- rust/lance/src/index/append.rs | 19 +- rust/lance/src/index/create.rs | 8 +- rust/lance/src/index/scalar.rs | 2 +- rust/lance/src/index/vector.rs | 20 +- rust/lance/src/index/vector/details.rs | 1456 +++++++++++++++++ rust/lance/src/index/vector/ivf.rs | 26 +- rust/lance/src/io/commit.rs | 2 + 22 files changed, 1878 insertions(+), 87 deletions(-) create mode 100644 rust/lance/src/index/vector/details.rs diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index c8a63c8672b..1321e8e71e3 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -477,6 +477,7 @@ pub fn get_vector_index_params( stages, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: Default::default(), }) }, )?; diff --git a/protos/index.proto b/protos/index.proto index 1fb51f3291c..ea21c70387d 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -184,6 +184,62 @@ message VectorIndex { VectorMetricType metric_type = 4; } +// Details for vector indexes, stored in the manifest's index_details field. +message VectorIndexDetails { + VectorMetricType metric_type = 1; + + // The target number of vectors per partition. + // 0 means unset. + uint64 target_partition_size = 2; + + // Optional HNSW index configuration. If set, the index has an HNSW layer. + optional HnswParameters hnsw_index_config = 3; + + message ProductQuantization { + uint32 num_bits = 1; + uint32 num_sub_vectors = 2; + } + message ScalarQuantization { + uint32 num_bits = 1; + } + message RabitQuantization { + enum RotationType { + FAST = 0; + MATRIX = 1; + } + uint32 num_bits = 1; + RotationType rotation_type = 2; + } + + // No quantization; vectors are stored as-is. + message FlatCompression {} + + oneof compression { + ProductQuantization pq = 4; + ScalarQuantization sq = 5; + RabitQuantization rq = 6; + FlatCompression flat = 8; + } + + // Runtime hints: optional build preferences that don't affect index structure. + // Keys use reverse-DNS namespacing (e.g., "lance.ivf.max_iters", "lancedb.accelerator"). + // Unrecognized keys must be silently ignored by all runtimes. + map runtime_hints = 9; +} + +// Hierarchical Navigable Small World (HNSW) parameters, used as an optional configuration for IVF indexes. +message HnswParameters { + // The maximum number of outgoing edges per node in the HNSW graph. Higher values + // means more connections, better recall, but more memory and slower builds. + // Referred to as "M" in the HNSW literature. + uint32 max_connections = 1; + // "construction exploration factor": The size of the dynamic list used during + // index construction. + uint32 construction_ef = 2; + // The maximum number of levels in the HNSW graph. + uint32 max_level = 3; +} + message JsonIndexDetails { string path = 1; google.protobuf.Any target_details = 2; diff --git a/protos/table.proto b/protos/table.proto index 9fe687bb03e..d298809d5d8 100644 --- a/protos/table.proto +++ b/protos/table.proto @@ -474,8 +474,7 @@ message ExternalFile { uint64 size = 3; } -// Empty details messages for older indexes that don't take advantage of the details field. -message VectorIndexDetails {} +// VectorIndexDetails and HnswParameters (formerly HnswIndexDetails) moved to index.proto message FragmentReuseIndexDetails { diff --git a/python/python/tests/compat/test_vector_indices.py b/python/python/tests/compat/test_vector_indices.py index 194435c095a..b98ffdf63e3 100644 --- a/python/python/tests/compat/test_vector_indices.py +++ b/python/python/tests/compat/test_vector_indices.py @@ -71,6 +71,17 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + elif self.compat_version >= "0.39.0": + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0]["name"] + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and rebuild index.""" ds = lance.dataset(self.path) @@ -140,6 +151,18 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + else: + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0]["name"] + + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and rebuild index.""" ds = lance.dataset(self.path) @@ -209,6 +232,18 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + else: + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0]["name"] + + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and rebuild index.""" ds = lance.dataset(self.path) @@ -226,9 +261,9 @@ def check_write(self): ds.optimize.compact_files() -@compat_test(min_version="0.39.0") +@compat_test(min_version="4.0.0-beta.8") class IvfRqVectorIndex(UpgradeDowngradeTest): - """Test IVF_RQ vector index compatibility.""" + """Test IVF_RQ vector index compatibility. V2 was introduced in v4.0.0-beta.8""" def __init__(self, path: Path): self.path = path @@ -273,6 +308,18 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + else: + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0].name + + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and run optimize workflows.""" ds = lance.dataset(self.path) diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index b04693ee85c..356f72a5e66 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1678,7 +1678,7 @@ def test_describe_vector_index(indexed_dataset: LanceDataset): info = indexed_dataset.describe_indices()[0] assert info.name == "vector_idx" - assert info.type_url == "/lance.table.VectorIndexDetails" + assert info.type_url == "/lance.index.pb.VectorIndexDetails" assert info.index_type == "IVF_PQ" assert info.num_rows_indexed == 1000 assert info.fields == [0] @@ -1689,6 +1689,31 @@ def test_describe_vector_index(indexed_dataset: LanceDataset): assert info.segments[0].index_version == 1 assert info.segments[0].created_at is not None + details = info.details + assert details["metric_type"] == "L2" + assert details["compression"]["type"] == "pq" + assert details["compression"]["num_bits"] == 8 + assert details["compression"]["num_sub_vectors"] == 16 + + +def test_describe_index_runtime_hints_stored(tmp_path): + tbl = create_table(nvec=300, ndim=16) + dataset = lance.write_dataset(tbl, tmp_path) + dataset = dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=4, + max_iters=100, + sample_rate=512, + ) + details = dataset.describe_indices()[0].details + hints = details.get("runtime_hints", {}) + assert hints.get("lance.ivf.max_iters") == "100" + assert hints.get("lance.ivf.sample_rate") == "512" + assert hints.get("lance.pq.max_iters") == "100" + assert hints.get("lance.pq.sample_rate") == "512" + def test_optimize_indices(indexed_dataset): data = create_table() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index ec3214f011b..4c7164e2ce8 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -4286,6 +4286,12 @@ fn prepare_vector_index_params( sq_params.sample_rate = sample_rate; } + if let Some(max_iters) = kwargs.get_item("max_iters")? { + let max_iters: usize = max_iters.extract()?; + ivf_params.max_iters = max_iters; + pq_params.max_iters = max_iters; + } + // Parse IVF params if let Some(n) = kwargs.get_item("num_partitions")? { ivf_params.num_partitions = Some(n.extract()?) @@ -4443,6 +4449,13 @@ fn prepare_vector_index_params( }?; params.version(index_file_version); params.skip_transpose(skip_transpose); + if let Some(kwargs) = kwargs + && let Some(acc) = kwargs.get_item("accelerator")? + { + params + .runtime_hints + .insert("lancedb.accelerator".to_string(), acc.to_string()); + } Ok(params) } diff --git a/python/src/indices.rs b/python/src/indices.rs index 5f00e9f5726..cf93579b867 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -537,8 +537,7 @@ async fn do_load_shuffled_vectors( dataset_version: ds.manifest.version, fragment_bitmap: Some(ds.fragments().iter().map(|f| f.id as u32).collect()), index_details: Some(Arc::new( - prost_types::Any::from_msg(&lance_table::format::pb::VectorIndexDetails::default()) - .unwrap(), + prost_types::Any::from_msg(&lance_index::pb::VectorIndexDetails::default()).unwrap(), )), index_version: IndexType::IvfPq.version(), created_at: Some(Utc::now()), diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index d56bfdcafc6..a0a16b22169 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -7,6 +7,7 @@ use std::iter::once; use std::str::FromStr; use std::sync::Arc; +use crate::pb::vector_index_details::RabitQuantization; use arrow_array::types::Float32Type; use arrow_array::{Array, ArrayRef, UInt8Array, cast::AsArray}; use lance_core::{Error, Result}; @@ -121,6 +122,19 @@ impl RQBuildParams { } } +impl From<&RQBuildParams> for RabitQuantization { + fn from(value: &RQBuildParams) -> Self { + use crate::pb::vector_index_details::rabit_quantization::RotationType; + Self { + num_bits: value.num_bits as u32, + rotation_type: match value.rotation_type { + RQRotationType::Fast => RotationType::Fast as i32, + RQRotationType::Matrix => RotationType::Matrix as i32, + }, + } + } +} + impl QuantizerBuildParams for RQBuildParams { fn sample_size(&self) -> usize { 0 diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index f605e15714d..a3ee32bb33f 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -69,6 +69,16 @@ pub struct HnswBuildParams { pub prefetch_distance: Option, } +impl From<&HnswBuildParams> for crate::pb::HnswParameters { + fn from(params: &HnswBuildParams) -> Self { + Self { + max_connections: params.m as u32, + construction_ef: params.ef_construction as u32, + max_level: params.max_level as u32, + } + } +} + impl Default for HnswBuildParams { fn default() -> Self { Self { diff --git a/rust/lance-index/src/vector/pq/builder.rs b/rust/lance-index/src/vector/pq/builder.rs index 1768e9fe8f0..c4dad4a6a3e 100644 --- a/rust/lance-index/src/vector/pq/builder.rs +++ b/rust/lance-index/src/vector/pq/builder.rs @@ -44,6 +44,15 @@ pub struct PQBuildParams { pub sample_rate: usize, } +impl From<&PQBuildParams> for crate::pb::vector_index_details::ProductQuantization { + fn from(params: &PQBuildParams) -> Self { + Self { + num_bits: params.num_bits as u32, + num_sub_vectors: params.num_sub_vectors as u32, + } + } +} + impl Default for PQBuildParams { fn default() -> Self { Self { diff --git a/rust/lance-index/src/vector/sq/builder.rs b/rust/lance-index/src/vector/sq/builder.rs index 913751062cf..359765040dd 100644 --- a/rust/lance-index/src/vector/sq/builder.rs +++ b/rust/lance-index/src/vector/sq/builder.rs @@ -12,6 +12,14 @@ pub struct SQBuildParams { pub sample_rate: usize, } +impl From<&SQBuildParams> for crate::pb::vector_index_details::ScalarQuantization { + fn from(params: &SQBuildParams) -> Self { + Self { + num_bits: params.num_bits as u32, + } + } +} + impl Default for SQBuildParams { fn default() -> Self { Self { diff --git a/rust/lance/src/dataset/index.rs b/rust/lance/src/dataset/index.rs index 6290323ede6..354cdaf7f86 100644 --- a/rust/lance/src/dataset/index.rs +++ b/rust/lance/src/dataset/index.rs @@ -17,9 +17,9 @@ use async_trait::async_trait; use lance_core::{Error, Result}; use lance_encoding::version::LanceFileVersion; use lance_index::frag_reuse::FRAG_REUSE_INDEX_NAME; +use lance_index::pb::VectorIndexDetails; use lance_index::scalar::lance_format::LanceIndexStore; use lance_table::format::IndexMetadata; -use lance_table::format::pb::VectorIndexDetails; use serde::{Deserialize, Serialize}; use super::optimize::{IndexRemapper, IndexRemapperOptions}; diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index 790ce9a04f5..a042cf568ce 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -4120,6 +4120,7 @@ mod tests { ], version: crate::index::vector::IndexFileVersion::V3, skip_transpose: false, + runtime_hints: Default::default(), }, false, ) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 6916c9ca258..f2c16b87d0b 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -3594,18 +3594,23 @@ impl Scanner { } } } else if let Some(index) = indices.iter().find(|i| i.fields.contains(&column_id)) { - // TODO: Once we do https://github.com/lance-format/lance/issues/5231, we - // should be able to get the metric type directly from the index metadata, - // at least for newer indexes. - let idx = self - .dataset - .open_vector_index( - q.column.as_str(), - &index.uuid.to_string(), - &NoOpMetricsCollector, - ) - .await?; - let index_metric = idx.metric_type(); + // Try to get metric type from index metadata first (fast path for newer indices) + let index_metric = if let Some(metric) = + crate::index::vector::details::metric_type_from_index_metadata(index) + { + metric + } else { + // Fall back to opening the index for legacy indices without details + let idx = self + .dataset + .open_vector_index( + q.column.as_str(), + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) + .await?; + idx.metric_type() + }; let use_this_index = match q.metric_type { Some(user_metric) => { @@ -9812,6 +9817,7 @@ full_filter=name LIKE Utf8(\"test%2\"), refine_filter=name LIKE Utf8(\"test%2\") ], version: crate::index::vector::IndexFileVersion::Legacy, skip_transpose: false, + runtime_hints: Default::default(), }, false, ) diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index a721c600d3d..c65fe13a23e 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -45,7 +45,7 @@ use lance_index::vector::sq::ScalarQuantizer; use lance_index::vector::v3::subindex::IvfSubIndex; use lance_index::{INDEX_FILE_NAME, Index, IndexType, PrewarmOptions, pb, vector::VectorIndex}; use lance_index::{ - IndexCriteria, infer_system_index_type, is_system_index, + IndexCriteria, is_system_index, metrics::{MetricsCollector, NoOpMetricsCollector}, }; use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; @@ -62,6 +62,10 @@ use scalar::index_matches_criteria; use serde_json::json; use tracing::{info, instrument}; use uuid::Uuid; +use vector::details::{ + derive_vector_index_type, infer_missing_vector_details, vector_details_as_json, +}; +pub(crate) use vector::details::{vector_index_details, vector_index_details_default}; use vector::ivf::v2::{IVFIndex, IvfStateEntryBox}; use vector::utils::get_vector_type; @@ -541,7 +545,7 @@ pub(crate) async fn remap_index( CreatedIndex { index_details: prost_types::Any::from_msg( - &lance_table::format::pb::VectorIndexDetails::default(), + &lance_index::pb::VectorIndexDetails::default(), ) .unwrap(), index_version, @@ -592,11 +596,6 @@ async fn open_index_proto(reader: &dyn Reader) -> Result { Ok(proto) } -fn vector_index_details() -> prost_types::Any { - let details = lance_table::format::pb::VectorIndexDetails::default(); - prost_types::Any::from_msg(&details).unwrap() -} - struct IndexDescriptionImpl { name: String, field_ids: Vec, @@ -667,41 +666,16 @@ impl IndexDescriptionImpl { let details = IndexDetails(index_details.clone()); let mut rows_indexed = 0; - // System indices (e.g. __lance_frag_reuse, __lance_mem_wal) are - // identified by name and have no entry in the scalar plugin registry, - // so resolve them up front. This mirrors `load_indices` in - // python/src/dataset.rs, keeping the two listing methods in agreement. - let index_type = if let Some(system_type) = infer_system_index_type(example_metadata) { + let index_type = if details.is_vector() { + derive_vector_index_type(index_details) + } else if let Some(system_type) = lance_index::infer_system_index_type(example_metadata) { + // System indices (frag-reuse, mem-wal) are identified by name, not + // by a plugin entry, so the plugin lookup below would return + // "Unknown" otherwise. system_type.to_string() - } else if details.is_vector() { - // Vector indices need to be opened to get the correct type - let column = field_ids - .first() - .and_then(|id| dataset.schema().field_by_id(*id)) - .map(|f| f.name.clone()) - .ok_or_else(|| { - Error::index("Cannot determine column name for vector index".to_string()) - })?; - - match dataset - .open_generic_index( - &column, - &example_metadata.uuid.to_string(), - &NoOpMetricsCollector, - ) - .await - { - Ok(idx) => idx.index_type().to_string(), - Err(e) => { - log::warn!( - "Failed to open vector index {} to determine type: {}", - name, - e - ); - "Unknown".to_string() - } - } } else { + // We attempted to infer the index type when we loaded the indices, + // so if we hit this branch the index type is truly unknown. details .get_plugin() .map(|p| p.name().to_string()) @@ -758,10 +732,14 @@ impl IndexDescription for IndexDescriptionImpl { } fn details(&self) -> Result { - let plugin = self.details.get_plugin()?; - plugin - .details_as_json(&self.details.0) - .map(|v| v.to_string()) + if self.details.is_vector() { + vector_details_as_json(&self.details.0) + } else { + let plugin = self.details.get_plugin()?; + plugin + .details_as_json(&self.details.0) + .map(|v| v.to_string()) + } } fn total_size_bytes(&self) -> Option { @@ -971,7 +949,7 @@ impl DatasetIndexExt for Dataset { let metadata_key = IndexMetadataKey { version: self.version().version, }; - let indices = match self.index_cache.get_with_key(&metadata_key).await { + let mut indices = match self.index_cache.get_with_key(&metadata_key).await { Some(indices) => indices, None => { let mut loaded_indices = read_manifest_indexes( @@ -989,6 +967,20 @@ impl DatasetIndexExt for Dataset { } }; + // Infer details for legacy vector indices (once per index name, concurrently). + // This may run on indices that were opportunistically cached during Dataset::open + // before the full Dataset was available for inference. + { + let mut updated = indices.as_ref().clone(); + infer_missing_vector_details(self, &mut updated).await; + if updated != *indices { + indices = Arc::new(updated); + self.index_cache + .insert_with_key(&metadata_key, indices.clone()) + .await; + } + } + if let Some(frag_reuse_index_meta) = indices.iter().find(|idx| idx.name == FRAG_REUSE_INDEX_NAME) { @@ -2483,7 +2475,7 @@ mod tests { fields: vec![field_id], dataset_version: dataset.manifest.version, fragment_bitmap: Some(fragment_bitmap.into_iter().collect()), - index_details: Some(Arc::new(vector_index_details())), + index_details: Some(Arc::new(vector_index_details_default())), index_version: IndexType::Vector.version(), created_at: Some(chrono::Utc::now()), base_id: None, @@ -4225,6 +4217,126 @@ mod tests { "updated_at_timestamp_ms should be null when no indices have created_at timestamps" ); } + + #[tokio::test] + async fn test_legacy_vector_index_details_inferred_on_load_and_migration() { + use lance_linalg::distance::DistanceType; + + // Create a fresh dataset with IVF_HNSW_SQ so inference produces non-default + // details (HNSW config + SQ compression) that survive proto serialization. + let test_dir = lance_core::utils::tempfile::TempDir::default(); + let test_uri = test_dir.path_str(); + let data = gen_batch() + .col("i", array::step::()) + .col("vec", array::rand_vec::(16.into())) + .into_reader_rows(RowCount::from(1024), BatchCount::from(1)); + let mut dataset = Dataset::write(data, &test_uri, None).await.unwrap(); + + let params = VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::Cosine, + IvfBuildParams { + num_partitions: Some(2), + ..Default::default() + }, + HnswBuildParams::default(), + SQBuildParams::default(), + ); + dataset + .create_index(&["vec"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Verify the index has populated details. + let descriptions = dataset.describe_indices(None).await.unwrap(); + assert_eq!(descriptions.len(), 1); + assert_eq!(descriptions[0].index_type(), "IVF_HNSW_SQ"); + + // Simulate a legacy dataset by clearing details from the manifest. + // Write a new manifest with empty VectorIndexDetails value bytes. + let mut indices = dataset.load_indices().await.unwrap().as_ref().clone(); + for idx in &mut indices { + if let Some(details) = idx.index_details.as_ref() + && details.type_url.ends_with("VectorIndexDetails") + { + idx.index_details = Some(Arc::new(vector_index_details_default())); + } + } + // Write back via a no-op commit that carries the cleared indices. + // We commit by doing a delete("false") after replacing the cached indices. + let metadata_key = crate::session::index_caches::IndexMetadataKey { + version: dataset.version().version, + }; + dataset + .index_cache + .insert_with_key(&metadata_key, Arc::new(indices)) + .await; + dataset.delete("false").await.unwrap(); + + // -- Part 1: Inference on load -- + // Open with a fresh session so nothing is cached. + let dataset = DatasetBuilder::from_uri(&test_uri) + .with_session(Arc::new(Session::default())) + .load() + .await + .unwrap(); + + // load_indices should detect empty details and infer from index files. + let indices = dataset.load_indices().await.unwrap(); + assert_eq!(indices.len(), 1); + let details = indices[0].index_details.as_ref().unwrap(); + assert!( + !details.value.is_empty(), + "Details should have been inferred from index files" + ); + + // describe_indices should return a real type (not generic "Vector"). + let descriptions = dataset.describe_indices(None).await.unwrap(); + assert_eq!(descriptions.len(), 1); + assert_ne!( + descriptions[0].index_type(), + "Vector", + "Should have inferred a specific index type" + ); + assert_eq!( + descriptions[0].index_type(), + "IVF_HNSW_SQ", + "Inferred type should match the originally-built index" + ); + let inferred_type = descriptions[0].index_type().to_string(); + let details_json: serde_json::Value = + serde_json::from_str(&descriptions[0].details().unwrap()).unwrap(); + assert_eq!(details_json["metric_type"], "COSINE"); + assert_eq!(details_json["compression"]["type"], "sq"); + assert!( + details_json["hnsw"]["max_connections"].is_number(), + "Inferred HNSW config should have max_connections; got {details_json}" + ); + assert!(details_json["hnsw"]["construction_ef"].is_number()); + assert!(details_json["hnsw"]["max_level"].is_number()); + + // -- Part 2: Migration persists inferred details -- + let mut dataset = dataset; + dataset.delete("false").await.unwrap(); + + // Open with yet another fresh session. + let dataset = DatasetBuilder::from_uri(&test_uri) + .with_session(Arc::new(Session::default())) + .load() + .await + .unwrap(); + + // The migrated manifest should have non-empty details without + // needing to read index files again. + let indices = dataset.load_indices().await.unwrap(); + assert_eq!(indices.len(), 1); + assert!( + !indices[0].index_details.as_ref().unwrap().value.is_empty(), + "Migrated manifest should have non-empty details" + ); + let descriptions = dataset.describe_indices(None).await.unwrap(); + assert_eq!(descriptions[0].index_type(), inferred_type); + } + #[rstest] #[case::btree("i", IndexType::BTree, Box::new(ScalarIndexParams::default()))] #[case::bitmap("i", IndexType::Bitmap, Box::new(ScalarIndexParams::default()))] diff --git a/rust/lance/src/index/append.rs b/rust/lance/src/index/append.rs index 9259a723acb..66c0dc84337 100644 --- a/rust/lance/src/index/append.rs +++ b/rust/lance/src/index/append.rs @@ -28,7 +28,7 @@ use crate::dataset::Dataset; use crate::dataset::index::LanceIndexStoreExt; use crate::dataset::rowids::load_row_id_sequences; use crate::index::scalar::load_training_data; -use crate::index::vector_index_details; +use crate::index::vector_index_details_default; #[derive(Debug, Clone)] pub struct IndexMergeResults<'a> { @@ -268,7 +268,7 @@ pub async fn merge_indices_with_unindexed_frags<'a>( vec![removed_segment], new_fragment_bitmap, CreatedIndex { - index_details: vector_index_details(), + index_details: vector_index_details_default(), index_version: lance_index::IndexType::Vector.version() as u32, files: Some(files), }, @@ -313,6 +313,16 @@ pub async fn merge_indices_with_unindexed_frags<'a>( } } + // Carry forward existing index details, preferring the first segment + // that has populated (non-empty) details. + let index_details = old_indices + .iter() + .rev() + .filter_map(|idx| idx.index_details.as_ref()) + .find(|d| !d.value.is_empty()) + .map(|d| d.as_ref().clone()) + .unwrap_or_else(vector_index_details_default); + let index_dir = dataset.indices_dir().join(new_uuid.to_string()); let files = list_index_files_with_sizes(&dataset.object_store, &index_dir).await?; @@ -321,7 +331,10 @@ pub async fn merge_indices_with_unindexed_frags<'a>( removed_indices, frag_bitmap, CreatedIndex { - index_details: vector_index_details(), + index_details, + // retain_supported_indices guarantees all old_indices have + // index_version <= our max supported version, so we can safely + // write the current library's version for this index type. index_version: lance_index::IndexType::Vector.version() as u32, files: Some(files), }, diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index 68783d1c673..64df23e0fc9 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -16,7 +16,7 @@ use crate::{ LANCE_VECTOR_INDEX, VectorIndexParams, build_distributed_vector_index, build_empty_vector_index, build_vector_index, }, - vector_index_details, + vector_index_details, vector_index_details_default, }, }; use futures::future::{BoxFuture, try_join_all}; @@ -386,7 +386,7 @@ impl<'a> CreateIndexBuilder<'a> { let files = list_index_files_with_sizes(&self.dataset.object_store, &index_dir).await?; CreatedIndex { - index_details: vector_index_details(), + index_details: vector_index_details(vec_params), index_version, files: Some(files), } @@ -425,7 +425,7 @@ impl<'a> CreateIndexBuilder<'a> { let files = list_index_files_with_sizes(&self.dataset.object_store, &index_dir).await?; CreatedIndex { - index_details: vector_index_details(), + index_details: vector_index_details_default(), index_version: self.index_type.version() as u32, files: Some(files), } @@ -2106,7 +2106,7 @@ mod tests { vec![IndexSegment::new( uuid, dataset.fragment_bitmap.as_ref().clone(), - Arc::new(vector_index_details()), + Arc::new(vector_index_details(¶ms)), IndexType::IvfHnswFlat.version(), )], ) diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 35a42b53828..a7856f2306f 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -621,12 +621,12 @@ mod tests { use lance_core::utils::tempfile::TempStrDir; use lance_core::{datatypes::Field, utils::address::RowAddress}; use lance_datagen::array; + use lance_index::pb::VectorIndexDetails; use lance_index::{IndexType, optimize::OptimizeOptions}; use lance_index::{ pbold::NGramIndexDetails, scalar::{BuiltinIndexType, ScalarIndexParams}, }; - use lance_table::format::pb::VectorIndexDetails; fn make_index_metadata( name: &str, diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 7e85955d090..87b32344ec6 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use std::{any::Any, collections::HashMap}; pub mod builder; +pub(crate) mod details; pub mod ivf; pub mod pq; pub mod utils; @@ -59,7 +60,7 @@ use tracing::instrument; use utils::get_vector_type; use uuid::Uuid; -use super::{DatasetIndexExt, DatasetIndexInternalExt, IndexParams, pb, vector_index_details}; +use super::{DatasetIndexExt, DatasetIndexInternalExt, IndexParams, pb}; use crate::dataset::index::dataset_format_version; use crate::dataset::transaction::{Operation, Transaction}; use crate::{Error, Result, dataset::Dataset, index::pb::vector_index_stage::Stage}; @@ -265,6 +266,11 @@ pub struct VectorIndexParams { /// Skip transpose / packing for PQ and RQ storage. pub skip_transpose: bool, + + /// Runtime hints: optional build preferences stored in the index manifest. + /// Keys use reverse-DNS namespacing (e.g., "lance.ivf.max_iters"). + /// Populated by the build path and merged into VectorIndexDetails at creation time. + pub runtime_hints: HashMap, } impl VectorIndexParams { @@ -286,6 +292,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -296,6 +303,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -330,6 +338,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -356,6 +365,7 @@ impl VectorIndexParams { metric_type: distance_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -371,6 +381,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -385,6 +396,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -399,6 +411,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -413,6 +426,7 @@ impl VectorIndexParams { metric_type: distance_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -434,6 +448,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -455,6 +470,7 @@ impl VectorIndexParams { metric_type, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: HashMap::new(), } } @@ -1791,7 +1807,7 @@ pub async fn initialize_vector_index( fields: vec![field.id], dataset_version: target_dataset.manifest.version, fragment_bitmap, - index_details: Some(Arc::new(vector_index_details())), + index_details: source_index.index_details.clone(), index_version: source_index.index_version, created_at: Some(chrono::Utc::now()), base_id: None, diff --git a/rust/lance/src/index/vector/details.rs b/rust/lance/src/index/vector/details.rs new file mode 100644 index 00000000000..83e9b92c209 --- /dev/null +++ b/rust/lance/src/index/vector/details.rs @@ -0,0 +1,1456 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Serialization and deserialization of [`VectorIndexDetails`] proto messages. +//! +//! This module handles: +//! - Populating `VectorIndexDetails` from build params at index creation time +//! - Deriving a human-readable index type string (e.g., "IVF_PQ") from details +//! - Serializing details as JSON for `describe_indices()` +//! - Inferring details from index files on disk (fallback for legacy indices) + +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; + +use lance_file::reader::FileReaderOptions; +use lance_index::pb::VectorIndexDetails; +use lance_index::pb::VectorMetricType; +use lance_index::pb::index::Implementation; +use lance_index::pb::vector_index_details::{Compression, FlatCompression, rabit_quantization}; +use lance_index::{INDEX_FILE_NAME, INDEX_METADATA_SCHEMA_KEY, pb}; +use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; +use lance_io::traits::Reader; +use lance_io::utils::{CachedFileSize, read_last_block, read_version}; +use lance_linalg::distance::DistanceType; +use lance_table::format::IndexMetadata; +use serde::Serialize; + +use lance_index::vector::bq::{RQBuildParams, RQRotationType}; +use lance_index::vector::hnsw::builder::HnswBuildParams; +use lance_index::vector::ivf::IvfBuildParams; +use lance_index::vector::pq::PQBuildParams; +use lance_index::vector::sq::builder::SQBuildParams; + +use super::{StageParams, VectorIndexParams}; +use crate::dataset::Dataset; +use crate::index::open_index_proto; +use crate::{Error, Result}; + +// Private structs for JSON serialization of VectorIndexDetails. +// Changes to field names or structure are backwards-incompatible for users +// parsing the JSON output of describe_indices(). See snapshot tests below. + +#[derive(Serialize)] +struct VectorDetailsJson { + metric_type: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + target_partition_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + hnsw: Option, + #[serde(skip_serializing_if = "Option::is_none")] + compression: Option, + #[serde(skip_serializing_if = "HashMap::is_empty")] + runtime_hints: HashMap, +} + +#[derive(Serialize)] +struct HnswDetailsJson { + max_connections: u32, + construction_ef: u32, + #[serde(skip_serializing_if = "is_zero")] + max_level: u32, +} + +fn is_zero(v: &u32) -> bool { + *v == 0 +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum CompressionDetailsJson { + Pq { + num_bits: u32, + num_sub_vectors: u32, + }, + Sq { + num_bits: u32, + }, + Rq { + num_bits: u32, + rotation_type: &'static str, + }, +} + +/// Build a `VectorIndexDetails` proto from build params at index creation time. +pub fn vector_index_details(params: &VectorIndexParams) -> prost_types::Any { + let metric_type = match params.metric_type { + lance_linalg::distance::DistanceType::L2 => VectorMetricType::L2, + lance_linalg::distance::DistanceType::Cosine => VectorMetricType::Cosine, + lance_linalg::distance::DistanceType::Dot => VectorMetricType::Dot, + lance_linalg::distance::DistanceType::Hamming => VectorMetricType::Hamming, + }; + + let mut target_partition_size = 0u64; + let mut hnsw_index_config = None; + let mut compression = None; + let mut runtime_hints: HashMap = params.runtime_hints.clone(); + + for stage in ¶ms.stages { + match stage { + StageParams::Ivf(ivf) => { + if let Some(tps) = ivf.target_partition_size { + target_partition_size = tps as u64; + } + runtime_hints.insert("lance.ivf.max_iters".to_string(), ivf.max_iters.to_string()); + runtime_hints.insert( + "lance.ivf.sample_rate".to_string(), + ivf.sample_rate.to_string(), + ); + runtime_hints.insert( + "lance.ivf.shuffle_partition_batches".to_string(), + ivf.shuffle_partition_batches.to_string(), + ); + runtime_hints.insert( + "lance.ivf.shuffle_partition_concurrency".to_string(), + ivf.shuffle_partition_concurrency.to_string(), + ); + } + StageParams::Hnsw(hnsw) => { + hnsw_index_config = Some(hnsw.into()); + let val = match hnsw.prefetch_distance { + Some(v) => v.to_string(), + None => "none".to_string(), + }; + runtime_hints.insert("lance.hnsw.prefetch_distance".to_string(), val); + } + StageParams::PQ(pq) => { + compression = Some(Compression::Pq(pq.into())); + runtime_hints.insert("lance.pq.max_iters".to_string(), pq.max_iters.to_string()); + runtime_hints.insert( + "lance.pq.sample_rate".to_string(), + pq.sample_rate.to_string(), + ); + runtime_hints.insert( + "lance.pq.kmeans_redos".to_string(), + pq.kmeans_redos.to_string(), + ); + } + StageParams::SQ(sq) => { + compression = Some(Compression::Sq(sq.into())); + runtime_hints.insert( + "lance.sq.sample_rate".to_string(), + sq.sample_rate.to_string(), + ); + } + StageParams::RQ(rq) => { + compression = Some(Compression::Rq(rq.into())); + } + } + } + + runtime_hints.insert( + "lance.skip_transpose".to_string(), + params.skip_transpose.to_string(), + ); + + let compression = compression.or(Some(Compression::Flat(FlatCompression {}))); + + let details = VectorIndexDetails { + metric_type: metric_type.into(), + target_partition_size, + hnsw_index_config, + compression, + runtime_hints, + }; + prost_types::Any::from_msg(&details).unwrap() +} + +pub fn vector_index_details_default() -> prost_types::Any { + let details = lance_index::pb::VectorIndexDetails::default(); + prost_types::Any::from_msg(&details).unwrap() +} + +/// Apply stored runtime hints from `VectorIndexDetails` back into build params. +/// +/// Known `lance.*` keys are parsed and applied to the appropriate stage. Unknown +/// keys (e.g., from other runtimes) are silently ignored. Malformed values are +/// also silently ignored — the stage keeps its existing default. +// TODO: wire into a general `Dataset::rebuild_index` method so users can +// regenerate an index from its stored details (e.g. after file corruption). +#[allow(dead_code)] +pub fn apply_runtime_hints(hints: &HashMap, params: &mut VectorIndexParams) { + fn parse(hints: &HashMap, key: &str) -> Option { + hints.get(key)?.parse().ok() + } + + if let Some(v) = parse::(hints, "lance.skip_transpose") { + params.skip_transpose = v; + } + + for stage in &mut params.stages { + match stage { + StageParams::Ivf(ivf) => { + if let Some(v) = parse(hints, "lance.ivf.max_iters") { + ivf.max_iters = v; + } + if let Some(v) = parse(hints, "lance.ivf.sample_rate") { + ivf.sample_rate = v; + } + if let Some(v) = parse(hints, "lance.ivf.shuffle_partition_batches") { + ivf.shuffle_partition_batches = v; + } + if let Some(v) = parse(hints, "lance.ivf.shuffle_partition_concurrency") { + ivf.shuffle_partition_concurrency = v; + } + } + StageParams::Hnsw(hnsw) => { + if let Some(raw) = hints.get("lance.hnsw.prefetch_distance") { + hnsw.prefetch_distance = if raw == "none" { + None + } else { + raw.parse().ok() + }; + } + } + StageParams::PQ(pq) => { + if let Some(v) = parse(hints, "lance.pq.max_iters") { + pq.max_iters = v; + } + if let Some(v) = parse(hints, "lance.pq.sample_rate") { + pq.sample_rate = v; + } + if let Some(v) = parse(hints, "lance.pq.kmeans_redos") { + pq.kmeans_redos = v; + } + } + StageParams::SQ(sq) => { + if let Some(v) = parse(hints, "lance.sq.sample_rate") { + sq.sample_rate = v; + } + } + StageParams::RQ(_) => {} + } + } +} + +/// Reconstruct `VectorIndexParams` from a stored `VectorIndexDetails` proto. +/// +/// Returns `None` for legacy indices (empty details) or if the proto is malformed. +/// Runtime hints are applied on top of the reconstructed spec. +// TODO: wire into a general `Dataset::rebuild_index` method so users can +// regenerate an index from its stored details (e.g. after file corruption). +#[allow(dead_code)] +pub fn vector_params_from_details(details: &prost_types::Any) -> Option { + if details.value.is_empty() { + return None; + } + let d = details.to_msg::().ok()?; + + let metric = DistanceType::from(VectorMetricType::try_from(d.metric_type).ok()?); + + let mut ivf = IvfBuildParams::default(); + if d.target_partition_size > 0 { + ivf.target_partition_size = Some(d.target_partition_size as usize); + } + + let hnsw = d.hnsw_index_config.map(|h| HnswBuildParams { + m: h.max_connections as usize, + ef_construction: h.construction_ef as usize, + max_level: h.max_level as u16, + ..Default::default() + }); + + let mut params = match (hnsw, d.compression) { + (None, Some(Compression::Pq(pq))) => VectorIndexParams::with_ivf_pq_params( + metric, + ivf, + PQBuildParams { + num_bits: pq.num_bits as usize, + num_sub_vectors: pq.num_sub_vectors as usize, + ..Default::default() + }, + ), + (None, Some(Compression::Sq(sq))) => VectorIndexParams::with_ivf_sq_params( + metric, + ivf, + SQBuildParams { + num_bits: sq.num_bits as u16, + ..Default::default() + }, + ), + (None, Some(Compression::Rq(rq))) => { + let rotation_type = + match rabit_quantization::RotationType::try_from(rq.rotation_type).ok()? { + rabit_quantization::RotationType::Matrix => RQRotationType::Matrix, + rabit_quantization::RotationType::Fast => RQRotationType::Fast, + }; + VectorIndexParams::with_ivf_rq_params( + metric, + ivf, + RQBuildParams::with_rotation_type(rq.num_bits as u8, rotation_type), + ) + } + (Some(hnsw), Some(Compression::Pq(pq))) => VectorIndexParams::with_ivf_hnsw_pq_params( + metric, + ivf, + hnsw, + PQBuildParams { + num_bits: pq.num_bits as usize, + num_sub_vectors: pq.num_sub_vectors as usize, + ..Default::default() + }, + ), + (Some(hnsw), Some(Compression::Sq(sq))) => VectorIndexParams::with_ivf_hnsw_sq_params( + metric, + ivf, + hnsw, + SQBuildParams { + num_bits: sq.num_bits as u16, + ..Default::default() + }, + ), + (Some(hnsw), _) => VectorIndexParams::ivf_hnsw(metric, ivf, hnsw), + _ => VectorIndexParams::with_ivf_flat_params(metric, ivf), + }; + + apply_runtime_hints(&d.runtime_hints, &mut params); + Some(params) +} + +/// Extract metric type from index metadata without opening the index file. +/// +/// For newer indices with populated `VectorIndexDetails`, returns the metric type directly. +/// For legacy indices without details, returns `None` and caller should fall back to opening the index. +/// +/// # Arguments +/// * `index` - The index metadata containing serialized VectorIndexDetails +/// +/// # Returns +/// * `Some(DistanceType)` if details are present and valid +/// * `None` if details are absent or empty (legacy index without details) +pub fn metric_type_from_index_metadata(index: &IndexMetadata) -> Option { + let index_details = index.index_details.as_ref()?; + + // Empty value bytes indicates legacy index that needs to be opened for details + if index_details.value.is_empty() { + return None; + } + + let details = index_details.to_msg::().ok()?; + + // Try to convert the metric_type field. This works even if metric_type is 0 (L2), + // since L2 is a valid metric type. + let metric_enum = VectorMetricType::try_from(details.metric_type).ok()?; + Some(DistanceType::from(metric_enum)) +} + +/// Returns true if the proto value represents a "truly empty" VectorIndexDetails +/// (i.e., a legacy index that was created before we populated this field). +fn is_empty_vector_details(details: &prost_types::Any) -> bool { + details.value.is_empty() +} + +/// Returns true if this is a vector index whose details need to be inferred from disk. +/// +/// This covers two legacy cases: +/// - Very old indices (<=0.19.2) where `index_details` is `None` but the indexed +/// field is a vector type +/// - Newer pre-details indices where `index_details` has a VectorIndexDetails +/// type_url but empty value bytes +pub fn needs_vector_details_inference( + index: &IndexMetadata, + schema: &lance_core::datatypes::Schema, +) -> bool { + match &index.index_details { + Some(d) => d.type_url.ends_with("VectorIndexDetails") && d.value.is_empty(), + None => index.fields.iter().any(|&field_id| { + schema + .field_by_id(field_id) + .map(|f| matches!(f.data_type(), arrow_schema::DataType::FixedSizeList(_, _))) + .unwrap_or(false) + }), + } +} + +/// Infer missing vector index details for all indices that need it. +/// +/// Runs inference once per unique index name, concurrently across names. +/// Applies the inferred details back to all matching indices in the slice. +pub async fn infer_missing_vector_details(dataset: &Dataset, indices: &mut [IndexMetadata]) { + let schema = dataset.schema(); + let needs_inference: HashMap<&str, &IndexMetadata> = indices + .iter() + .filter(|idx| needs_vector_details_inference(idx, schema)) + .map(|idx| (idx.name.as_str(), idx)) + .collect(); + if needs_inference.is_empty() { + return; + } + let inferred: HashMap> = + futures::future::join_all(needs_inference.into_iter().map( + |(name, representative)| async move { + let result = infer_vector_index_details(dataset, representative).await; + (name.to_string(), result) + }, + )) + .await + .into_iter() + .filter_map(|(name, result)| match result { + Ok(details) => Some((name, Arc::new(details))), + Err(err) => { + tracing::warn!("Could not infer vector index details for {}: {}", name, err); + None + } + }) + .collect(); + for index in indices.iter_mut() { + if let Some(details) = inferred.get(&index.name) { + index.index_details = Some(details.clone()); + } + } +} + +/// Derive a human-readable index type string from VectorIndexDetails. +pub fn derive_vector_index_type(details: &prost_types::Any) -> String { + if is_empty_vector_details(details) { + return "Vector".to_string(); + } + + let Ok(d) = details.to_msg::() else { + return "Vector".to_string(); + }; + let mut index_type = "IVF_".to_string(); + if d.hnsw_index_config.is_some() { + index_type.push_str("HNSW_"); + } + match d.compression { + None | Some(Compression::Flat(_)) => index_type.push_str("FLAT"), + Some(Compression::Pq(_)) => index_type.push_str("PQ"), + Some(Compression::Sq(_)) => index_type.push_str("SQ"), + Some(Compression::Rq(_)) => index_type.push_str("RQ"), + } + index_type +} + +/// Serialize VectorIndexDetails as a JSON string. +pub fn vector_details_as_json(details: &prost_types::Any) -> Result { + if is_empty_vector_details(details) { + return Ok("{}".to_string()); + } + + let d = details + .to_msg::() + .map_err(|e| Error::index(format!("Failed to deserialize VectorIndexDetails: {}", e)))?; + + let metric_type = match VectorMetricType::try_from(d.metric_type) { + Ok(VectorMetricType::L2) => "L2", + Ok(VectorMetricType::Cosine) => "COSINE", + Ok(VectorMetricType::Dot) => "DOT", + Ok(VectorMetricType::Hamming) => "HAMMING", + Err(_) => "UNKNOWN", + }; + + let hnsw = d.hnsw_index_config.map(|h| HnswDetailsJson { + max_connections: h.max_connections, + construction_ef: h.construction_ef, + max_level: h.max_level, + }); + + let compression = d.compression.and_then(|c| match c { + Compression::Flat(_) => None, + Compression::Pq(pq) => Some(CompressionDetailsJson::Pq { + num_bits: pq.num_bits, + num_sub_vectors: pq.num_sub_vectors, + }), + Compression::Sq(sq) => Some(CompressionDetailsJson::Sq { + num_bits: sq.num_bits, + }), + Compression::Rq(rq) => { + let rotation_type = match rabit_quantization::RotationType::try_from(rq.rotation_type) { + Ok(rabit_quantization::RotationType::Matrix) => "matrix", + _ => "fast", + }; + Some(CompressionDetailsJson::Rq { + num_bits: rq.num_bits, + rotation_type, + }) + } + }); + + let json = VectorDetailsJson { + metric_type, + target_partition_size: if d.target_partition_size > 0 { + Some(d.target_partition_size) + } else { + None + }, + hnsw, + compression, + runtime_hints: d.runtime_hints, + }; + + serde_json::to_string(&json).map_err(|e| Error::index(format!("Failed to serialize: {}", e))) +} + +/// Infer VectorIndexDetails from index files on disk. +/// Used as a fallback for legacy indices where the manifest doesn't have populated details. +pub async fn infer_vector_index_details( + dataset: &Dataset, + index: &IndexMetadata, +) -> Result { + let uuid = index.uuid.to_string(); + let index_dir = dataset.indice_files_dir(index)?; + let file_dir = index_dir.clone().join(uuid.as_str()); + let index_file = file_dir.clone().join(INDEX_FILE_NAME); + let reader: Arc = dataset.object_store.open(&index_file).await?.into(); + + let tailing_bytes = read_last_block(reader.as_ref()).await?; + let (major_version, minor_version) = read_version(&tailing_bytes)?; + + match (major_version, minor_version) { + (0, 1) | (0, 0) => { + // Legacy v0.1: read pb::Index, extract VectorIndex stages + let proto = open_index_proto(reader.as_ref()).await?; + convert_legacy_proto_to_details(&proto) + } + _ => { + // v0.2+/v0.3: read lance file schema metadata + convert_v3_metadata_to_details(dataset, &file_dir).await + } + } +} + +fn convert_legacy_proto_to_details(proto: &pb::Index) -> Result { + use lance_index::pb::VectorIndexDetails; + use lance_index::pb::vector_index_details::*; + use pb::vector_index_stage::Stage; + + let Some(Implementation::VectorIndex(vector_index)) = &proto.implementation else { + return Ok(vector_index_details_default()); + }; + + let metric_type = pb::VectorMetricType::try_from(vector_index.metric_type) + .unwrap_or(pb::VectorMetricType::L2); + + let mut compression: Option = None; + for stage in &vector_index.stages { + if let Some(Stage::Pq(pq)) = &stage.stage { + compression = Some(Compression::Pq(ProductQuantization { + num_bits: pq.num_bits, + num_sub_vectors: pq.num_sub_vectors, + })); + } + } + let compression = compression.or(Some(Compression::Flat(FlatCompression {}))); + + let details = VectorIndexDetails { + metric_type: metric_type.into(), + target_partition_size: 0, + hnsw_index_config: None, + compression, + runtime_hints: Default::default(), + }; + Ok(prost_types::Any::from_msg(&details).unwrap()) +} + +async fn convert_v3_metadata_to_details( + dataset: &Dataset, + file_dir: &object_store::path::Path, +) -> Result { + use lance_index::INDEX_AUXILIARY_FILE_NAME; + use lance_index::pb::vector_index_details::*; + use lance_index::pb::{HnswParameters, VectorIndexDetails}; + use lance_index::vector::bq::storage::RabitQuantizationMetadata; + use lance_index::vector::hnsw::HnswMetadata; + use lance_index::vector::hnsw::builder::HNSW_METADATA_KEY; + use lance_index::vector::pq::storage::ProductQuantizationMetadata; + use lance_index::vector::shared::partition_merger::SupportedIvfIndexType; + use lance_index::vector::sq::storage::ScalarQuantizationMetadata; + use lance_index::vector::storage::STORAGE_METADATA_KEY; + + let index_file = file_dir.clone().join(INDEX_FILE_NAME); + let main_reader = open_lance_file(dataset, &index_file).await?; + let main_meta = &main_reader.schema().metadata; + + // Index type and distance live in the main file's INDEX_METADATA_SCHEMA_KEY. + let idx_meta: Option = main_meta + .get(INDEX_METADATA_SCHEMA_KEY) + .map(|s| serde_json::from_str(s)) + .transpose()?; + + let metric_type = idx_meta + .as_ref() + .map(|m| match m.distance_type.to_uppercase().as_str() { + "L2" | "EUCLIDEAN" => VectorMetricType::L2, + "COSINE" => VectorMetricType::Cosine, + "DOT" => VectorMetricType::Dot, + "HAMMING" => VectorMetricType::Hamming, + _ => VectorMetricType::L2, + }) + .unwrap_or(VectorMetricType::L2); + + // The index_type string drives both whether HNSW is present and which + // compression to expect. Falls back to IvfFlat if the metadata is missing + // or unrecognized. + let supported_type = idx_meta + .as_ref() + .and_then(|m| SupportedIvfIndexType::from_index_type_str(&m.index_type)) + .unwrap_or(SupportedIvfIndexType::IvfFlat); + let (has_hnsw, compression_kind) = match supported_type { + SupportedIvfIndexType::IvfFlat => (false, CompressionKind::Flat), + SupportedIvfIndexType::IvfPq => (false, CompressionKind::Pq), + SupportedIvfIndexType::IvfSq => (false, CompressionKind::Sq), + SupportedIvfIndexType::IvfRq => (false, CompressionKind::Rq), + SupportedIvfIndexType::IvfHnswFlat => (true, CompressionKind::Flat), + SupportedIvfIndexType::IvfHnswPq => (true, CompressionKind::Pq), + SupportedIvfIndexType::IvfHnswSq => (true, CompressionKind::Sq), + }; + + let hnsw_index_config = if has_hnsw { + // HNSW partition metadata is stored as a JSON array of JSON-encoded + // strings (one per partition), matching how the builder writes + // `partition_index_metadata: Vec`. + main_meta + .get(HNSW_METADATA_KEY) + .map(|s| serde_json::from_str::>(s)) + .transpose()? + .and_then(|entries| entries.into_iter().next()) + .map(|s| serde_json::from_str::(&s)) + .transpose()? + .map(|hnsw| HnswParameters { + max_connections: hnsw.params.m as u32, + construction_ef: hnsw.params.ef_construction as u32, + max_level: hnsw.params.max_level as u32, + }) + } else { + None + }; + + // For quantized indices, the per-quantizer metadata is in the auxiliary + // file under STORAGE_METADATA_KEY (a JSON-encoded Vec, one entry + // per partition; all entries currently share the same metadata so we read + // the first). + let compression = match compression_kind { + CompressionKind::Flat => Some(Compression::Flat(FlatCompression {})), + CompressionKind::Pq | CompressionKind::Sq | CompressionKind::Rq => { + let aux_file = file_dir.clone().join(INDEX_AUXILIARY_FILE_NAME); + let aux_reader = open_lance_file(dataset, &aux_file).await?; + let raw = aux_reader + .schema() + .metadata + .get(STORAGE_METADATA_KEY) + .ok_or_else(|| { + Error::index(format!( + "auxiliary file missing {STORAGE_METADATA_KEY} metadata" + )) + })?; + let entries: Vec = serde_json::from_str(raw)?; + let first = entries.first().ok_or_else(|| { + Error::index("auxiliary STORAGE_METADATA_KEY was empty".to_string()) + })?; + match compression_kind { + CompressionKind::Pq => { + let pq: ProductQuantizationMetadata = serde_json::from_str(first)?; + Some(Compression::Pq(ProductQuantization { + num_bits: pq.nbits, + num_sub_vectors: pq.num_sub_vectors as u32, + })) + } + CompressionKind::Sq => { + let sq: ScalarQuantizationMetadata = serde_json::from_str(first)?; + Some(Compression::Sq(ScalarQuantization { + num_bits: sq.num_bits as u32, + })) + } + CompressionKind::Rq => { + let rq: RabitQuantizationMetadata = serde_json::from_str(first)?; + let rotation_type = match rq.rotation_type { + lance_index::vector::bq::RQRotationType::Fast => { + rabit_quantization::RotationType::Fast + } + lance_index::vector::bq::RQRotationType::Matrix => { + rabit_quantization::RotationType::Matrix + } + }; + Some(Compression::Rq(RabitQuantization { + num_bits: rq.num_bits as u32, + rotation_type: rotation_type.into(), + })) + } + CompressionKind::Flat => unreachable!(), + } + } + }; + + let details = VectorIndexDetails { + metric_type: metric_type.into(), + target_partition_size: 0, + hnsw_index_config, + compression, + runtime_hints: Default::default(), + }; + Ok(prost_types::Any::from_msg(&details).unwrap()) +} + +enum CompressionKind { + Flat, + Pq, + Sq, + Rq, +} + +async fn open_lance_file( + dataset: &Dataset, + path: &object_store::path::Path, +) -> Result { + let scheduler = ScanScheduler::new( + dataset.object_store.clone(), + SchedulerConfig::max_bandwidth(&dataset.object_store), + ); + let file = scheduler + .open_file(path, &CachedFileSize::unknown()) + .await?; + lance_file::reader::FileReader::try_open( + file, + None, + Default::default(), + &dataset.metadata_cache.file_metadata_cache(path), + FileReaderOptions::default(), + ) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + use lance_index::pb::vector_index_details::*; + use lance_index::pb::{HnswParameters, VectorIndexDetails}; + + fn make_details( + metric: VectorMetricType, + hnsw: Option, + compression: Option, + ) -> prost_types::Any { + let details = VectorIndexDetails { + metric_type: metric.into(), + target_partition_size: 0, + hnsw_index_config: hnsw, + compression, + runtime_hints: Default::default(), + }; + prost_types::Any::from_msg(&details).unwrap() + } + + #[test] + fn test_derive_index_type_without_hnsw() { + // Note: (None, "IVF_FLAT") is not testable here because a proto with + // all defaults serializes to empty bytes, which is treated as a legacy index. + let cases: [(Option, &str); 3] = [ + ( + Some(Compression::Pq(ProductQuantization { + num_bits: 8, + num_sub_vectors: 16, + })), + "IVF_PQ", + ), + ( + Some(Compression::Sq(ScalarQuantization { num_bits: 8 })), + "IVF_SQ", + ), + ( + Some(Compression::Rq(RabitQuantization { + num_bits: 1, + rotation_type: 0, + })), + "IVF_RQ", + ), + ]; + for (compression, expected) in cases { + let details = make_details(VectorMetricType::L2, None, compression); + assert_eq!(derive_vector_index_type(&details), expected); + } + } + + #[test] + fn test_derive_index_type_with_hnsw() { + let hnsw = Some(HnswParameters { + max_connections: 20, + construction_ef: 150, + max_level: 7, + }); + assert_eq!( + derive_vector_index_type(&make_details(VectorMetricType::L2, hnsw, None)), + "IVF_HNSW_FLAT" + ); + assert_eq!( + derive_vector_index_type(&make_details( + VectorMetricType::L2, + hnsw, + Some(Compression::Pq(ProductQuantization { + num_bits: 8, + num_sub_vectors: 16, + })) + )), + "IVF_HNSW_PQ" + ); + assert_eq!( + derive_vector_index_type(&make_details( + VectorMetricType::L2, + hnsw, + Some(Compression::Sq(ScalarQuantization { num_bits: 8 })) + )), + "IVF_HNSW_SQ" + ); + } + + #[test] + fn test_derive_index_type_empty_details() { + let details = vector_index_details_default(); + assert_eq!(derive_vector_index_type(&details), "Vector"); + } + + // Snapshot tests for JSON serialization. These guard backwards compatibility + // of the JSON format returned by describe_indices(). + + #[test] + fn test_json_ivf_pq() { + let details = make_details( + VectorMetricType::L2, + None, + Some(Compression::Pq(ProductQuantization { + num_bits: 8, + num_sub_vectors: 16, + })), + ); + assert_eq!( + vector_details_as_json(&details).unwrap(), + r#"{"metric_type":"L2","compression":{"type":"pq","num_bits":8,"num_sub_vectors":16}}"# + ); + } + + #[test] + fn test_json_ivf_hnsw_sq() { + let details = make_details( + VectorMetricType::Cosine, + Some(HnswParameters { + max_connections: 30, + construction_ef: 200, + max_level: 8, + }), + Some(Compression::Sq(ScalarQuantization { num_bits: 4 })), + ); + assert_eq!( + vector_details_as_json(&details).unwrap(), + r#"{"metric_type":"COSINE","hnsw":{"max_connections":30,"construction_ef":200,"max_level":8},"compression":{"type":"sq","num_bits":4}}"# + ); + } + + #[test] + fn test_json_ivf_rq_with_rotation() { + let details = make_details( + VectorMetricType::Dot, + None, + Some(Compression::Rq(RabitQuantization { + num_bits: 1, + rotation_type: rabit_quantization::RotationType::Matrix as i32, + })), + ); + assert_eq!( + vector_details_as_json(&details).unwrap(), + r#"{"metric_type":"DOT","compression":{"type":"rq","num_bits":1,"rotation_type":"matrix"}}"# + ); + } + + #[test] + fn test_json_ivf_rq_fast_rotation() { + let details = make_details( + VectorMetricType::L2, + None, + Some(Compression::Rq(RabitQuantization { + num_bits: 1, + rotation_type: rabit_quantization::RotationType::Fast as i32, + })), + ); + assert_eq!( + vector_details_as_json(&details).unwrap(), + r#"{"metric_type":"L2","compression":{"type":"rq","num_bits":1,"rotation_type":"fast"}}"# + ); + } + + #[test] + fn test_json_with_target_partition_size() { + let details = { + let d = VectorIndexDetails { + metric_type: VectorMetricType::L2.into(), + target_partition_size: 5000, + hnsw_index_config: None, + compression: None, + runtime_hints: Default::default(), + }; + prost_types::Any::from_msg(&d).unwrap() + }; + assert_eq!( + vector_details_as_json(&details).unwrap(), + r#"{"metric_type":"L2","target_partition_size":5000}"# + ); + } + + #[test] + fn test_json_empty_details() { + let details = vector_index_details_default(); + assert_eq!(vector_details_as_json(&details).unwrap(), "{}"); + } + + #[test] + fn test_metric_type_from_index_metadata_populated() { + // Test that populated details return the metric type. + // Note: We add a non-default compression field so the proto doesn't serialize to empty bytes. + let details = make_details( + VectorMetricType::L2, + None, + Some(Compression::Pq(ProductQuantization { + num_bits: 8, + num_sub_vectors: 16, + })), + ); + let index_details = Some(std::sync::Arc::new(details)); + let index = IndexMetadata { + uuid: uuid::Uuid::new_v4(), + fields: vec![0], + name: "test_index".to_string(), + dataset_version: 1, + fragment_bitmap: None, + index_details, + index_version: 1, + created_at: None, + base_id: None, + files: None, + }; + + let metric = metric_type_from_index_metadata(&index); + assert_eq!(metric, Some(DistanceType::L2)); + } + + #[test] + fn test_metric_type_from_index_metadata_empty() { + // Test that empty details return None (legacy index) + let details = vector_index_details_default(); + let index_details = Some(std::sync::Arc::new(details)); + let index = IndexMetadata { + uuid: uuid::Uuid::new_v4(), + fields: vec![0], + name: "test_index".to_string(), + dataset_version: 1, + fragment_bitmap: None, + index_details, + index_version: 1, + created_at: None, + base_id: None, + files: None, + }; + + let metric = metric_type_from_index_metadata(&index); + assert_eq!(metric, None); + } + + #[test] + fn test_metric_type_from_index_metadata_none() { + // Test that missing details return None + let index = IndexMetadata { + uuid: uuid::Uuid::new_v4(), + fields: vec![0], + name: "test_index".to_string(), + dataset_version: 1, + fragment_bitmap: None, + index_details: None, + index_version: 1, + created_at: None, + base_id: None, + files: None, + }; + + let metric = metric_type_from_index_metadata(&index); + assert_eq!(metric, None); + } + + #[test] + fn test_metric_type_from_index_metadata_all_metrics() { + // Test all supported metric types. + // Note: We add a non-default compression field so the proto doesn't serialize to empty bytes. + let metrics = [ + VectorMetricType::L2, + VectorMetricType::Cosine, + VectorMetricType::Dot, + VectorMetricType::Hamming, + ]; + let expected = [ + DistanceType::L2, + DistanceType::Cosine, + DistanceType::Dot, + DistanceType::Hamming, + ]; + + for (metric_enum, expected_distance) in metrics.iter().zip(expected.iter()) { + let details = make_details( + *metric_enum, + None, + Some(Compression::Sq(ScalarQuantization { num_bits: 8 })), + ); + let index_details = Some(std::sync::Arc::new(details)); + let index = IndexMetadata { + uuid: uuid::Uuid::new_v4(), + fields: vec![0], + name: "test_index".to_string(), + dataset_version: 1, + fragment_bitmap: None, + index_details, + index_version: 1, + created_at: None, + base_id: None, + files: None, + }; + + let metric = metric_type_from_index_metadata(&index); + assert_eq!(metric, Some(*expected_distance)); + } + } + + #[test] + fn test_runtime_hints_roundtrip() { + use crate::index::vector::{StageParams, VectorIndexParams}; + use lance_index::vector::ivf::builder::IvfBuildParams; + use lance_index::vector::pq::builder::PQBuildParams; + use lance_linalg::distance::DistanceType; + + // Non-default values for IVF and PQ hints + let params = VectorIndexParams::with_ivf_pq_params( + DistanceType::L2, + IvfBuildParams { + max_iters: 100, + sample_rate: 512, + shuffle_partition_batches: 2048, + shuffle_partition_concurrency: 4, + ..Default::default() + }, + PQBuildParams { + num_sub_vectors: 8, + num_bits: 8, + max_iters: 75, + kmeans_redos: 3, + sample_rate: 128, + ..Default::default() + }, + ); + + let any = vector_index_details(¶ms); + let details = any.to_msg::().unwrap(); + assert_eq!( + details + .runtime_hints + .get("lance.ivf.max_iters") + .map(|s| s.as_str()), + Some("100") + ); + assert_eq!( + details + .runtime_hints + .get("lance.ivf.sample_rate") + .map(|s| s.as_str()), + Some("512") + ); + assert_eq!( + details + .runtime_hints + .get("lance.ivf.shuffle_partition_batches") + .map(|s| s.as_str()), + Some("2048") + ); + assert_eq!( + details + .runtime_hints + .get("lance.ivf.shuffle_partition_concurrency") + .map(|s| s.as_str()), + Some("4") + ); + assert_eq!( + details + .runtime_hints + .get("lance.pq.max_iters") + .map(|s| s.as_str()), + Some("75") + ); + assert_eq!( + details + .runtime_hints + .get("lance.pq.sample_rate") + .map(|s| s.as_str()), + Some("128") + ); + assert_eq!( + details + .runtime_hints + .get("lance.pq.kmeans_redos") + .map(|s| s.as_str()), + Some("3") + ); + // No HNSW stage in this IVF+PQ params, so no prefetch_distance hint. + assert!( + !details + .runtime_hints + .contains_key("lance.hnsw.prefetch_distance") + ); + // skip_transpose is recorded even when false. + assert_eq!( + details.runtime_hints.get("lance.skip_transpose"), + Some(&"false".to_string()) + ); + + // Roundtrip: apply hints back to a fresh params struct + let mut restored = VectorIndexParams::with_ivf_pq_params( + DistanceType::L2, + IvfBuildParams::default(), + PQBuildParams { + num_sub_vectors: 8, + num_bits: 8, + ..Default::default() + }, + ); + apply_runtime_hints(&details.runtime_hints, &mut restored); + let StageParams::Ivf(ivf) = &restored.stages[0] else { + panic!() + }; + assert_eq!(ivf.max_iters, 100); + assert_eq!(ivf.sample_rate, 512); + assert_eq!(ivf.shuffle_partition_batches, 2048); + assert_eq!(ivf.shuffle_partition_concurrency, 4); + let StageParams::PQ(pq) = &restored.stages[1] else { + panic!() + }; + assert_eq!(pq.max_iters, 75); + assert_eq!(pq.sample_rate, 128); + assert_eq!(pq.kmeans_redos, 3); + } + + #[test] + fn test_runtime_hints_roundtrip_hnsw_sq_skip_transpose() { + use crate::index::vector::{StageParams, VectorIndexParams}; + use lance_index::vector::hnsw::builder::HnswBuildParams; + use lance_index::vector::ivf::builder::IvfBuildParams; + use lance_index::vector::sq::builder::SQBuildParams; + use lance_linalg::distance::DistanceType; + + // Non-default values for hints that aren't covered by the IVF+PQ test: + // hnsw.prefetch_distance, sq.sample_rate, and the top-level skip_transpose. + let hnsw = HnswBuildParams { + m: 20, + ef_construction: 150, + max_level: 6, + prefetch_distance: Some(4), + }; + let mut params = VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::L2, + IvfBuildParams::default(), + hnsw, + SQBuildParams { + num_bits: 8, + sample_rate: 128, + }, + ); + params.skip_transpose = true; + + let any = vector_index_details(¶ms); + let details = any.to_msg::().unwrap(); + assert_eq!( + details.runtime_hints.get("lance.hnsw.prefetch_distance"), + Some(&"4".to_string()) + ); + assert_eq!( + details.runtime_hints.get("lance.sq.sample_rate"), + Some(&"128".to_string()) + ); + assert_eq!( + details.runtime_hints.get("lance.skip_transpose"), + Some(&"true".to_string()) + ); + + // Roundtrip back into a fresh params struct + let mut restored = VectorIndexParams::with_ivf_hnsw_sq_params( + DistanceType::L2, + IvfBuildParams::default(), + HnswBuildParams::default(), + SQBuildParams::default(), + ); + assert!(!restored.skip_transpose); + apply_runtime_hints(&details.runtime_hints, &mut restored); + + assert!(restored.skip_transpose); + let StageParams::Hnsw(restored_hnsw) = &restored.stages[1] else { + panic!("expected HNSW stage"); + }; + assert_eq!(restored_hnsw.prefetch_distance, Some(4)); + let StageParams::SQ(restored_sq) = &restored.stages[2] else { + panic!("expected SQ stage"); + }; + assert_eq!(restored_sq.sample_rate, 128); + } + + #[test] + fn test_runtime_hints_prefetch_distance_none_roundtrip() { + use crate::index::vector::{StageParams, VectorIndexParams}; + use lance_index::vector::hnsw::builder::HnswBuildParams; + use lance_index::vector::ivf::builder::IvfBuildParams; + use lance_linalg::distance::DistanceType; + + // prefetch_distance = None is distinguishable from "default Some(2)" + // via the "none" sentinel — verify it round-trips. + let hnsw = HnswBuildParams { + m: 16, + ef_construction: 100, + max_level: 5, + prefetch_distance: None, + }; + let params = VectorIndexParams::ivf_hnsw(DistanceType::L2, IvfBuildParams::default(), hnsw); + + let any = vector_index_details(¶ms); + let details = any.to_msg::().unwrap(); + assert_eq!( + details.runtime_hints.get("lance.hnsw.prefetch_distance"), + Some(&"none".to_string()) + ); + + let mut restored = VectorIndexParams::ivf_hnsw( + DistanceType::L2, + IvfBuildParams::default(), + HnswBuildParams::default(), + ); + apply_runtime_hints(&details.runtime_hints, &mut restored); + let StageParams::Hnsw(restored_hnsw) = &restored.stages[1] else { + panic!("expected HNSW stage"); + }; + assert_eq!(restored_hnsw.prefetch_distance, None); + } + + #[test] + fn test_runtime_hints_in_json() { + use crate::index::vector::VectorIndexParams; + use lance_index::vector::ivf::builder::IvfBuildParams; + use lance_index::vector::pq::builder::PQBuildParams; + use lance_linalg::distance::DistanceType; + + let params = VectorIndexParams::with_ivf_pq_params( + DistanceType::L2, + IvfBuildParams { + max_iters: 100, + ..Default::default() + }, + PQBuildParams { + num_sub_vectors: 8, + num_bits: 8, + ..Default::default() + }, + ); + let any = vector_index_details(¶ms); + let json = vector_details_as_json(&any).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["runtime_hints"]["lance.ivf.max_iters"], "100"); + } + + /// Matrix of subindex/quantizer combinations we want to round-trip. + #[derive(Debug, Clone, Copy)] + #[allow(clippy::enum_variant_names)] + enum Combo { + IvfFlat, + IvfPq, + IvfSq, + IvfRqMatrix, + IvfRqFast, + IvfHnswFlat, + IvfHnswPq, + IvfHnswSq, + } + + fn build_roundtrip_params(combo: Combo, metric: DistanceType) -> VectorIndexParams { + use crate::index::vector::VectorIndexParams; + use lance_index::vector::bq::{RQBuildParams, RQRotationType}; + use lance_index::vector::hnsw::builder::HnswBuildParams; + use lance_index::vector::ivf::builder::IvfBuildParams; + use lance_index::vector::pq::builder::PQBuildParams; + use lance_index::vector::sq::builder::SQBuildParams; + + // Non-default values so the round-trip actually checks preservation + // rather than coincidentally matching defaults. + let ivf = IvfBuildParams { + max_iters: 100, + sample_rate: 512, + target_partition_size: Some(2048), + shuffle_partition_batches: 4096, + shuffle_partition_concurrency: 4, + ..Default::default() + }; + let hnsw = HnswBuildParams { + m: 30, + ef_construction: 200, + max_level: 5, + prefetch_distance: Some(2), + }; + let pq = PQBuildParams { + num_sub_vectors: 8, + num_bits: 8, + max_iters: 75, + sample_rate: 128, + kmeans_redos: 3, + ..Default::default() + }; + let sq = SQBuildParams { + num_bits: 8, + sample_rate: 128, + }; + + match combo { + Combo::IvfFlat => VectorIndexParams::with_ivf_flat_params(metric, ivf), + Combo::IvfPq => VectorIndexParams::with_ivf_pq_params(metric, ivf, pq), + Combo::IvfSq => VectorIndexParams::with_ivf_sq_params(metric, ivf, sq), + Combo::IvfRqMatrix => VectorIndexParams::with_ivf_rq_params( + metric, + ivf, + RQBuildParams::with_rotation_type(1, RQRotationType::Matrix), + ), + Combo::IvfRqFast => VectorIndexParams::with_ivf_rq_params( + metric, + ivf, + RQBuildParams::with_rotation_type(1, RQRotationType::Fast), + ), + Combo::IvfHnswFlat => VectorIndexParams::ivf_hnsw(metric, ivf, hnsw), + Combo::IvfHnswPq => VectorIndexParams::with_ivf_hnsw_pq_params(metric, ivf, hnsw, pq), + Combo::IvfHnswSq => VectorIndexParams::with_ivf_hnsw_sq_params(metric, ivf, hnsw, sq), + } + } + + #[rstest::rstest] + #[case::ivf_flat(Combo::IvfFlat)] + #[case::ivf_pq(Combo::IvfPq)] + #[case::ivf_sq(Combo::IvfSq)] + #[case::ivf_rq_matrix(Combo::IvfRqMatrix)] + #[case::ivf_rq_fast(Combo::IvfRqFast)] + #[case::ivf_hnsw_flat(Combo::IvfHnswFlat)] + #[case::ivf_hnsw_pq(Combo::IvfHnswPq)] + #[case::ivf_hnsw_sq(Combo::IvfHnswSq)] + fn test_vector_index_details_roundtrip( + #[case] combo: Combo, + #[values(DistanceType::L2, DistanceType::Cosine)] metric: DistanceType, + ) { + use crate::index::vector::StageParams; + use lance_index::vector::bq::RQRotationType; + + let params = build_roundtrip_params(combo, metric); + + let any = vector_index_details(¶ms); + let restored = vector_params_from_details(&any) + .expect("non-empty details should round-trip to params"); + + assert_eq!(restored.metric_type, metric); + assert_eq!(restored.index_type(), params.index_type()); + + let StageParams::Ivf(ivf) = &restored.stages[0] else { + panic!("first stage should be IVF for combo {:?}", combo); + }; + assert_eq!(ivf.max_iters, 100); + assert_eq!(ivf.sample_rate, 512); + assert_eq!(ivf.target_partition_size, Some(2048)); + assert_eq!(ivf.shuffle_partition_batches, 4096); + assert_eq!(ivf.shuffle_partition_concurrency, 4); + + match combo { + Combo::IvfFlat => { + assert_eq!(restored.stages.len(), 1); + } + Combo::IvfPq => { + let StageParams::PQ(pq) = &restored.stages[1] else { + panic!("expected PQ stage"); + }; + assert_eq!(pq.num_sub_vectors, 8); + assert_eq!(pq.num_bits, 8); + assert_eq!(pq.max_iters, 75); + assert_eq!(pq.sample_rate, 128); + assert_eq!(pq.kmeans_redos, 3); + } + Combo::IvfSq => { + let StageParams::SQ(sq) = &restored.stages[1] else { + panic!("expected SQ stage"); + }; + assert_eq!(sq.num_bits, 8); + assert_eq!(sq.sample_rate, 128); + } + Combo::IvfRqMatrix | Combo::IvfRqFast => { + let StageParams::RQ(rq) = &restored.stages[1] else { + panic!("expected RQ stage"); + }; + assert_eq!(rq.num_bits, 1); + let expected = match combo { + Combo::IvfRqMatrix => RQRotationType::Matrix, + Combo::IvfRqFast => RQRotationType::Fast, + _ => unreachable!(), + }; + assert_eq!(rq.rotation_type, expected); + } + Combo::IvfHnswFlat => { + let StageParams::Hnsw(hnsw) = &restored.stages[1] else { + panic!("expected HNSW stage"); + }; + assert_eq!(hnsw.m, 30); + assert_eq!(hnsw.ef_construction, 200); + assert_eq!(hnsw.max_level, 5); + } + Combo::IvfHnswPq => { + let StageParams::Hnsw(hnsw) = &restored.stages[1] else { + panic!("expected HNSW stage"); + }; + assert_eq!(hnsw.m, 30); + assert_eq!(hnsw.ef_construction, 200); + assert_eq!(hnsw.max_level, 5); + let StageParams::PQ(pq) = &restored.stages[2] else { + panic!("expected PQ stage"); + }; + assert_eq!(pq.num_sub_vectors, 8); + assert_eq!(pq.num_bits, 8); + } + Combo::IvfHnswSq => { + let StageParams::Hnsw(hnsw) = &restored.stages[1] else { + panic!("expected HNSW stage"); + }; + assert_eq!(hnsw.m, 30); + assert_eq!(hnsw.ef_construction, 200); + assert_eq!(hnsw.max_level, 5); + let StageParams::SQ(sq) = &restored.stages[2] else { + panic!("expected SQ stage"); + }; + assert_eq!(sq.num_bits, 8); + } + } + } + + #[test] + fn test_apply_runtime_hints_ignores_unknown_keys() { + use crate::index::vector::VectorIndexParams; + use lance_index::vector::ivf::builder::IvfBuildParams; + use lance_linalg::distance::DistanceType; + + let hints: HashMap = [ + ("lancedb.accelerator".to_string(), "cuda".to_string()), + ("unknown.vendor.key".to_string(), "value".to_string()), + ("lance.ivf.max_iters".to_string(), "99".to_string()), + ] + .into(); + + let mut params = + VectorIndexParams::with_ivf_flat_params(DistanceType::L2, IvfBuildParams::default()); + apply_runtime_hints(&hints, &mut params); + + let StageParams::Ivf(ivf) = ¶ms.stages[0] else { + panic!() + }; + assert_eq!(ivf.max_iters, 99); + // Unknown keys silently ignored — no panic + } +} diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index aeeb65173b3..179ec96df4c 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2251,7 +2251,7 @@ pub(crate) async fn merge_segments_with_progress( merged_segment = TableIndexMetadata { uuid: segment_uuid, fragment_bitmap: Some(fragment_bitmap), - index_details: Some(Arc::new(crate::index::vector_index_details())), + index_details: Some(Arc::new(crate::index::vector_index_details_default())), index_version, created_at: Some(chrono::Utc::now()), base_id: None, @@ -2350,12 +2350,16 @@ fn build_segment_plan( Some(index_type) => index_type.version(), None => infer_source_index_version(&group)?, }; - let segment = IndexSegment::new( - segment_uuid, - fragment_bitmap, - Arc::new(crate::index::vector_index_details()), - index_version, - ); + + // Legacy source segments may not carry index_details. Fall back to an empty + // placeholder; `needs_vector_details_inference` will pick this up on the + // next manifest load and populate the real details from the index files. + let index_details = match first.index_details.as_ref() { + Some(d) => d.clone(), + None => Arc::new(crate::index::vector::details::vector_index_details_default()), + }; + + let segment = IndexSegment::new(segment_uuid, fragment_bitmap, index_details, index_version); Ok(IndexSegmentPlan::new( segment, @@ -2722,7 +2726,7 @@ mod tests { use crate::dataset::{InsertBuilder, WriteMode, WriteParams}; use crate::index::prefilter::DatasetPreFilter; use crate::index::vector::IndexFileVersion; - use crate::index::vector_index_details; + use crate::index::vector_index_details_default; use crate::index::{DatasetIndexExt, DatasetIndexInternalExt, vector::VectorIndexParams}; use crate::utils::test::copy_test_data_to_tmp; @@ -3195,7 +3199,7 @@ mod tests { fields: vec![field.id], name: INDEX_NAME.to_string(), fragment_bitmap: Some(dataset.fragment_bitmap.as_ref().clone()), - index_details: Some(Arc::new(vector_index_details())), + index_details: Some(Arc::new(vector_index_details_default())), index_version: VECTOR_INDEX_VERSION as i32, created_at: Some(chrono::Utc::now()), base_id: None, @@ -3234,7 +3238,7 @@ mod tests { fields: Vec::new(), name: INDEX_NAME.to_string(), fragment_bitmap: None, - index_details: Some(Arc::new(vector_index_details())), + index_details: Some(Arc::new(vector_index_details_default())), index_version: VECTOR_INDEX_VERSION as i32, created_at: None, // Test index, not setting timestamp base_id: None, @@ -3294,7 +3298,7 @@ mod tests { fields: vec![field.id], name: format!("{}_remapped", INDEX_NAME), fragment_bitmap: Some(dataset_mut.fragment_bitmap.as_ref().clone()), - index_details: Some(Arc::new(vector_index_details())), + index_details: Some(Arc::new(vector_index_details_default())), index_version: VECTOR_INDEX_VERSION as i32, created_at: Some(chrono::Utc::now()), base_id: None, diff --git a/rust/lance/src/io/commit.rs b/rust/lance/src/io/commit.rs index 700da33a034..67e57ed3320 100644 --- a/rust/lance/src/io/commit.rs +++ b/rust/lance/src/io/commit.rs @@ -50,6 +50,7 @@ use crate::dataset::{ }; use crate::index::DatasetIndexExt; use crate::index::DatasetIndexInternalExt; +use crate::index::vector::details::infer_missing_vector_details; use crate::io::deletion::read_dataset_deletion_file; use crate::session::Session; use crate::session::caches::DSMetadataCache; @@ -671,6 +672,7 @@ fn must_recalculate_fragment_bitmap( /// Indices might be missing `fragment_bitmap`, so this function will add it. /// Indices might also be missing `files` (file sizes), so this function will collect them. async fn migrate_indices(dataset: &Dataset, indices: &mut [IndexMetadata]) -> Result<()> { + infer_missing_vector_details(dataset, indices).await; let needs_recalculating = match detect_overlapping_fragments(indices) { Ok(()) => vec![], Err(BadFragmentBitmapError { bad_indices }) => { From 4de5ce67dfdcc49f25a43f3fded178d9458b10fd Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 20 May 2026 12:23:53 -0700 Subject: [PATCH 10/23] feat(index): serializable cache for Bitmap and LabelList scalar indices (#6874) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `CacheCodec` impls so Bitmap and LabelList index cache entries survive through a persistent cache backend, mirroring the BTree work in #6793. - `CacheCodecImpl for RowAddrTreeMap` (delegates to existing `serialize_into`/`deserialize_from`), so per-value bitmap entries cached under `BitmapKey` are codec-backed. - `BitmapIndexState` captures the value→offset map (Arrow IPC), the null bitmap, and the value type. `BitmapIndexPlugin` overrides `get_from_cache`/`put_in_cache` to store this sized state. - `LabelListIndexState` wraps an inner `BitmapIndexState` plus `list_nulls` and gets the same plugin-level codec treatment. - `open_scalar_index` skips the LabelList compatibility check on cache hits, so a fully-cached LabelList query no longer pays an extra `bitmap_page_lookup.lance` open per call. ## Tests - Unit codec round-trip for `BitmapIndexState` (empty + populated). - Integration tests `test_{bitmap,label_list}_prewarm_with_serializing_backend_serves_query_with_no_io` asserting zero IOPS after prewarm through a serializing cache backend. Closes #6744 --- rust/lance-core/src/utils/mask.rs | 11 + rust/lance-index/src/scalar/bitmap.rs | 214 +++++++++++++++++- rust/lance-index/src/scalar/label_list.rs | 122 +++++++++- rust/lance/src/dataset/tests/dataset_index.rs | 193 ++++++++++++++++ rust/lance/src/index/scalar.rs | 11 +- 5 files changed, 544 insertions(+), 7 deletions(-) diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 0ee1b5d17fa..3836832ae29 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -13,6 +13,7 @@ use deepsize::DeepSizeOf; use itertools::Itertools; use roaring::{MultiOps, RoaringBitmap, RoaringTreemap}; +use crate::cache::CacheCodecImpl; use crate::{Error, Result}; use super::address::RowAddress; @@ -661,6 +662,16 @@ impl RowAddrTreeMap { } } +impl CacheCodecImpl for RowAddrTreeMap { + fn serialize(&self, writer: &mut dyn Write) -> Result<()> { + self.serialize_into(writer) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + Self::deserialize_from(data.as_ref()) + } +} + impl std::ops::BitOr for RowAddrTreeMap { type Output = Self; diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 45027cc7b63..76d387f92b7 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -19,10 +19,14 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{StreamExt, TryStreamExt, stream}; +use lance_arrow::ipc::{ + read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream, + write_len_prefixed_bytes, +}; use lance_core::utils::mask::RowSetOps; use lance_core::{ Error, ROW_ID, Result, - cache::{CacheKey, LanceCache, WeakLanceCache}, + cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, utils::{ mask::{NullableRowAddrSet, RowAddrTreeMap}, @@ -145,6 +149,151 @@ impl CacheKey for BitmapKey { fn type_name() -> &'static str { "Bitmap" } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } +} + +/// The serializable state of a [`BitmapIndex`]. +/// +/// `BitmapIndex` holds non-serializable infrastructure (an `IndexStore`, a +/// cache handle, a lazy reader, a fragment-reuse index). `BitmapIndexState` +/// captures just the data needed to rebuild it: the value→file-offset map, +/// the null bitmap, and the value type. +#[derive(Debug, Clone)] +pub struct BitmapIndexState { + /// Value-to-row-offset lookup, encoded as an Arrow `RecordBatch` so we can + /// reuse the existing IPC utilities for zero-copy round trips. + /// + /// Schema: `keys: `, `offsets: UInt64`. Iteration order of + /// `index_map` is preserved on serialize and the `BTreeMap` resorts the + /// entries on deserialize, so the wire form does not need to be sorted. + lookup_batch: RecordBatch, + /// Already-remapped null bitmap (remapping is applied during load, so the + /// cached state matches the in-memory representation). + null_map: Arc, + /// Cached separately from the schema for the empty-index case where the + /// `lookup_batch` is empty but we still need to remember the column type. + value_type: DataType, +} + +impl DeepSizeOf for BitmapIndexState { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.lookup_batch.get_array_memory_size() + self.null_map.deep_size_of_children(context) + } +} + +impl BitmapIndexState { + pub(crate) fn from_index(index: &BitmapIndex) -> Result { + Ok(Self { + lookup_batch: build_lookup_batch(&index.index_map, &index.value_type)?, + null_map: index.null_map.clone(), + value_type: index.value_type.clone(), + }) + } + + pub(crate) fn into_bitmap_index( + self, + store: Arc, + index_cache: &LanceCache, + frag_reuse_index: Option>, + ) -> Result> { + let index_map = parse_lookup_batch(&self.lookup_batch)?; + Ok(Arc::new(BitmapIndex::new( + index_map, + self.null_map, + self.value_type, + store, + WeakLanceCache::from(index_cache), + frag_reuse_index, + ))) + } +} + +fn build_lookup_batch( + index_map: &BTreeMap, + value_type: &DataType, +) -> Result { + let keys = if index_map.is_empty() { + arrow_array::new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(index_map.keys().map(|k| k.0.clone()))? + }; + let offsets = Arc::new(UInt64Array::from_iter_values( + index_map.values().map(|v| *v as u64), + )); + let schema = Arc::new(Schema::new(vec![ + Field::new("keys", value_type.clone(), true), + Field::new("offsets", DataType::UInt64, false), + ])); + Ok(RecordBatch::try_new(schema, vec![keys, offsets])?) +} + +fn parse_lookup_batch(batch: &RecordBatch) -> Result> { + let keys = batch.column(0); + let offsets = batch + .column(1) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::internal("BitmapIndexState: expected UInt64 offsets column".to_string()) + })?; + let mut index_map = BTreeMap::new(); + for idx in 0..batch.num_rows() { + let value = OrderableScalarValue(ScalarValue::try_from_array(keys, idx)?); + index_map.insert(value, offsets.value(idx) as usize); + } + Ok(index_map) +} + +impl CacheCodecImpl for BitmapIndexState { + /// Wire format: + /// ```text + /// [u64 null_map_len][null_map bytes] + /// [arrow IPC stream: (keys: , offsets: UInt64)] + /// ``` + /// The value type is recovered from the IPC stream schema. + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + let mut null_bytes = Vec::with_capacity(self.null_map.serialized_size()); + self.null_map.serialize_into(&mut null_bytes)?; + write_len_prefixed_bytes(writer, &null_bytes)?; + write_ipc_stream(&self.lookup_batch, writer)?; + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + let mut offset = 0; + let null_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let null_map = Arc::new(RowAddrTreeMap::deserialize_from(null_bytes.as_ref())?); + let lookup_batch = read_ipc_stream_single_at(data, &mut offset)?; + let value_type = lookup_batch.schema().field(0).data_type().clone(); + Ok(Self { + lookup_batch, + null_map, + value_type, + }) + } +} + +/// Cache key for a [`BitmapIndexState`]. The cache is already namespaced +/// per-index by the caller, so a constant key suffices. +struct BitmapIndexStateKey; + +impl CacheKey for BitmapIndexStateKey { + type ValueType = BitmapIndexState; + + fn key(&self) -> std::borrow::Cow<'_, str> { + "state".into() + } + + fn type_name() -> &'static str { + "BitmapIndexState" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } } impl BitmapIndex { @@ -1542,6 +1691,34 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { Ok(BitmapIndex::load(index_store, frag_reuse_index, cache).await? as Arc) } + async fn get_from_cache( + &self, + index_store: Arc, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result>> { + let Some(state) = cache.get_with_key(&BitmapIndexStateKey).await else { + return Ok(None); + }; + let state = (*state).clone(); + let index = state.into_bitmap_index(index_store, cache, frag_reuse_index)?; + Ok(Some(index as Arc)) + } + + async fn put_in_cache(&self, cache: &LanceCache, index: Arc) -> Result<()> { + let bitmap = index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::internal("BitmapIndexPlugin::put_in_cache called with a non-bitmap index") + })?; + let state = BitmapIndexState::from_index(bitmap)?; + cache + .insert_with_key(&BitmapIndexStateKey, Arc::new(state)) + .await; + Ok(()) + } + async fn load_statistics( &self, index_store: Arc, @@ -1589,6 +1766,41 @@ mod tests { use lance_io::object_store::ObjectStore; use std::collections::HashMap; + fn assert_state_roundtrips(state: &BitmapIndexState) { + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + let restored = BitmapIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + assert_eq!(restored.lookup_batch, state.lookup_batch); + assert_eq!(&*restored.null_map, &*state.null_map); + assert_eq!(restored.value_type, state.value_type); + } + + #[test] + fn test_bitmap_index_state_codec_roundtrip() { + // Non-empty state with a few keys and a populated null map. + let mut index_map = BTreeMap::new(); + index_map.insert(OrderableScalarValue(ScalarValue::Int32(Some(1))), 0); + index_map.insert(OrderableScalarValue(ScalarValue::Int32(Some(7))), 1); + index_map.insert(OrderableScalarValue(ScalarValue::Int32(Some(42))), 2); + let mut null_map = RowAddrTreeMap::new(); + null_map.insert(RowAddress::new_from_parts(0, 3).into()); + null_map.insert(RowAddress::new_from_parts(0, 5).into()); + let state = BitmapIndexState { + lookup_batch: build_lookup_batch(&index_map, &DataType::Int32).unwrap(), + null_map: Arc::new(null_map), + value_type: DataType::Int32, + }; + assert_state_roundtrips(&state); + + // Empty state: no keys, empty null map. Schema still carries the type. + let empty_state = BitmapIndexState { + lookup_batch: build_lookup_batch(&BTreeMap::new(), &DataType::Utf8).unwrap(), + null_map: Arc::new(RowAddrTreeMap::new()), + value_type: DataType::Utf8, + }; + assert_state_roundtrips(&empty_state); + } + #[tokio::test] async fn test_bitmap_lazy_loading_and_cache() { // Create a temporary directory for the index diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index eae0f8d6054..1efb62bd566 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -19,7 +19,8 @@ use datafusion::physical_plan::{SendableRecordBatchStream, stream::RecordBatchSt use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{StreamExt, TryStream, TryStreamExt, stream::BoxStream}; -use lance_core::cache::LanceCache; +use lance_arrow::ipc::{read_len_prefixed_bytes_at, write_len_prefixed_bytes}; +use lance_core::cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache}; use lance_core::error::LanceOptionExt; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps}; use lance_core::{Error, ROW_ID, Result}; @@ -31,7 +32,7 @@ use super::{BuiltinIndexType, SargableQuery, ScalarIndexParams}; use super::{MetricsCollector, SearchResult}; use crate::frag_reuse::FragReuseIndex; use crate::pbold; -use crate::scalar::bitmap::BitmapIndexPlugin; +use crate::scalar::bitmap::{BitmapIndexPlugin, BitmapIndexState}; use crate::scalar::expression::{LabelListQueryParser, ScalarQueryParser}; use crate::scalar::registry::{ DefaultTrainingRequest, ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest, @@ -488,6 +489,93 @@ async fn write_label_list_bitmap_index( .await } +/// The serializable state of a [`LabelListIndex`]. +/// +/// `LabelListIndex` is a thin wrapper around a [`BitmapIndex`] plus a separate +/// row bitmap tracking which list values were `NULL` (lost by unnest at build +/// time). Its cache state is the corresponding [`BitmapIndexState`] plus the +/// already-loaded `list_nulls`. +#[derive(Debug, Clone)] +pub struct LabelListIndexState { + bitmap_state: BitmapIndexState, + list_nulls: Arc, +} + +impl DeepSizeOf for LabelListIndexState { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.bitmap_state.deep_size_of_children(context) + + self.list_nulls.deep_size_of_children(context) + } +} + +impl LabelListIndexState { + fn from_index(index: &LabelListIndex) -> Result { + Ok(Self { + bitmap_state: BitmapIndexState::from_index(&index.values_index)?, + list_nulls: index.list_nulls.clone(), + }) + } + + fn into_label_list_index( + self, + store: Arc, + index_cache: &LanceCache, + frag_reuse_index: Option>, + ) -> Result> { + let bitmap = self + .bitmap_state + .into_bitmap_index(store, index_cache, frag_reuse_index)?; + Ok(Arc::new(LabelListIndex::new(bitmap, self.list_nulls))) + } +} + +impl CacheCodecImpl for LabelListIndexState { + /// Wire format: + /// ```text + /// [u64 list_nulls_len][list_nulls bytes] + /// [bitmap state bytes (self-delimiting)] + /// ``` + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + let mut nulls_bytes = Vec::with_capacity(self.list_nulls.serialized_size()); + self.list_nulls.serialize_into(&mut nulls_bytes)?; + write_len_prefixed_bytes(writer, &nulls_bytes)?; + self.bitmap_state.serialize(writer)?; + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + let mut offset = 0; + let nulls_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let list_nulls = Arc::new(RowAddrTreeMap::deserialize_from(nulls_bytes.as_ref())?); + // The bitmap state is self-delimiting (length-prefixed null map + + // Arrow IPC stream with EOS marker), so we can hand the remaining + // tail to it directly. + let bitmap_state = BitmapIndexState::deserialize(&data.slice(offset..))?; + Ok(Self { + bitmap_state, + list_nulls, + }) + } +} + +struct LabelListIndexStateKey; + +impl CacheKey for LabelListIndexStateKey { + type ValueType = LabelListIndexState; + + fn key(&self) -> std::borrow::Cow<'_, str> { + "state".into() + } + + fn type_name() -> &'static str { + "LabelListIndexState" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } +} + #[derive(Debug, Default)] pub struct LabelListIndexPlugin; @@ -606,4 +694,34 @@ impl ScalarIndexPlugin for LabelListIndexPlugin { as Arc, ) } + + async fn get_from_cache( + &self, + index_store: Arc, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result>> { + let Some(state) = cache.get_with_key(&LabelListIndexStateKey).await else { + return Ok(None); + }; + let state = (*state).clone(); + let index = state.into_label_list_index(index_store, cache, frag_reuse_index)?; + Ok(Some(index as Arc)) + } + + async fn put_in_cache(&self, cache: &LanceCache, index: Arc) -> Result<()> { + let label_list = index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::internal( + "LabelListIndexPlugin::put_in_cache called with a non-label-list index", + ) + })?; + let state = LabelListIndexState::from_index(label_list)?; + cache + .insert_with_key(&LabelListIndexStateKey, Arc::new(state)) + .await; + Ok(()) + } } diff --git a/rust/lance/src/dataset/tests/dataset_index.rs b/rust/lance/src/dataset/tests/dataset_index.rs index 4fd1e2fcfd1..e785de7bee4 100644 --- a/rust/lance/src/dataset/tests/dataset_index.rs +++ b/rust/lance/src/dataset/tests/dataset_index.rs @@ -2297,6 +2297,199 @@ async fn test_btree_prewarm_with_serializing_backend_serves_query_with_no_io() { ); } +/// Bitmap analogue of `test_btree_prewarm_with_serializing_backend_serves_query_with_no_io`: +/// after prewarming a Bitmap scalar index through a serializing cache backend, +/// an indexed-filter query serves results without any further IO. The +/// serializing backend forces every cache hit through the `BitmapIndexState` +/// (top-level state) and `RowAddrTreeMap` (per-value bitmap) `CacheCodec` +/// impls, so this exercises both round-trip paths. +#[tokio::test] +async fn test_bitmap_prewarm_with_serializing_backend_serves_query_with_no_io() { + use lance_io::assert_io_eq; + + use fts_serializing_backend::SerializingBackend; + + let tmpdir = TempStrDir::default(); + let uri = tmpdir.to_owned(); + drop(tmpdir); + + // Low-cardinality column so the index has several per-value bitmaps to + // round-trip through the per-key codec. + let num_rows: i32 = 8_000; + let values = Int32Array::from_iter_values((0..num_rows).map(|i| i % 16)); + let ids = UInt64Array::from_iter_values(0..num_rows as u64); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("value", DataType::Int32, false), + arrow_schema::Field::new("id", DataType::UInt64, false), + ]) + .into(), + vec![Arc::new(values) as ArrayRef, Arc::new(ids) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, &uri, None).await.unwrap(); + dataset + .create_index( + &["value"], + IndexType::Bitmap, + Some("value_idx".to_owned()), + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + let backend = Arc::new(SerializingBackend::new()); + let session = Arc::new(Session::with_index_cache_backend( + backend.clone(), + 128 * 1024 * 1024, + Arc::new(lance_io::object_store::ObjectStoreRegistry::default()), + )); + let dataset = DatasetBuilder::from_uri(&uri) + .with_session(session) + .load() + .await + .unwrap(); + + dataset.object_store.as_ref().io_stats_incremental(); + dataset.prewarm_index("value_idx").await.unwrap(); + + let serialized_after_prewarm = backend.serialized_entry_count().await; + assert!( + serialized_after_prewarm > 0, + "prewarm should have routed the bitmap state and per-value bitmaps through \ + CacheCodec, but the serializing store was empty" + ); + + dataset.object_store.as_ref().io_stats_incremental(); + let result = dataset + .scan() + .project(&[ROW_ID]) + .unwrap() + .filter("value = 7") + .unwrap() + .try_into_batch() + .await + .unwrap(); + let expected = (num_rows as usize) / 16; + assert_eq!( + result.num_rows(), + expected, + "indexed bitmap filter should return correct results after deserialization" + ); + + let stats = dataset.object_store.as_ref().io_stats_incremental(); + assert_io_eq!( + stats, + read_iops, + 0, + "Bitmap filter query should not perform IO after prewarm; the serializing \ + cache backend must serve the index state and every per-value bitmap from memory" + ); +} + +/// LabelList analogue: after prewarming, an `array_has_any` query against a +/// `LabelList` index serves results without any further IO. Exercises the +/// `LabelListIndexState` codec (which embeds the inner bitmap state and the +/// list-nulls bitmap) plus the same per-value bitmap codec. +#[tokio::test] +async fn test_label_list_prewarm_with_serializing_backend_serves_query_with_no_io() { + use lance_io::assert_io_eq; + + use fts_serializing_backend::SerializingBackend; + + let tmpdir = TempStrDir::default(); + let uri = tmpdir.to_owned(); + drop(tmpdir); + + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; + + let mut dataset = gen_batch() + .col( + "labels", + lance_datagen::array::rand_list_any( + lance_datagen::array::cycle::(vec![1, 2, 3, 4, 5]), + false, + ), + ) + .into_dataset(&uri, FragmentCount::from(2), FragmentRowCount::from(2000)) + .await + .unwrap(); + dataset + .create_index( + &["labels"], + IndexType::LabelList, + Some("labels_idx".to_owned()), + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + let expected = dataset + .scan() + .project(&[ROW_ID]) + .unwrap() + .filter("array_has_any(labels, [3])") + .unwrap() + .try_into_batch() + .await + .unwrap() + .num_rows(); + assert!( + expected > 0, + "test dataset must contain at least one row whose labels include 3" + ); + + let backend = Arc::new(SerializingBackend::new()); + let session = Arc::new(Session::with_index_cache_backend( + backend.clone(), + 128 * 1024 * 1024, + Arc::new(lance_io::object_store::ObjectStoreRegistry::default()), + )); + let dataset = DatasetBuilder::from_uri(&uri) + .with_session(session) + .load() + .await + .unwrap(); + + dataset.object_store.as_ref().io_stats_incremental(); + dataset.prewarm_index("labels_idx").await.unwrap(); + + let serialized_after_prewarm = backend.serialized_entry_count().await; + assert!( + serialized_after_prewarm > 0, + "prewarm should have routed the label-list state and per-value bitmaps through \ + CacheCodec, but the serializing store was empty" + ); + + dataset.object_store.as_ref().io_stats_incremental(); + let result = dataset + .scan() + .project(&[ROW_ID]) + .unwrap() + .filter("array_has_any(labels, [3])") + .unwrap() + .try_into_batch() + .await + .unwrap(); + assert_eq!( + result.num_rows(), + expected, + "indexed label-list filter should return correct results after deserialization" + ); + + let stats = dataset.object_store.as_ref().io_stats_incremental(); + assert_io_eq!( + stats, + read_iops, + 0, + "LabelList filter query should not perform IO after prewarm; the serializing \ + cache backend must serve the index state and every per-value bitmap from memory" + ); +} + #[tokio::test] async fn test_fts_phrase_query_with_removed_stop_words() { let tmpdir = TempStrDir::default(); diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index a7856f2306f..18c218ef4f7 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -389,10 +389,6 @@ pub async fn open_scalar_index( let index_details = fetch_index_details(dataset, column, index).await?; let plugin = SCALAR_INDEX_PLUGIN_REGISTRY.get_plugin_by_details(index_details.as_ref())?; - if index_details.type_url.ends_with("LabelListIndexDetails") { - validate_label_list_index_compatibility(dataset, column, index, &index_store).await?; - } - let frag_reuse_index = dataset.open_frag_reuse_index(metrics).await?; let index_cache = dataset @@ -403,9 +399,16 @@ pub async fn open_scalar_index( .get_from_cache(index_store.clone(), frag_reuse_index.clone(), &index_cache) .await? { + // Compatibility check is only needed on first load; a cache hit means + // the index was already validated when it was originally opened in + // this session, so we can skip the extra `open_index_file` IOP. return Ok(index); } + if index_details.type_url.ends_with("LabelListIndexDetails") { + validate_label_list_index_compatibility(dataset, column, index, &index_store).await?; + } + let index = plugin .load_index(index_store, &index_details, frag_reuse_index, &index_cache) .await?; From e808eb135547220cdb7b9774bd7173fdcfc716ff Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Wed, 20 May 2026 15:34:03 -0500 Subject: [PATCH 11/23] feat(mem_wal): cache opened L0 flushed-generation datasets (#6816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem In the LSM scanner, every query against an L0 (frozen/flushed) generation re-opens that generation's Lance dataset from object storage. There are **three** identical cold-open sites — scan (`planner.rs`), point lookup (`point_lookup.rs`), and vector search (`vector_search.rs`) — each doing `DatasetBuilder::from_uri(path).load()` with no session. Per query, per flushed generation, this pays: manifest version discovery + manifest read + decode, file-metadata decode, and scalar/vector index load. For an LSM tree, frozen generations are the single best caching target, yet they were the only data source paying full cold-open cost on every query. ## Key invariant Flush writes each generation **once** to a globally-unique, content-addressed path (`memtable/flush.rs`). Same path ⟹ same bytes, forever — a cache hit can never be stale. This is the rare cache that needs **no TTL and no invalidation for correctness**; pruning is desirable only to reclaim memory. ## Changes (OSS `lance`) Two complementary, independently-useful, opt-in pieces — defaults preserve existing behavior exactly: 1. **`with_session` plumbing** — thread an existing `Arc` into the scanner/planners so the first open of each generation populates and reuses the shared index + file-metadata caches. `LsmScanner::new` defaults this to the base table's session; `without_base_table` defaults to `None`. 2. **`FlushedDatasetCache`** — a `moka`-backed, single-flight cache of `Arc` keyed by resolved flushed path, owned and sized by the consumer and injected per-request. After the first open, every subsequent query for that generation is a pure `Arc::clone` with zero object-store I/O. `retain_paths(live_paths)` prunes retired generations at compaction (memory-only; correctness never depends on it). A single shared `open_flushed_dataset(path, session, cache)` helper replaces all three cold-open sites (repo rule: dedupe logic in 2+ places). `None`/`None` reproduces the original behavior precisely, so no existing test changes. `data_source.rs` / `collector.rs` are untouched — opening stays lazy inside the planner, preserving bloom-filter pruning on point lookups. Planner wiring uses chainable `with_session`/`with_flushed_cache` builder methods rather than constructor changes, keeping `new()` signatures (and every existing test/bench) untouched. ## Testing - New unit tests for `FlushedDatasetCache`: miss opens once; hit returns the same `Arc` (pointer eq); 16-way concurrent `get_or_open` opens exactly once (single-flight); `retain_paths` drops the right keys; no-cache path cold-opens each call. - Regression: full `mem_wal::scanner` suite (78 tests) passes untouched. - `cargo clippy -p lance --tests --benches` clean; `cargo fmt` clean. ## Notes The `sophon` consumer side (process-bootstrap cache ownership, scanner wiring, compaction `retain_paths`) is out of scope for this PR. Phase 1 (`with_session`) is independently shippable ahead of the cache. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- Cargo.lock | 1 + python/Cargo.lock | 1 + rust/lance/Cargo.toml | 1 + rust/lance/src/dataset/mem_wal/scanner.rs | 2 + .../src/dataset/mem_wal/scanner/builder.rs | 49 +++- .../dataset/mem_wal/scanner/flushed_cache.rs | 276 ++++++++++++++++++ .../src/dataset/mem_wal/scanner/planner.rs | 28 +- .../dataset/mem_wal/scanner/point_lookup.rs | 28 +- .../dataset/mem_wal/scanner/vector_search.rs | 28 +- 9 files changed, 404 insertions(+), 10 deletions(-) create mode 100644 rust/lance/src/dataset/mem_wal/scanner/flushed_cache.rs diff --git a/Cargo.lock b/Cargo.lock index 1b4f587ba2c..6e20ba7bf77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4386,6 +4386,7 @@ dependencies = [ "log", "lzma-sys", "mock_instant", + "moka", "object_store", "parquet", "permutation", diff --git a/python/Cargo.lock b/python/Cargo.lock index 5e13dd8ac51..882344b8c4f 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -4023,6 +4023,7 @@ dependencies = [ "lance-table", "lance-tokenizer", "log", + "moka", "object_store", "permutation", "pin-project", diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 1f7c4af0bd1..ca2bfdeaf91 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -55,6 +55,7 @@ deepsize.workspace = true # matches arrow-rs use half.workspace = true itertools.workspace = true +moka.workspace = true object_store = { workspace = true } aws-credential-types.workspace = true aws-credential-types.optional = true diff --git a/rust/lance/src/dataset/mem_wal/scanner.rs b/rust/lance/src/dataset/mem_wal/scanner.rs index b97221cc81e..ec179653096 100644 --- a/rust/lance/src/dataset/mem_wal/scanner.rs +++ b/rust/lance/src/dataset/mem_wal/scanner.rs @@ -35,6 +35,7 @@ mod builder; mod collector; mod data_source; pub mod exec; +mod flushed_cache; mod planner; mod point_lookup; mod projection; @@ -45,6 +46,7 @@ pub use collector::{ ActiveMemTableRef, InMemoryMemTableRef, InMemoryMemTables, LsmDataSourceCollector, }; pub use data_source::{FlushedGeneration, LsmDataSource, LsmGeneration, ShardSnapshot}; +pub use flushed_cache::FlushedMemTableCache; pub use point_lookup::LsmPointLookupPlanner; pub use projection::DISTANCE_COLUMN; pub use vector_search::LsmVectorSearchPlanner; diff --git a/rust/lance/src/dataset/mem_wal/scanner/builder.rs b/rust/lance/src/dataset/mem_wal/scanner/builder.rs index ae337c924c1..570a3e0cfc9 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/builder.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/builder.rs @@ -18,8 +18,10 @@ use uuid::Uuid; use super::collector::{InMemoryMemTableRef, InMemoryMemTables, LsmDataSourceCollector}; use super::data_source::ShardSnapshot; +use super::flushed_cache::FlushedMemTableCache; use super::planner::LsmScanPlanner; use crate::dataset::Dataset; +use crate::session::Session; /// Either a base Lance table, or an explicit base path used to resolve /// flushed-generation directories when no base dataset is configured. @@ -73,6 +75,14 @@ pub struct LsmScanner { // Primary key columns (required for deduplication) pk_columns: Vec, + + /// Session threaded into flushed-generation opens so the first open of + /// each generation populates the shared index / file-metadata caches. + /// Defaults to the base table's session when one is present. + session: Option>, + /// Cache of opened flushed-generation datasets. When set, repeated + /// queries against the same generation skip the manifest read entirely. + flushed_cache: Option>, } impl LsmScanner { @@ -90,6 +100,10 @@ impl LsmScanner { ) -> Self { let lance_schema = base_table.schema(); let arrow_schema: arrow_schema::Schema = lance_schema.into(); + // Default the session to the base table's so the common path reuses + // the shared index / metadata caches without extra wiring. An + // explicit `with_session` still overrides this. + let session = Some(base_table.session()); Self { base: BaseSource::Table(base_table), schema: Arc::new(arrow_schema), @@ -102,6 +116,8 @@ impl LsmScanner { with_row_address: false, with_memtable_gen: false, pk_columns, + session, + flushed_cache: None, } } @@ -138,6 +154,8 @@ impl LsmScanner { with_row_address: false, with_memtable_gen: false, pk_columns, + session: None, + flushed_cache: None, } } @@ -170,6 +188,29 @@ impl LsmScanner { self } + /// Thread an existing session into flushed-generation opens. + /// + /// The first open of each flushed generation then populates the shared + /// index / file-metadata caches, so later queries skip re-decoding them. + /// When a base table is configured this defaults to its session; call + /// this to override (e.g. on a fresh-tier-only scanner that owns its own + /// long-lived session). + pub fn with_session(mut self, session: Arc) -> Self { + self.session = Some(session); + self + } + + /// Inject a cache of opened flushed-generation datasets. + /// + /// With a cache, repeated queries against the same generation become a + /// pure `Arc::clone` with no manifest read or object-store I/O. The cache + /// is owned and sized by the caller (see [`FlushedMemTableCache`]); not + /// set by default, so behavior is unchanged unless opted in. + pub fn with_flushed_cache(mut self, cache: Arc) -> Self { + self.flushed_cache = Some(cache); + self + } + /// Project specific columns. /// /// If not called, all columns from the base schema are included. @@ -239,7 +280,13 @@ impl LsmScanner { pub async fn create_plan(&self) -> Result> { let collector = self.build_collector(); let base_schema = self.schema(); - let planner = LsmScanPlanner::new(collector, self.pk_columns.clone(), base_schema); + let mut planner = LsmScanPlanner::new(collector, self.pk_columns.clone(), base_schema); + if let Some(session) = &self.session { + planner = planner.with_session(session.clone()); + } + if let Some(cache) = &self.flushed_cache { + planner = planner.with_flushed_cache(cache.clone()); + } planner .plan_scan( diff --git a/rust/lance/src/dataset/mem_wal/scanner/flushed_cache.rs b/rust/lance/src/dataset/mem_wal/scanner/flushed_cache.rs new file mode 100644 index 00000000000..94f004399b0 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/flushed_cache.rs @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Cache of opened flushed-generation datasets for the LSM scanner. +//! +//! Flushed generations are written exactly once to a globally-unique, +//! content-addressed path (see `memtable/flush.rs`): a fresh random hash per +//! flush invocation means the same path always maps to the same immutable +//! bytes. A cached `Arc` therefore can never go stale and needs no +//! TTL or invalidation for correctness — pruning entries is a pure memory +//! optimization driven by the consumer at compaction time. +//! +//! ```text +//! query ──> open_flushed_dataset(path, session, cache) +//! │ +//! cache.is_some() ──────┤────── cache.is_none() +//! │ │ +//! FlushedMemTableCache::get_or_open DatasetBuilder::from_uri +//! (single-flight, shared Arc) (cold open every call) +//! ``` + +use std::collections::HashSet; +use std::sync::Arc; + +use lance_core::{Error, Result}; + +use crate::dataset::{Dataset, DatasetBuilder}; +use crate::session::Session; + +/// Cache of opened flushed-generation datasets, keyed by resolved path. +/// +/// Flushed generations live at a globally-unique, immutable path, so cached +/// entries are never stale and require no TTL. Intended to be held by a +/// long-lived owner (one per process or per table) and injected into +/// per-request scanners via [`crate::dataset::mem_wal::scanner::LsmScanner`] +/// (and the point-lookup / vector-search planners). +/// +/// The key is the resolved absolute flushed path +/// (`{base}/_mem_wal/{shard}/{folder}`), which is globally unique, so a single +/// cache can safely span multiple tables. +pub struct FlushedMemTableCache { + // `moka`'s async cache gives a bounded size plus single-flight + // `try_get_with`, so concurrent first-queries on a just-flushed + // generation open the dataset exactly once. + inner: moka::future::Cache>, +} + +impl FlushedMemTableCache { + /// Create a cache holding at most `max_entries` opened datasets. + /// + /// Eviction is size-only (no TTL): an evicted-then-re-requested generation + /// simply re-opens, which is always correct because the path is immutable. + pub fn new(max_entries: u64) -> Self { + Self { + inner: moka::future::Cache::builder() + .max_capacity(max_entries) + // Required for `retain_paths`: moka silently ignores + // `invalidate_entries_if` unless closure support is opted + // into at build time. + .support_invalidation_closures() + .build(), + } + } + + /// Get the dataset for `path`, opening it (exactly once) on a miss. + /// + /// `session` is threaded into the open so the first open populates the + /// shared index / file-metadata caches; subsequent hits are a pure + /// `Arc::clone` with zero object-store I/O. Concurrent callers for the + /// same path share a single open via `moka`'s single-flight + /// `try_get_with`. + pub async fn get_or_open( + &self, + path: &str, + session: Option>, + ) -> Result> { + self.inner + .try_get_with(path.to_string(), async move { + let mut builder = DatasetBuilder::from_uri(path); + if let Some(session) = session { + builder = builder.with_session(session); + } + builder.load().await.map(Arc::new) + }) + .await + // `try_get_with` hands losing racers an `Arc`; the original + // error keeps full context, clones collapse to `Error::Cloned`. + .map_err(|e: Arc| Error::cloned(e.to_string())) + } + + /// Drop cached entries whose path is not in `live_paths`. + /// + /// Called by the consumer after compaction retires generations. Purely a + /// memory optimization: stale entries are unobservable because a retired + /// generation's path never reappears in a shard snapshot, so correctness + /// never depends on this running. Invalidation is applied lazily by + /// `moka` during its next maintenance cycle. + pub fn retain_paths(&self, live_paths: &HashSet) { + let live = live_paths.clone(); + // The only error is exceeding moka's registered-predicate cap, which + // would just defer reclamation — never a correctness issue. + let _ = self + .inner + .invalidate_entries_if(move |path, _| !live.contains(path)); + } +} + +impl std::fmt::Debug for FlushedMemTableCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FlushedMemTableCache") + .field("entry_count", &self.inner.entry_count()) + .finish() + } +} + +/// Open a flushed-generation dataset, shared by all three LSM open sites +/// (scan, point lookup, vector search). +/// +/// - `cache` present: route through [`FlushedMemTableCache`] (single-flight, +/// shared `Arc`, manifest read amortized across queries). +/// - `cache` absent: cold open via [`DatasetBuilder`]. Passing `session` +/// still reuses the shared index / metadata caches; `None`/`None` +/// reproduces the original per-query cold-open behavior exactly. +pub(super) async fn open_flushed_dataset( + path: &str, + session: Option<&Arc>, + cache: Option<&Arc>, +) -> Result> { + match cache { + Some(cache) => cache.get_or_open(path, session.cloned()).await, + None => { + let mut builder = DatasetBuilder::from_uri(path); + if let Some(session) = session { + builder = builder.with_session(session.clone()); + } + Ok(Arc::new(builder.load().await?)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::atomic::{AtomicUsize, Ordering}; + + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + + use crate::dataset::WriteParams; + + async fn write_dataset(uri: &str, ids: &[i32]) { + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(ids.to_vec()))], + ) + .unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + Dataset::write(reader, uri, Some(WriteParams::default())) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_hit_returns_same_arc() { + let temp_dir = tempfile::tempdir().unwrap(); + let uri = format!("{}/gen_1", temp_dir.path().to_str().unwrap()); + write_dataset(&uri, &[1, 2, 3]).await; + + let cache = FlushedMemTableCache::new(8); + let first = cache.get_or_open(&uri, None).await.unwrap(); + let second = cache.get_or_open(&uri, None).await.unwrap(); + + assert!( + Arc::ptr_eq(&first, &second), + "a cache hit must return the same Arc, not re-open" + ); + assert_eq!(cache.inner.entry_count(), 0); // not yet flushed to count + cache.inner.run_pending_tasks().await; + assert_eq!(cache.inner.entry_count(), 1); + } + + #[tokio::test] + async fn test_concurrent_get_or_open_single_flight() { + // moka's `try_get_with` must collapse concurrent first-queries on the + // same path into exactly one open. We can't count opens through the + // public API, so wrap the cache call: every task that observes the + // same returned Arc proves they shared one open. + let temp_dir = tempfile::tempdir().unwrap(); + let uri = format!("{}/gen_1", temp_dir.path().to_str().unwrap()); + write_dataset(&uri, &[1, 2, 3]).await; + + let cache = Arc::new(FlushedMemTableCache::new(8)); + let calls = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + for _ in 0..16 { + let cache = cache.clone(); + let uri = uri.clone(); + let calls = calls.clone(); + handles.push(tokio::spawn(async move { + calls.fetch_add(1, Ordering::SeqCst); + cache.get_or_open(&uri, None).await.unwrap() + })); + } + + let datasets: Vec> = futures::future::try_join_all(handles).await.unwrap(); + + assert_eq!(calls.load(Ordering::SeqCst), 16, "all tasks ran"); + let first = &datasets[0]; + for ds in &datasets { + assert!( + Arc::ptr_eq(first, ds), + "all concurrent callers must share one opened dataset" + ); + } + cache.inner.run_pending_tasks().await; + assert_eq!(cache.inner.entry_count(), 1, "exactly one entry cached"); + } + + #[tokio::test] + async fn test_retain_paths_drops_unreferenced() { + let temp_dir = tempfile::tempdir().unwrap(); + let base = temp_dir.path().to_str().unwrap(); + let keep_uri = format!("{}/gen_1", base); + let drop_uri = format!("{}/gen_2", base); + write_dataset(&keep_uri, &[1]).await; + write_dataset(&drop_uri, &[2]).await; + + let cache = FlushedMemTableCache::new(8); + cache.get_or_open(&keep_uri, None).await.unwrap(); + cache.get_or_open(&drop_uri, None).await.unwrap(); + cache.inner.run_pending_tasks().await; + assert_eq!(cache.inner.entry_count(), 2); + + let live: HashSet = [keep_uri.clone()].into_iter().collect(); + cache.retain_paths(&live); + cache.inner.run_pending_tasks().await; + + assert_eq!(cache.inner.entry_count(), 1, "only live path retained"); + assert!(cache.inner.contains_key(&keep_uri)); + assert!(!cache.inner.contains_key(&drop_uri)); + } + + #[tokio::test] + async fn test_open_flushed_dataset_no_cache_matches_direct_open() { + // The `None`/`None` path must reproduce a plain cold open: same data, + // independent Arc per call (no caching). + let temp_dir = tempfile::tempdir().unwrap(); + let uri = format!("{}/gen_1", temp_dir.path().to_str().unwrap()); + write_dataset(&uri, &[7, 8, 9]).await; + + let a = open_flushed_dataset(&uri, None, None).await.unwrap(); + let b = open_flushed_dataset(&uri, None, None).await.unwrap(); + assert!( + !Arc::ptr_eq(&a, &b), + "no-cache path must cold-open each call" + ); + assert_eq!(a.count_rows(None).await.unwrap(), 3); + + // With a cache, the second call is a shared clone. + let cache = Arc::new(FlushedMemTableCache::new(8)); + let c = open_flushed_dataset(&uri, None, Some(&cache)) + .await + .unwrap(); + let d = open_flushed_dataset(&uri, None, Some(&cache)) + .await + .unwrap(); + assert!(Arc::ptr_eq(&c, &d), "cached path must reuse the Arc"); + } +} diff --git a/rust/lance/src/dataset/mem_wal/scanner/planner.rs b/rust/lance/src/dataset/mem_wal/scanner/planner.rs index 2f93f1c35a7..98daf2bdda9 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/planner.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/planner.rs @@ -19,9 +19,11 @@ use tracing::instrument; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; use super::exec::{DeduplicateExec, MEMTABLE_GEN_COLUMN, MemtableGenTagExec, ROW_ADDRESS_COLUMN}; +use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical, }; +use crate::session::Session; /// Plans scan queries over LSM data. pub struct LsmScanPlanner { @@ -31,6 +33,10 @@ pub struct LsmScanPlanner { pk_columns: Vec, /// Schema of the base table. base_schema: SchemaRef, + /// Session threaded into flushed-generation opens (shared caches). + session: Option>, + /// Cache of opened flushed-generation datasets. + flushed_cache: Option>, } impl LsmScanPlanner { @@ -44,9 +50,25 @@ impl LsmScanPlanner { collector, pk_columns, base_schema, + session: None, + flushed_cache: None, } } + /// Thread a session into flushed-generation opens so the first open + /// populates the shared index / file-metadata caches. + pub fn with_session(mut self, session: Arc) -> Self { + self.session = Some(session); + self + } + + /// Inject a cache of opened flushed-generation datasets, making repeated + /// queries against the same generation a pure `Arc::clone`. + pub fn with_flushed_cache(mut self, cache: Arc) -> Self { + self.flushed_cache = Some(cache); + self + } + /// Create scan plan with deduplication. /// /// # Arguments @@ -351,9 +373,9 @@ impl LsmScanPlanner { scanner.create_plan().await } LsmDataSource::FlushedMemTable { path, .. } => { - let dataset = crate::dataset::DatasetBuilder::from_uri(path) - .load() - .await?; + let dataset = + open_flushed_dataset(path, self.session.as_ref(), self.flushed_cache.as_ref()) + .await?; let mut scanner = dataset.scan(); let cols = diff --git a/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs b/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs index 91b19ff480d..a70f200f14e 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs @@ -19,10 +19,12 @@ use tracing::instrument; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; use super::exec::{BloomFilterGuardExec, CoalesceFirstExec, compute_pk_hash_from_scalars}; +use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ build_scanner_projection, canonical_output_schema, project_to_canonical, wants_row_address, wants_row_id, }; +use crate::session::Session; /// Plans point lookup queries over LSM data. /// @@ -70,6 +72,10 @@ pub struct LsmPointLookupPlanner { /// Bloom filters for each memtable generation. /// Map: generation -> bloom filter bloom_filters: std::collections::HashMap>, + /// Session threaded into flushed-generation opens (shared caches). + session: Option>, + /// Cache of opened flushed-generation datasets. + flushed_cache: Option>, } impl LsmPointLookupPlanner { @@ -90,9 +96,25 @@ impl LsmPointLookupPlanner { pk_columns, base_schema, bloom_filters: std::collections::HashMap::new(), + session: None, + flushed_cache: None, } } + /// Thread a session into flushed-generation opens so the first open + /// populates the shared index / file-metadata caches. + pub fn with_session(mut self, session: Arc) -> Self { + self.session = Some(session); + self + } + + /// Inject a cache of opened flushed-generation datasets, making repeated + /// lookups against the same generation a pure `Arc::clone`. + pub fn with_flushed_cache(mut self, cache: Arc) -> Self { + self.flushed_cache = Some(cache); + self + } + /// Add a bloom filter for a generation. /// /// Bloom filters are optional but improve performance by skipping @@ -234,9 +256,9 @@ impl LsmPointLookupPlanner { scanner.create_plan().await? } LsmDataSource::FlushedMemTable { path, .. } => { - let dataset = crate::dataset::DatasetBuilder::from_uri(path) - .load() - .await?; + let dataset = + open_flushed_dataset(path, self.session.as_ref(), self.flushed_cache.as_ref()) + .await?; let mut scanner = dataset.scan(); scanner.project(&cols.iter().map(|s| s.as_str()).collect::>())?; scanner.filter_expr(filter.clone()); diff --git a/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs b/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs index ab7b4b7b50e..f6d50f101d0 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs @@ -24,10 +24,12 @@ use tracing::instrument; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; use super::exec::{FilterStaleExec, GenerationBloomFilter, MemtableGenTagExec}; +use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ DISTANCE_COLUMN, build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical, wants_row_id, }; +use crate::session::Session; /// Plans vector search queries over LSM data. /// @@ -90,6 +92,10 @@ pub struct LsmVectorSearchPlanner { /// original vectors. Set this to make cross-source distance comparison /// across the LSM merge fully exact. base_table_refine_factor: Option, + /// Session threaded into flushed-generation opens (shared caches). + session: Option>, + /// Cache of opened flushed-generation datasets. + flushed_cache: Option>, } impl LsmVectorSearchPlanner { @@ -117,9 +123,25 @@ impl LsmVectorSearchPlanner { vector_column, distance_type, base_table_refine_factor: None, + session: None, + flushed_cache: None, } } + /// Thread a session into flushed-generation opens so the first open + /// populates the shared index / file-metadata caches. + pub fn with_session(mut self, session: Arc) -> Self { + self.session = Some(session); + self + } + + /// Inject a cache of opened flushed-generation datasets, making repeated + /// searches against the same generation a pure `Arc::clone`. + pub fn with_flushed_cache(mut self, cache: Arc) -> Self { + self.flushed_cache = Some(cache); + self + } + /// Enable base-table refine. /// /// When set, the base-table arm of the KNN plan asks the scanner for @@ -315,9 +337,9 @@ impl LsmVectorSearchPlanner { scanner.create_plan().await } LsmDataSource::FlushedMemTable { path, .. } => { - let dataset = crate::dataset::DatasetBuilder::from_uri(path) - .load() - .await?; + let dataset = + open_flushed_dataset(path, self.session.as_ref(), self.flushed_cache.as_ref()) + .await?; let mut scanner = dataset.scan(); let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns); From bf681716b986617b5a89280a032a327c08211f51 Mon Sep 17 00:00:00 2001 From: Jack Ye Date: Wed, 20 May 2026 14:57:32 -0700 Subject: [PATCH 12/23] feat: add MemWAL sharding evaluator (#6854) Adds an Arrow-native MemWAL sharding evaluator and exposes it through the Java API/JNI. - Evaluates MemWAL sharding specs against Arrow RecordBatch values for bucket, identity, and unsharded fields. - Resolves sharding source IDs through a Java-provided source-id-to-column map. - Adds Java-facing ShardingEvaluator returning an Arrow reader for the evaluated sharding key batch. This is needed by lance-spark to route writes using Lance's sharding semantics instead of duplicating Spark-side bucket logic. --- java/lance-jni/src/mem_wal.rs | 143 ++++- .../org/lance/memwal/ShardingEvaluator.java | 124 ++++ .../java/org/lance/memwal/MemWalTest.java | 105 ++++ python/python/lance/__init__.py | 18 +- python/python/lance/dataset.py | 38 +- python/python/lance/lance/__init__.pyi | 76 +++ python/python/lance/mem_wal.py | 164 +++--- python/python/tests/test_mem_wal.py | 112 +++- python/src/dataset.rs | 15 +- python/src/lib.rs | 5 +- python/src/mem_wal.rs | 183 +++--- rust/lance/src/dataset/mem_wal.rs | 5 + rust/lance/src/dataset/mem_wal/api.rs | 64 ++- rust/lance/src/dataset/mem_wal/sharding.rs | 542 ++++++++++++++++++ rust/lance/src/dataset/mem_wal/write.rs | 17 +- 15 files changed, 1375 insertions(+), 236 deletions(-) create mode 100644 java/src/main/java/org/lance/memwal/ShardingEvaluator.java create mode 100644 rust/lance/src/dataset/mem_wal/sharding.rs diff --git a/java/lance-jni/src/mem_wal.rs b/java/lance-jni/src/mem_wal.rs index d5f9da750ab..6457e5b1419 100644 --- a/java/lance-jni/src/mem_wal.rs +++ b/java/lance-jni/src/mem_wal.rs @@ -7,6 +7,7 @@ //! [`ShardWriter`], an LSM-aware [`LsmScanner`], an [`ExecutionPlan`] wrapper, //! and the point-lookup / vector-search planners. +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -21,7 +22,7 @@ use datafusion::common::ScalarValue; use datafusion::physical_plan::{ExecutionPlan, collect, displayable}; use datafusion::prelude::SessionContext; use jni::JNIEnv; -use jni::objects::{JClass, JObject, JString, JValueGen}; +use jni::objects::{JClass, JMap, JObject, JString, JValueGen}; use jni::sys::{jint, jlong}; use lance::dataset::Dataset as LanceDataset; use lance::dataset::mem_wal::scanner::{ @@ -30,6 +31,7 @@ use lance::dataset::mem_wal::scanner::{ use lance::dataset::mem_wal::write::{MemTableStats, WriteStatsSnapshot}; use lance::dataset::mem_wal::{ DatasetMemWalExt, LsmScanner, ShardSnapshot, ShardWriter, ShardWriterConfig, + evaluate_sharding_spec_with_source_columns, }; use lance::dataset::scanner::DatasetRecordBatchStream; use lance_index::mem_wal::{MemWalIndexDetails, ShardManifest, ShardingField, ShardingSpec}; @@ -42,6 +44,7 @@ use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; use crate::traits::{IntoJava, export_vec, import_vec, import_vec_to_rust}; +use crate::utils::to_rust_map; const NATIVE_SHARD_WRITER: &str = "nativeShardWriterHandle"; const NATIVE_LSM_SCANNER: &str = "nativeLsmScannerHandle"; @@ -599,6 +602,63 @@ fn inner_plan_open_stream(env: &mut JNIEnv, this: JObject, stream_addr: jlong) - Ok(()) } +////////////////////////// +// Sharding evaluation // +////////////////////////// + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_memwal_ShardingEvaluator_nativeEvaluate<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + arrow_array_addr: jlong, + arrow_schema_addr: jlong, + sharding_spec: JObject<'local>, + source_id_to_column: JObject<'local>, + stream_addr: jlong, +) { + ok_or_throw_without_return!( + env, + inner_evaluate_sharding( + &mut env, + arrow_array_addr, + arrow_schema_addr, + sharding_spec, + source_id_to_column, + stream_addr, + ) + ); +} + +fn inner_evaluate_sharding( + env: &mut JNIEnv, + arrow_array_addr: jlong, + arrow_schema_addr: jlong, + sharding_spec: JObject, + source_id_to_column: JObject, + stream_addr: jlong, +) -> Result<()> { + let input = import_ffi_array(arrow_array_addr, arrow_schema_addr)?; + let struct_array = input + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::input_error( + "ShardingEvaluator expects a VectorSchemaRoot struct array".to_string(), + ) + })?; + let input_batch = RecordBatch::from(struct_array.clone()); + let spec = sharding_spec_from_java(env, &sharding_spec)?; + let source_id_to_column = i32_string_map_from_java(env, source_id_to_column)?; + let result = + evaluate_sharding_spec_with_source_columns(&input_batch, &spec, &source_id_to_column)?; + let schema = result.schema(); + let reader: Box = + Box::new(RecordBatchIterator::new([Ok(result)].into_iter(), schema)); + let ffi_stream = FFI_ArrowArrayStream::new(reader); + unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } + Ok(()) +} + #[unsafe(no_mangle)] pub extern "system" fn Java_org_lance_memwal_ExecutionPlan_releaseNativeExecutionPlan( mut env: JNIEnv, @@ -1301,6 +1361,87 @@ fn sharding_field_to_java<'a>(env: &mut JNIEnv<'a>, field: &ShardingField) -> Re )?) } +fn sharding_spec_from_java(env: &mut JNIEnv, spec: &JObject) -> Result { + if spec.is_null() { + return Err(Error::input_error( + "ShardingSpec must not be null".to_string(), + )); + } + + env.with_local_frame(32, |env| { + let spec_id = env.call_method(spec, "specId", "()I", &[])?.i()? as u32; + let fields_obj = env + .call_method(spec, "fields", "()Ljava/util/List;", &[])? + .l()?; + let fields_list = env.get_list(&fields_obj)?; + let mut iter = fields_list.iter(env)?; + let mut fields = Vec::with_capacity(fields_list.size(env)? as usize); + while let Some(field_obj) = iter.next(env)? { + fields.push(sharding_field_from_java(env, &field_obj)?); + } + + Ok::<_, Error>(ShardingSpec { spec_id, fields }) + }) +} + +fn sharding_field_from_java(env: &mut JNIEnv, field: &JObject) -> Result { + env.with_local_frame(32, |env| { + let field_id = string_from_method(env, field, "fieldId")?; + let source_ids_obj = env + .call_method(field, "sourceIds", "()Ljava/util/List;", &[])? + .l()?; + let source_ids = env.get_integers(&source_ids_obj)?; + let transform = env.get_optional_string_from_method(field, "transform")?; + let expression = env.get_optional_string_from_method(field, "expression")?; + let result_type = string_from_method(env, field, "resultType")?; + let parameters_obj = env + .call_method(field, "parameters", "()Ljava/util/Map;", &[])? + .l()?; + let parameters = if parameters_obj.is_null() { + HashMap::new() + } else { + let parameters_map = JMap::from_env(env, ¶meters_obj)?; + to_rust_map(env, ¶meters_map)? + }; + + Ok::<_, Error>(ShardingField { + field_id, + source_ids, + transform, + expression, + result_type, + parameters, + }) + }) +} + +fn i32_string_map_from_java(env: &mut JNIEnv, map_obj: JObject) -> Result> { + if map_obj.is_null() { + return Ok(HashMap::new()); + } + + env.with_local_frame(32, |env| { + let jmap = JMap::from_env(env, &map_obj)?; + let mut iter = jmap.iter(env)?; + let mut map = HashMap::new(); + while let Some((key, value)) = iter.next(env)? { + let key = env.call_method(&key, "intValue", "()I", &[])?.i()?; + let value = JString::from(value); + let value: String = env.get_string(&value)?.into(); + map.insert(key, value); + } + Ok::<_, Error>(map) + }) +} + +fn string_from_method(env: &mut JNIEnv, obj: &JObject, method: &str) -> Result { + let value = env + .call_method(obj, method, "()Ljava/lang/String;", &[])? + .l()?; + let value = JString::from(value); + Ok(env.get_string(&value)?.into()) +} + fn int_list_to_java<'a>(env: &mut JNIEnv<'a>, ints: &[i32]) -> Result> { let list = env.new_object("java/util/ArrayList", "()V", &[])?; for &value in ints { diff --git a/java/src/main/java/org/lance/memwal/ShardingEvaluator.java b/java/src/main/java/org/lance/memwal/ShardingEvaluator.java new file mode 100644 index 00000000000..50e6cdb1ea3 --- /dev/null +++ b/java/src/main/java/org/lance/memwal/ShardingEvaluator.java @@ -0,0 +1,124 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.memwal; + +import org.lance.JniLoader; +import org.lance.schema.LanceField; +import org.lance.schema.LanceSchema; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** Evaluates MemWAL sharding specs against Arrow record batches. */ +public final class ShardingEvaluator { + static { + JniLoader.ensureLoaded(); + } + + private ShardingEvaluator() {} + + /** + * Evaluate {@code spec} against {@code root}. + * + * @param allocator allocator used for Arrow C data interface structs + * @param root input record batch + * @param spec MemWAL sharding spec to evaluate + * @param schema Lance table schema used to resolve spec source field IDs to input column names + * @return an Arrow reader containing one result batch with the derived sharding fields + */ + public static ArrowReader evaluate( + BufferAllocator allocator, VectorSchemaRoot root, ShardingSpec spec, LanceSchema schema) { + Preconditions.checkNotNull(allocator, "allocator must not be null"); + Preconditions.checkNotNull(root, "root must not be null"); + Preconditions.checkNotNull(spec, "spec must not be null"); + Preconditions.checkNotNull(schema, "schema must not be null"); + + try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray arrowArray = ArrowArray.allocateNew(allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); + nativeEvaluate( + arrowArray.memoryAddress(), + arrowSchema.memoryAddress(), + spec, + sourceIdToColumnMap(schema), + stream.memoryAddress()); + return Data.importArrayStream(allocator, stream); + } + } + + /** Evaluate {@code spec} against {@code root} when the spec embeds enough column information. */ + public static ArrowReader evaluate( + BufferAllocator allocator, VectorSchemaRoot root, ShardingSpec spec) { + return evaluate(allocator, root, spec, Collections.emptyMap()); + } + + private static ArrowReader evaluate( + BufferAllocator allocator, + VectorSchemaRoot root, + ShardingSpec spec, + Map sourceIdToColumn) { + Preconditions.checkNotNull(allocator, "allocator must not be null"); + Preconditions.checkNotNull(root, "root must not be null"); + Preconditions.checkNotNull(spec, "spec must not be null"); + Preconditions.checkNotNull(sourceIdToColumn, "sourceIdToColumn must not be null"); + + try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray arrowArray = ArrowArray.allocateNew(allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); + nativeEvaluate( + arrowArray.memoryAddress(), + arrowSchema.memoryAddress(), + spec, + sourceIdToColumn, + stream.memoryAddress()); + return Data.importArrayStream(allocator, stream); + } + } + + private static Map sourceIdToColumnMap(LanceSchema schema) { + Map result = new HashMap<>(); + for (LanceField field : schema.fields()) { + collectFieldIds(field, "", result); + } + return result; + } + + private static void collectFieldIds( + LanceField field, String prefix, Map result) { + String fullName = prefix.isEmpty() ? field.getName() : prefix + "." + field.getName(); + result.put(field.getId(), fullName); + for (LanceField child : field.getChildren()) { + collectFieldIds(child, fullName, result); + } + } + + private static native void nativeEvaluate( + long arrowArrayAddress, + long arrowSchemaAddress, + ShardingSpec spec, + Map sourceIdToColumn, + long streamAddress); +} diff --git a/java/src/test/java/org/lance/memwal/MemWalTest.java b/java/src/test/java/org/lance/memwal/MemWalTest.java index d3aaebf8891..ee26932dd59 100644 --- a/java/src/test/java/org/lance/memwal/MemWalTest.java +++ b/java/src/test/java/org/lance/memwal/MemWalTest.java @@ -194,6 +194,111 @@ void testInitializeMemWalBucketShardingWithoutPrimaryKey(@TempDir Path tempDir) } } + @Test + void testInitializeMemWalBucketShardingUsesConfiguredColumn(@TempDir Path tempDir) + throws Exception { + String path = tempDir.resolve("base").toString(); + try (BufferAllocator allocator = new RootAllocator(); + Dataset dataset = writeLookupDataset(allocator, path, new long[] {1, 2, 3}, "base")) { + dataset.initializeMemWal(new InitializeMemWalParams().withBucketSharding("name", 4)); + + MemWalIndexDetails details = dataset.memWalIndexDetails().get(); + ShardingField field = details.shardingSpecs().get(0).fields().get(0); + int nameFieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals("name")) + .findFirst() + .get() + .getId(); + assertEquals("bucket", field.transform().get()); + assertEquals(nameFieldId, field.sourceIds().get(0)); + } + } + + @Test + void testShardingEvaluatorBucketAndIdentity(@TempDir Path tempDir) throws Exception { + String path = tempDir.resolve("append_only").toString(); + try (BufferAllocator allocator = new RootAllocator(); + Dataset dataset = writeAppendOnlyDataset(allocator, path, new long[] {1}, "base")) { + dataset.initializeMemWal(new InitializeMemWalParams().withBucketSharding("id", 4)); + ShardingSpec bucketSpec = dataset.memWalIndexDetails().get().shardingSpecs().get(0); + ShardingField bucketField = bucketSpec.fields().get(0); + + try (VectorSchemaRoot root = appendOnlyRoot(allocator, new long[] {1, 2, 3}, "eval"); + ArrowReader reader = + ShardingEvaluator.evaluate(allocator, root, bucketSpec, dataset.getLanceSchema())) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + IntVector buckets = (IntVector) result.getVector(bucketField.fieldId()); + assertEquals(3, result.getRowCount()); + assertEquals(0, buckets.get(0)); + assertEquals(0, buckets.get(1)); + assertEquals(3, buckets.get(2)); + assertFalse(reader.loadNextBatch()); + } + + int nameFieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals("name")) + .findFirst() + .get() + .getId(); + ShardingSpec identitySpec = + new ShardingSpec( + 7, + Collections.singletonList( + new ShardingField( + "name_identity", + Collections.singletonList(nameFieldId), + "identity", + null, + "utf8", + Collections.emptyMap()))); + try (VectorSchemaRoot root = appendOnlyRoot(allocator, new long[] {1}, "eval"); + ArrowReader reader = + ShardingEvaluator.evaluate(allocator, root, identitySpec, dataset.getLanceSchema())) { + assertTrue(reader.loadNextBatch()); + VarCharVector names = + (VarCharVector) reader.getVectorSchemaRoot().getVector("name_identity"); + assertEquals("eval_1", new String(names.get(0), StandardCharsets.UTF_8)); + assertFalse(reader.loadNextBatch()); + } + + Map stringBucketParameters = new HashMap<>(); + stringBucketParameters.put("column", "key"); + stringBucketParameters.put("num_buckets", "8"); + ShardingSpec stringBucketSpec = + new ShardingSpec( + 8, + Collections.singletonList( + new ShardingField( + "key_bucket", + Collections.emptyList(), + "bucket", + null, + "int32", + stringBucketParameters))); + Schema stringSchema = + new Schema(Collections.singletonList(Field.nullable("key", new ArrowType.Utf8()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(stringSchema, allocator)) { + VarCharVector keyVector = (VarCharVector) root.getVector("key"); + keyVector.allocateNew(); + keyVector.setSafe(0, "a".getBytes(StandardCharsets.UTF_8)); + keyVector.setSafe(1, "b".getBytes(StandardCharsets.UTF_8)); + keyVector.setNull(2); + root.setRowCount(3); + try (ArrowReader reader = ShardingEvaluator.evaluate(allocator, root, stringBucketSpec)) { + assertTrue(reader.loadNextBatch()); + IntVector buckets = (IntVector) reader.getVectorSchemaRoot().getVector("key_bucket"); + assertEquals(1, buckets.get(0)); + assertEquals(5, buckets.get(1)); + assertEquals(0, buckets.get(2)); + assertFalse(reader.loadNextBatch()); + } + } + } + } + @Test void testInitializeMemWalRejectsConflictingSharding(@TempDir Path tempDir) throws Exception { String path = tempDir.resolve("base").toString(); diff --git a/python/python/lance/__init__.py b/python/python/lance/__init__.py index 78f97418133..f58b169a47a 100644 --- a/python/python/lance/__init__.py +++ b/python/python/lance/__init__.py @@ -39,10 +39,11 @@ LsmScanner, LsmVectorSearchPlanner, MergedGeneration, - RegionField, - RegionSnapshot, - RegionSpec, - RegionWriter, + ShardingField, + ShardingSpec, + ShardSnapshot, + ShardWriter, + evaluate_sharding_spec, ) from .namespace import ( DescribeTableRequest, @@ -99,10 +100,11 @@ "LsmScanner", "LsmVectorSearchPlanner", "MergedGeneration", - "RegionField", - "RegionSpec", - "RegionSnapshot", - "RegionWriter", + "ShardSnapshot", + "ShardWriter", + "ShardingField", + "ShardingSpec", + "evaluate_sharding_spec", ] diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index a57781480c1..5737ec013b5 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -4740,19 +4740,19 @@ def initialize_mem_wal( ) -> None: """Initialize MemWAL on this dataset. - Must be called once before any calls to `mem_wal_writer`. The dataset - schema must have at least one field annotated with the - ``lance-schema:unenforced-primary-key`` Arrow field metadata. + Must be called once before any calls to `mem_wal_writer`. Append-only + tables may omit primary-key metadata; primary keys are only required + for primary-key lookup and last-write-wins deduplication workflows. At most one sharding mode may be selected: bucket sharding (``bucket_column`` + ``num_buckets``), identity sharding (``identity_column``), or ``unsharded``. With none selected, shards are - managed manually by passing region IDs to `mem_wal_writer`. + managed manually by passing shard IDs to `mem_wal_writer`. Any writer-configuration keyword arguments (``durable_write``, ``max_memtable_size``, ``max_wal_flush_interval_ms``, etc. — the same knobs accepted by `mem_wal_writer`) are recorded as the default - `~lance.mem_wal.RegionWriter` configuration in the MemWAL index, so + `~lance.mem_wal.ShardWriter` configuration in the MemWAL index, so every writer starts from the same defaults. Parameters @@ -4761,8 +4761,7 @@ def initialize_mem_wal( Names of existing indexes to keep updated as data is written through the MemWAL. Must reference indexes that already exist. bucket_column : str, optional - With ``num_buckets``, hash-bucket writes by this column, which must - be the single-column unenforced primary key. + With ``num_buckets``, hash-bucket writes by this scalar column. num_buckets : int, optional Number of hash buckets (shards). Required with ``bucket_column``. identity_column : str, optional @@ -4773,7 +4772,6 @@ def initialize_mem_wal( Raises ------ IOError - - Dataset has no ``lance-schema:unenforced-primary-key`` field. - An entry in *maintained_indexes* does not exist on the dataset. - MemWAL has already been initialized on this dataset. ValueError @@ -4815,7 +4813,7 @@ def mem_wal_index_details(self) -> Optional[dict]: def mem_wal_writer( self, - region_id: str, + shard_id: str, *, durable_write: Optional[bool] = None, sync_indexed_write: Optional[bool] = None, @@ -4830,17 +4828,17 @@ def mem_wal_writer( async_index_interval_ms: Optional[int] = None, backpressure_log_interval_ms: Optional[int] = None, stats_log_interval_ms: Optional[int] = None, - ) -> "mem_wal.RegionWriter": - """Get a RegionWriter for the specified region. + ) -> "mem_wal.ShardWriter": + """Get a ShardWriter for the specified shard. `initialize_mem_wal` must be called before using this method. - Each *region* is an independent write shard; use different region IDs + Each shard is an independent write path; use different shard IDs to achieve parallel ingestion without writer contention. Parameters ---------- - region_id : str - UUID string identifying the write region (e.g. + shard_id : str + UUID string identifying the write shard (e.g. ``str(uuid.uuid4())``). durable_write : bool, optional Whether to fsync WAL writes (default: ``True``). @@ -4873,8 +4871,8 @@ def mem_wal_writer( Returns ------- - RegionWriter - A context-manager-compatible writer for the specified region. + ShardWriter + A context-manager-compatible writer for the specified shard. Examples -------- @@ -4893,9 +4891,9 @@ def mem_wal_writer( ... tmpdir, ... ) ... ds.initialize_mem_wal() - ... region_id = str(uuid.uuid4()) + ... shard_id = str(uuid.uuid4()) ... new_data = pa.table({"id": [2], "val": [0.2]}, schema=schema) - ... with ds.mem_wal_writer(region_id) as writer: + ... with ds.mem_wal_writer(shard_id) as writer: ... writer.put(new_data) """ import lance.mem_wal as _mw @@ -4919,8 +4917,8 @@ def mem_wal_writer( ] if val is not None } - raw = self._ds.mem_wal_writer(region_id, **kwargs) - return _mw.RegionWriter(raw) + raw = self._ds.mem_wal_writer(shard_id, **kwargs) + return _mw.ShardWriter(raw) class SqlQuery: diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 6edcef5e080..39f22e8aded 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -578,6 +578,82 @@ def _json_to_schema(schema_json: str) -> pa.Schema: ... def _schema_to_json(schema: pa.Schema) -> str: ... def _parse_field_path(path: str) -> list[str]: ... def _format_field_path(segments: list[str]) -> str: ... +def _evaluate_sharding_spec( + batch: pa.RecordBatch, + spec: Dict[str, Any], + schema: LanceSchema, +) -> pa.RecordBatch: ... + +class _MergedGeneration: + shard_id: str + generation: int + def __init__(self, shard_id: str, generation: int) -> None: ... + +class _ShardSnapshot: + shard_id: str + def __init__(self, shard_id: str) -> None: ... + def with_spec_id(self, spec_id: int) -> Self: ... + def with_current_generation(self, generation: int) -> Self: ... + def with_flushed_generation(self, generation: int, path: str) -> Self: ... + +class _ShardWriter: + shard_id: str + def put(self, data: Any) -> None: ... + def close(self) -> None: ... + def stats(self) -> Dict[str, Any]: ... + def memtable_stats(self) -> Dict[str, Any]: ... + def lsm_scanner( + self, shard_snapshots: Optional[List[_ShardSnapshot]] = None + ) -> _LsmScanner: ... + +class _LsmScanner: + @staticmethod + def from_snapshots( + dataset: _Dataset, shard_snapshots: List[_ShardSnapshot] + ) -> _LsmScanner: ... + def project(self, columns: List[str]) -> Self: ... + def filter(self, expr: str) -> Self: ... + def limit(self, n: int, offset: Optional[int] = None) -> Self: ... + def with_row_address(self) -> Self: ... + def with_memtable_gen(self) -> Self: ... + def to_batch(self) -> pa.RecordBatch: ... + def to_batches(self) -> List[pa.RecordBatch]: ... + def count_rows(self) -> int: ... + +class _ExecutionPlan: + schema: pa.Schema + dataset_schema: pa.Schema + def explain(self) -> str: ... + def to_reader(self) -> pa.RecordBatchReader: ... + def to_batches(self) -> List[pa.RecordBatch]: ... + +class _LsmPointLookupPlanner: + def __init__( + self, + dataset: _Dataset, + shard_snapshots: List[_ShardSnapshot], + pk_columns: Optional[List[str]] = None, + ) -> None: ... + def plan_lookup( + self, pk_value: pa.Array, columns: Optional[List[str]] = None + ) -> _ExecutionPlan: ... + +class _LsmVectorSearchPlanner: + def __init__( + self, + dataset: _Dataset, + shard_snapshots: List[_ShardSnapshot], + vector_column: str, + pk_columns: Optional[List[str]] = None, + distance_type: Optional[str] = None, + ) -> None: ... + def plan_search( + self, + query: pa.Array, + k: int = 10, + nprobes: int = 20, + columns: Optional[List[str]] = None, + ) -> _ExecutionPlan: ... class _Hnsw: @staticmethod diff --git a/python/python/lance/mem_wal.py b/python/python/lance/mem_wal.py index 87609435e53..2ca293d790d 100644 --- a/python/python/lance/mem_wal.py +++ b/python/python/lance/mem_wal.py @@ -15,19 +15,20 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Union import pyarrow as pa from .lance import ( + _evaluate_sharding_spec, _ExecutionPlan, _LsmPointLookupPlanner, _LsmScanner, _LsmVectorSearchPlanner, _MergedGeneration, - _RegionSnapshot, - _RegionWriter, + _ShardSnapshot, + _ShardWriter, ) from .types import _coerce_reader @@ -35,11 +36,12 @@ import lance __all__ = [ - "RegionField", - "RegionSpec", + "ShardingField", + "ShardingSpec", + "evaluate_sharding_spec", "MergedGeneration", - "RegionSnapshot", - "RegionWriter", + "ShardSnapshot", + "ShardWriter", "LsmScanner", "ExecutionPlan", "LsmPointLookupPlanner", @@ -48,13 +50,13 @@ # --------------------------------------------------------------------------- -# RegionSpec +# ShardingSpec # --------------------------------------------------------------------------- @dataclass -class RegionField: - """Defines one derived field used in region partitioning. +class ShardingField: + """Defines one MemWAL sharding field. Parameters ---------- @@ -81,11 +83,43 @@ class RegionField: @dataclass -class RegionSpec: - """Partitioning specification for deriving MemWAL region IDs.""" +class ShardingSpec: + """Specification for deriving MemWAL shard routing values.""" spec_id: int - fields: List[RegionField] + fields: List[ShardingField] + + +def evaluate_sharding_spec( + batch: pa.RecordBatch, + spec: Union[ShardingSpec, Mapping[str, object]], + schema: "lance.schema.LanceSchema", +) -> pa.RecordBatch: + """Evaluate a MemWAL sharding spec against one PyArrow RecordBatch. + + Parameters + ---------- + batch : pyarrow.RecordBatch + Input batch containing the sharding source columns. + spec : ShardingSpec or dict + MemWAL sharding spec to evaluate. + schema : LanceSchema + Lance table schema used to resolve source field IDs in the spec to + input batch column names. + """ + if not isinstance(batch, pa.RecordBatch): + raise TypeError(f"Expected pyarrow.RecordBatch, got {type(batch)!r}") + return _evaluate_sharding_spec( + batch, + _sharding_spec_to_dict(spec), + schema, + ) + + +def _sharding_spec_to_dict(spec: Union[ShardingSpec, Mapping[str, object]]) -> dict: + if isinstance(spec, Mapping): + return dict(spec) + return asdict(spec) @dataclass @@ -97,45 +131,45 @@ class MergedGeneration: Parameters ---------- - region_id : str - UUID string for the write region. + shard_id : str + UUID string for the write shard. generation : int Generation number (from - :attr:`RegionSnapshot.flushed_generations`). + :attr:`ShardSnapshot.flushed_generations`). """ - region_id: str + shard_id: str generation: int -class RegionSnapshot: - """Snapshot of a MemWAL region's state, used when constructing scanners. +class ShardSnapshot: + """Snapshot of a MemWAL shard's state, used when constructing scanners. Parameters ---------- - region_id : str - UUID string for the write region. + shard_id : str + UUID string for the write shard. """ - def __init__(self, region_id: str) -> None: - self._raw = _RegionSnapshot(region_id) + def __init__(self, shard_id: str) -> None: + self._raw = _ShardSnapshot(shard_id) @property - def region_id(self) -> str: - """UUID string for this region.""" - return self._raw.region_id + def shard_id(self) -> str: + """UUID string for this shard.""" + return self._raw.shard_id - def with_spec_id(self, spec_id: int) -> "RegionSnapshot": - """Set the RegionSpec ID.""" + def with_spec_id(self, spec_id: int) -> "ShardSnapshot": + """Set the sharding spec ID.""" self._raw = self._raw.with_spec_id(spec_id) return self - def with_current_generation(self, generation: int) -> "RegionSnapshot": + def with_current_generation(self, generation: int) -> "ShardSnapshot": """Set the current (active) generation number.""" self._raw = self._raw.with_current_generation(generation) return self - def with_flushed_generation(self, generation: int, path: str) -> "RegionSnapshot": + def with_flushed_generation(self, generation: int, path: str) -> "ShardSnapshot": """Add a flushed generation with its storage path.""" self._raw = self._raw.with_flushed_generation(generation, path) return self @@ -144,28 +178,28 @@ def __repr__(self) -> str: return repr(self._raw) -class RegionWriter: - """Stateful writer for one MemWAL region. +class ShardWriter: + """Stateful writer for one MemWAL shard. Obtain an instance via mem_wal_writer. Use as a context manager so the writer is closed automatically:: - with dataset.mem_wal_writer(region_id) as writer: + with dataset.mem_wal_writer(shard_id) as writer: writer.put(batch) Parameters ---------- - _raw : _RegionWriter + _raw : _ShardWriter Internal PyO3 object — do not construct directly. """ - def __init__(self, _raw: _RegionWriter) -> None: + def __init__(self, _raw: _ShardWriter) -> None: self._raw = _raw @property - def region_id(self) -> str: - """UUID string for this writer's region.""" - return self._raw.region_id + def shard_id(self) -> str: + """UUID string for this writer's shard.""" + return self._raw.shard_id def put(self, data, *, schema: Optional[pa.Schema] = None) -> None: """Write data to the MemWAL. @@ -222,7 +256,7 @@ def memtable_stats(self) -> dict: return self._raw.memtable_stats() def lsm_scanner( - self, region_snapshots: Optional[List[RegionSnapshot]] = None + self, shard_snapshots: Optional[List[ShardSnapshot]] = None ) -> "LsmScanner": """Create an LSM scanner that includes the active MemTable. @@ -232,18 +266,18 @@ def lsm_scanner( Parameters ---------- - region_snapshots : list of RegionSnapshot, optional - Snapshots of other regions to include. This writer's own region + shard_snapshots : list of ShardSnapshot, optional + Snapshots of other shards to include. This writer's own shard is automatically included. Returns ------- LsmScanner """ - raw_snaps = [s._raw for s in (region_snapshots or [])] + raw_snaps = [s._raw for s in (shard_snapshots or [])] return LsmScanner(self._raw.lsm_scanner(raw_snaps)) - def __enter__(self) -> "RegionWriter": + def __enter__(self) -> "ShardWriter": return self def __exit__(self, exc_type, exc_val, exc_tb) -> bool: @@ -257,7 +291,7 @@ class LsmScanner: Deduplicates by primary key, always returning the newest version of each row across base table, flushed MemTables, and the active MemTable. - Obtain an instance from `RegionWriter.lsm_scanner` (includes + Obtain an instance from `ShardWriter.lsm_scanner` (includes active MemTable) or `LsmScanner.from_snapshots` (flushed only). The builder methods (`project`, `filter`, `limit`) @@ -276,23 +310,21 @@ def __init__(self, _raw: _LsmScanner) -> None: @staticmethod def from_snapshots( dataset: "lance.LanceDataset", - region_snapshots: List[RegionSnapshot], + shard_snapshots: List[ShardSnapshot], ) -> "LsmScanner": - """Create a scanner from dataset and region snapshots. + """Create a scanner from dataset and shard snapshots. Does **not** include the active MemTable; use - `RegionWriter.lsm_scanner` for that. + `ShardWriter.lsm_scanner` for that. Parameters ---------- dataset : LanceDataset The base dataset to scan. - region_snapshots : list of RegionSnapshot - Region snapshots specifying flushed generations to include. + shard_snapshots : list of ShardSnapshot + Shard snapshots specifying flushed generations to include. """ - raw = _LsmScanner.from_snapshots( - dataset._ds, [s._raw for s in region_snapshots] - ) + raw = _LsmScanner.from_snapshots(dataset._ds, [s._raw for s in shard_snapshots]) return LsmScanner(raw) def project(self, columns: List[str]) -> "LsmScanner": @@ -389,8 +421,8 @@ class LsmPointLookupPlanner: ---------- dataset : LanceDataset The base dataset. - region_snapshots : list of RegionSnapshot - Region snapshots specifying flushed generations to include. + shard_snapshots : list of ShardSnapshot + Shard snapshots specifying flushed generations to include. pk_columns : list of str, optional Primary key column names. Inferred from schema metadata if omitted. @@ -404,12 +436,12 @@ class LsmPointLookupPlanner: def __init__( self, dataset: "lance.LanceDataset", - region_snapshots: List[RegionSnapshot], + shard_snapshots: List[ShardSnapshot], pk_columns: Optional[List[str]] = None, ) -> None: self._raw = _LsmPointLookupPlanner( dataset._ds, - [s._raw for s in region_snapshots], + [s._raw for s in shard_snapshots], pk_columns, ) @@ -448,8 +480,8 @@ class LsmVectorSearchPlanner: ---------- dataset : LanceDataset The base dataset. - region_snapshots : list of RegionSnapshot - Region snapshots specifying flushed generations to include. + shard_snapshots : list of ShardSnapshot + Shard snapshots specifying flushed generations to include. vector_column : str Name of the ``FixedSizeList`` vector column. pk_columns : list of str, optional @@ -470,7 +502,7 @@ class LsmVectorSearchPlanner: def __init__( self, dataset: "lance.LanceDataset", - region_snapshots: List[RegionSnapshot], + shard_snapshots: List[ShardSnapshot], vector_column: str, pk_columns: Optional[List[str]] = None, distance_type: Optional[str] = None, @@ -482,7 +514,7 @@ def __init__( kwargs["distance_type"] = distance_type self._raw = _LsmVectorSearchPlanner( dataset._ds, - [s._raw for s in region_snapshots], + [s._raw for s in shard_snapshots], vector_column, **kwargs, ) @@ -517,16 +549,16 @@ def plan_search( return ExecutionPlan(self._raw.plan_search(query, k, nprobes, columns)) -def _unwrap_region_id(region_id: str) -> str: - """Validate region_id is a UUID string.""" +def _unwrap_shard_id(shard_id: str) -> str: + """Validate shard_id is a UUID string.""" import uuid as _uuid - _uuid.UUID(region_id) # raises ValueError if invalid - return region_id + _uuid.UUID(shard_id) # raises ValueError if invalid + return shard_id def _to_raw_merged_generations( generations: Iterable[MergedGeneration], ) -> list: """Convert Python MergedGeneration list to PyO3 _MergedGeneration list.""" - return [_MergedGeneration(g.region_id, g.generation) for g in generations] + return [_MergedGeneration(g.shard_id, g.generation) for g in generations] diff --git a/python/python/tests/test_mem_wal.py b/python/python/tests/test_mem_wal.py index 88397e94167..b8c859cb637 100644 --- a/python/python/tests/test_mem_wal.py +++ b/python/python/tests/test_mem_wal.py @@ -11,8 +11,12 @@ from lance.mem_wal import ( LsmPointLookupPlanner, LsmScanner, - RegionSnapshot, + ShardingField, + ShardingSpec, + ShardSnapshot, + evaluate_sharding_spec, ) +from lance.schema import LanceSchema _PK_META = {"lance-schema:unenforced-primary-key": "true"} _LOOKUP_SCHEMA = pa.schema( @@ -51,13 +55,13 @@ def _append_only_table(ids, prefix: str) -> pa.Table: ) -def _write_flushed_gen(base_path: str, region_id: str, gen_folder: str, data: pa.Table): +def _write_flushed_gen(base_path: str, shard_id: str, gen_folder: str, data: pa.Table): """Write a flushed-generation Lance dataset at the expected sub-path. The collector resolves flushed generation paths as: - {base_dataset_path}/_mem_wal/{region_id}/{gen_folder} + {base_dataset_path}/_mem_wal/{shard_id}/{gen_folder} """ - gen_path = os.path.join(base_path, "_mem_wal", region_id, gen_folder) + gen_path = os.path.join(base_path, "_mem_wal", shard_id, gen_folder) lance.write_dataset(data, gen_path, schema=_LOOKUP_SCHEMA) @@ -71,10 +75,10 @@ def test_point_lookup_with_memtables(tmp_path): base : ids [1, 2, 3] names ["base_1", "base_2", "base_3"] gen_1 : ids [2] names ["gen1_2"] ← update to id=2 - RegionSnapshot: flushed_generation(gen=1, path="gen_1"), current_generation=2 + ShardSnapshot: flushed_generation(gen=1, path="gen_1"), current_generation=2 """ ds_path = str(tmp_path / "base") - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) # --- Base dataset --- base_ds = lance.write_dataset( @@ -83,11 +87,11 @@ def test_point_lookup_with_memtables(tmp_path): base_ds.initialize_mem_wal() # --- Flushed generation: overwrites id=2 --- - _write_flushed_gen(ds_path, region_id, "gen_1", _lookup_table([2], "gen1")) + _write_flushed_gen(ds_path, shard_id, "gen_1", _lookup_table([2], "gen1")) - # --- RegionSnapshot describing the flushed state --- + # --- ShardSnapshot describing the flushed state --- snap = ( - RegionSnapshot(region_id) + ShardSnapshot(shard_id) .with_flushed_generation(1, "gen_1") .with_current_generation(2) ) @@ -127,17 +131,17 @@ def test_lsm_scanner_with_memtables(tmp_path): Expected result: 3 unique rows — id=2 from gen_1, id=1 and id=3 from base. """ ds_path = str(tmp_path / "base") - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) base_ds = lance.write_dataset( _lookup_table([1, 2, 3], "base"), ds_path, schema=_LOOKUP_SCHEMA ) base_ds.initialize_mem_wal() - _write_flushed_gen(ds_path, region_id, "gen_1", _lookup_table([2], "gen1")) + _write_flushed_gen(ds_path, shard_id, "gen_1", _lookup_table([2], "gen1")) snap = ( - RegionSnapshot(region_id) + ShardSnapshot(shard_id) .with_flushed_generation(1, "gen_1") .with_current_generation(2) ) @@ -153,14 +157,14 @@ def test_lsm_scanner_with_memtables(tmp_path): assert name_by_id[3] == "base_3" -def test_region_writer_lsm_scanner_includes_own_flushed_generations(tmp_path): +def test_shard_writer_lsm_scanner_includes_own_flushed_generations(tmp_path): ds_path = str(tmp_path / "base") - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) ds = lance.write_dataset(_lookup_table([0], "base"), ds_path, schema=_LOOKUP_SCHEMA) ds.initialize_mem_wal() with ds.mem_wal_writer( - region_id, + shard_id, durable_write=True, max_wal_buffer_size=1, max_wal_flush_interval_ms=10, @@ -249,15 +253,15 @@ def _e2e_batch(schema, start_id: int, num_rows: int) -> pa.RecordBatch: ) -def test_region_writer_e2e_correctness(tmp_path): +def test_shard_writer_e2e_correctness(tmp_path): """ - End-to-end correctness test for RegionWriter covering: + End-to-end correctness test for ShardWriter covering: - Multi-round writes that trigger WAL and MemTable flushes - - File-system layout verification (_mem_wal//wal/ and manifest/) + - File-system layout verification (_mem_wal//wal/ and manifest/) - Flushed generation data readable via LsmScanner - New writer created after close can write and scan correctly - Mirrors Rust test: region_writer_tests::test_region_writer_e2e_correctness + Mirrors Rust test: shard_writer_tests::test_shard_writer_e2e_correctness """ schema = _e2e_schema() ds_path = str(tmp_path / "ds") @@ -272,9 +276,9 @@ def test_region_writer_e2e_correctness(tmp_path): ds.initialize_mem_wal(maintained_indexes=["id_btree"]) # Small buffers to trigger WAL and MemTable flushes during the test - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) writer = ds.mem_wal_writer( - region_id, + shard_id, durable_write=True, sync_indexed_write=True, max_wal_buffer_size=10 * 1024, # 10 KB @@ -304,7 +308,7 @@ def test_region_writer_e2e_correctness(tmp_path): assert closed_memtable_stats["generation"] >= 1 # === File-system layout === - mem_wal_dir = os.path.join(ds_path, "_mem_wal", region_id) + mem_wal_dir = os.path.join(ds_path, "_mem_wal", shard_id) assert os.path.isdir(mem_wal_dir), f"MemWAL directory missing: {mem_wal_dir}" wal_dir = os.path.join(mem_wal_dir, "wal") @@ -325,9 +329,9 @@ def test_region_writer_e2e_correctness(tmp_path): # === New writer: write and read back via active MemTable scanner === ds2 = lance.dataset(ds_path) - region_id2 = str(uuid.uuid4()) + shard_id2 = str(uuid.uuid4()) with ds2.mem_wal_writer( - region_id2, durable_write=False, sync_indexed_write=True + shard_id2, durable_write=False, sync_indexed_write=True ) as writer2: verify_batch = _e2e_batch(schema, start_id=10000, num_rows=10) writer2.put(pa.Table.from_batches([verify_batch])) @@ -375,13 +379,22 @@ def test_initialize_mem_wal_unsharded(tmp_path): def test_initialize_mem_wal_bucket_sharding(tmp_path): ds = _mem_wal_dataset(tmp_path) - ds.initialize_mem_wal(bucket_column="id", num_buckets=8) + ds.initialize_mem_wal(bucket_column="name", num_buckets=8) details = ds.mem_wal_index_details() assert details["num_shards"] == 8 field = details["sharding_specs"][0]["fields"][0] assert field["transform"] == "bucket" + assert "expression" in field assert field["parameters"]["num_buckets"] == "8" + assert len(field["source_ids"]) == 1 + + batch = _lookup_table([1, 2, 3], "base").to_batches()[0] + result = evaluate_sharding_spec( + batch, details["sharding_specs"][0], LanceSchema.from_pyarrow(batch.schema) + ) + assert result.column_names == [field["field_id"]] + assert result.num_rows == batch.num_rows def test_initialize_mem_wal_bucket_sharding_without_primary_key(tmp_path): @@ -456,6 +469,55 @@ def test_initialize_mem_wal_rejects_partial_bucket(tmp_path): ds.initialize_mem_wal(bucket_column="id") +def test_evaluate_sharding_spec_python_binding(): + batch = pa.record_batch( + [pa.array([1, 2, None, 3], type=pa.int32())], + names=["id"], + ) + spec = ShardingSpec( + 1, + [ + ShardingField( + field_id="bucket", + source_ids=[0], + transform="bucket", + result_type="int32", + parameters={"num_buckets": "8"}, + ) + ], + ) + + result = evaluate_sharding_spec(batch, spec, LanceSchema.from_pyarrow(batch.schema)) + + assert result.column_names == ["bucket"] + assert result.column(0).to_pylist() == [2, 7, 0, 1] + + +def test_evaluate_sharding_spec_python_binding_column_parameter(): + batch = pa.record_batch( + [pa.array(["a", "b", None], type=pa.utf8())], + names=["key"], + ) + spec = { + "spec_id": 1, + "fields": [ + { + "field_id": "key_bucket", + "source_ids": [0], + "transform": "bucket", + "expression": None, + "result_type": "int32", + "parameters": {"num_buckets": "8"}, + } + ], + } + + result = evaluate_sharding_spec(batch, spec, LanceSchema.from_pyarrow(batch.schema)) + + assert result.column_names == ["key_bucket"] + assert result.column(0).to_pylist() == [1, 5, 0] + + def test_mem_wal_index_details_none_before_init(tmp_path): ds = _mem_wal_dataset(tmp_path) assert ds.mem_wal_index_details() is None diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 4c7164e2ce8..c868504e87c 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -3375,6 +3375,7 @@ impl Dataset { field_dict.set_item("field_id", &field.field_id)?; field_dict.set_item("source_ids", field.source_ids.clone())?; field_dict.set_item("transform", field.transform.clone())?; + field_dict.set_item("expression", field.expression.clone())?; field_dict.set_item("result_type", &field.result_type)?; field_dict.set_item("parameters", field.parameters.clone())?; fields.append(field_dict)?; @@ -3386,12 +3387,12 @@ impl Dataset { Ok(Some(dict)) } - /// Get a RegionWriter for the specified region. + /// Get a ShardWriter for the specified shard. /// /// `initialize_mem_wal()` must be called before using this method. #[allow(clippy::too_many_arguments)] #[pyo3(signature=( - region_id, + shard_id, *, durable_write=None, sync_indexed_write=None, @@ -3410,7 +3411,7 @@ impl Dataset { fn mem_wal_writer( &self, py: Python<'_>, - region_id: String, + shard_id: String, durable_write: Option, sync_indexed_write: Option, max_wal_buffer_size: Option, @@ -3424,11 +3425,11 @@ impl Dataset { async_index_interval_ms: Option, backpressure_log_interval_ms: Option, stats_log_interval_ms: Option, - ) -> PyResult { + ) -> PyResult { use lance::dataset::mem_wal::DatasetMemWalExt; - let uuid = uuid::Uuid::parse_str(®ion_id) - .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + let uuid = uuid::Uuid::parse_str(&shard_id) + .map_err(|e| PyValueError::new_err(format!("Invalid shard_id UUID: {}", e)))?; let config = writer_config_from_kwargs( durable_write, @@ -3455,7 +3456,7 @@ impl Dataset { )? .map_err(|e| PyIOError::new_err(e.to_string()))?; - Ok(crate::mem_wal::PyRegionWriter::new( + Ok(crate::mem_wal::PyShardWriter::new( writer, uuid, self.ds.clone(), diff --git a/python/src/lib.rs b/python/src/lib.rs index f1d384d275e..cf29b26c46a 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -286,12 +286,13 @@ fn lance(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; // MemWAL classes m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_wrapped(wrap_pyfunction!(mem_wal::py_evaluate_sharding_spec))?; m.add_wrapped(wrap_pyfunction!(bfloat16_array))?; m.add_wrapped(wrap_pyfunction!(write_dataset))?; m.add_wrapped(wrap_pyfunction!(write_fragments))?; diff --git a/python/src/mem_wal.rs b/python/src/mem_wal.rs index 68bd57a58bc..b9aca4e15a1 100644 --- a/python/src/mem_wal.rs +++ b/python/src/mem_wal.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::sync::Arc; use arrow::ffi_stream::ArrowArrayStreamReader; @@ -20,10 +21,10 @@ use lance::dataset::mem_wal::scanner::{ FlushedGeneration, LsmDataSourceCollector, LsmPointLookupPlanner, LsmVectorSearchPlanner, }; use lance::dataset::mem_wal::write::{MemTableStats, WriteStatsSnapshot}; -use lance::dataset::mem_wal::{ - LsmScanner, ShardSnapshot as RegionSnapshot, ShardWriter as RegionWriter, +use lance::dataset::mem_wal::{LsmScanner, ShardSnapshot, ShardWriter, evaluate_sharding_spec}; +use lance_index::mem_wal::{ + MergedGeneration as LanceMergedGeneration, ShardingField, ShardingSpec, }; -use lance_index::mem_wal::MergedGeneration as LanceMergedGeneration; use lance_linalg::distance::DistanceType; use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; @@ -33,28 +34,82 @@ use uuid::Uuid; use crate::dataset::Dataset as PyDataset; use crate::rt; +use crate::schema::LanceSchema as PyLanceSchema; + +/// Evaluate a MemWAL sharding spec against one PyArrow RecordBatch. +#[pyfunction(name = "_evaluate_sharding_spec", signature = (batch, spec, schema))] +pub fn py_evaluate_sharding_spec<'py>( + py: Python<'py>, + batch: PyArrowType, + spec: &Bound<'_, PyAny>, + schema: &PyLanceSchema, +) -> PyResult> { + let PyArrowType(batch) = batch; + let spec = sharding_spec_from_py(spec)?; + let result = evaluate_sharding_spec(&batch, &spec, &schema.0) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + result.to_pyarrow(py) +} + +fn sharding_spec_from_py(spec: &Bound<'_, PyAny>) -> PyResult { + let spec_id = get_py_value(spec, "spec_id")?.extract::()?; + let fields_obj = get_py_value(spec, "fields")?; + let mut fields = Vec::new(); + for field_obj in fields_obj.try_iter()? { + fields.push(sharding_field_from_py(&field_obj?)?); + } + Ok(ShardingSpec { spec_id, fields }) +} + +fn sharding_field_from_py(field: &Bound<'_, PyAny>) -> PyResult { + Ok(ShardingField { + field_id: get_py_value(field, "field_id")?.extract::()?, + source_ids: get_py_value(field, "source_ids")?.extract::>()?, + transform: optional_string(get_py_value(field, "transform")?)?, + expression: optional_string(get_py_value(field, "expression")?)?, + result_type: get_py_value(field, "result_type")?.extract::()?, + parameters: get_py_value(field, "parameters")?.extract::>()?, + }) +} + +fn get_py_value<'py>(obj: &Bound<'py, PyAny>, name: &str) -> PyResult> { + if let Ok(dict) = obj.cast::() { + dict.get_item(name)? + .ok_or_else(|| PyValueError::new_err(format!("Missing sharding spec field '{name}'"))) + } else { + obj.getattr(name) + } +} + +fn optional_string(value: Bound<'_, PyAny>) -> PyResult> { + if value.is_none() { + Ok(None) + } else { + value.extract::().map(Some) + } +} -/// Represents a single generation of a MemWAL region that has been merged +/// Represents a single generation of a MemWAL shard that has been merged /// into the base table. Used with `MergeInsertBuilder.mark_generations_as_merged()`. #[pyclass(name = "_MergedGeneration", module = "_lib")] pub struct PyMergedGeneration { - pub region_id: String, + pub shard_id: String, pub generation: u64, } #[pymethods] impl PyMergedGeneration { #[new] - pub fn new(region_id: String, generation: u64) -> Self { + pub fn new(shard_id: String, generation: u64) -> Self { Self { - region_id, + shard_id, generation, } } #[getter] - pub fn region_id(&self) -> &str { - &self.region_id + pub fn shard_id(&self) -> &str { + &self.shard_id } #[getter] @@ -64,42 +119,42 @@ impl PyMergedGeneration { pub fn __repr__(&self) -> String { format!( - "_MergedGeneration(region_id='{}', generation={})", - self.region_id, self.generation + "_MergedGeneration(shard_id='{}', generation={})", + self.shard_id, self.generation ) } } impl PyMergedGeneration { pub fn to_lance(&self) -> PyResult { - let uuid = Uuid::parse_str(&self.region_id) - .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + let uuid = Uuid::parse_str(&self.shard_id) + .map_err(|e| PyValueError::new_err(format!("Invalid shard_id UUID: {}", e)))?; Ok(LanceMergedGeneration::new(uuid, self.generation)) } } -/// Snapshot of a MemWAL region's state at a point in time. +/// Snapshot of a MemWAL shard's state at a point in time. /// /// Used to specify which flushed generations to include when creating an /// `_LsmScanner`. Supports a builder pattern for adding generations. -#[pyclass(name = "_RegionSnapshot", module = "_lib", skip_from_py_object)] +#[pyclass(name = "_ShardSnapshot", module = "_lib", skip_from_py_object)] #[derive(Clone)] -pub struct PyRegionSnapshot { - pub inner: RegionSnapshot, +pub struct PyShardSnapshot { + pub inner: ShardSnapshot, } #[pymethods] -impl PyRegionSnapshot { +impl PyShardSnapshot { #[new] - pub fn new(region_id: String) -> PyResult { - let uuid = Uuid::parse_str(®ion_id) - .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + pub fn new(shard_id: String) -> PyResult { + let uuid = Uuid::parse_str(&shard_id) + .map_err(|e| PyValueError::new_err(format!("Invalid shard_id UUID: {}", e)))?; Ok(Self { - inner: RegionSnapshot::new(uuid), + inner: ShardSnapshot::new(uuid), }) } - /// Set the RegionSpec ID for this snapshot. + /// Set the sharding spec ID for this snapshot. pub fn with_spec_id(mut slf: PyRefMut<'_, Self>, spec_id: u32) -> PyRefMut<'_, Self> { slf.inner = slf.inner.clone().with_spec_id(spec_id); slf @@ -125,13 +180,13 @@ impl PyRegionSnapshot { } #[getter] - pub fn region_id(&self) -> String { + pub fn shard_id(&self) -> String { self.inner.shard_id.to_string() } pub fn __repr__(&self) -> String { format!( - "_RegionSnapshot(region_id='{}', current_gen={}, flushed_gens={})", + "_ShardSnapshot(shard_id='{}', current_gen={}, flushed_gens={})", self.inner.shard_id, self.inner.current_generation, self.inner.flushed_generations.len() @@ -139,26 +194,26 @@ impl PyRegionSnapshot { } } -/// Long-lived stateful writer for a MemWAL region. +/// Long-lived stateful writer for a MemWAL shard. /// /// Supports writing batches, querying statistics, creating LSM scanners, /// and graceful shutdown. Supports the Python context manager protocol. -#[pyclass(name = "_RegionWriter", module = "_lib")] -pub struct PyRegionWriter { - inner: Arc>>, - closed_state: Arc>>, - region_id: Uuid, +#[pyclass(name = "_ShardWriter", module = "_lib")] +pub struct PyShardWriter { + inner: Arc>>, + closed_state: Arc>>, + shard_id: Uuid, dataset: Arc, } #[derive(Clone)] -struct ClosedRegionWriterState { +struct ClosedShardWriterState { stats: WriteStatsSnapshot, memtable_stats: MemTableStats, } #[pymethods] -impl PyRegionWriter { +impl PyShardWriter { /// Write data batches to the MemWAL. /// /// Accepts any PyArrow-compatible data source (RecordBatch, Table, @@ -180,7 +235,7 @@ impl PyRegionWriter { match guard.as_ref() { Some(writer) => writer.put(batches).await.map(|_| ()), None => Err(lance_core::Error::invalid_input( - "RegionWriter is already closed", + "ShardWriter is already closed", )), } })? @@ -205,7 +260,7 @@ impl PyRegionWriter { writer.close().await?; let closed_memtable_stats = closed_memtable_stats(memtable_stats_before_close); let mut closed_guard = closed_state.lock().await; - *closed_guard = Some(ClosedRegionWriterState { + *closed_guard = Some(ClosedShardWriterState { stats: stats_snapshot, memtable_stats: closed_memtable_stats, }); @@ -235,7 +290,7 @@ impl PyRegionWriter { .as_ref() .map(|state| state.stats.clone()) .ok_or_else(|| { - lance_core::Error::invalid_input("RegionWriter is already closed") + lance_core::Error::invalid_input("ShardWriter is already closed") }) } })? @@ -262,7 +317,7 @@ impl PyRegionWriter { .as_ref() .map(|state| state.memtable_stats.clone()) .ok_or_else(|| { - lance_core::Error::invalid_input("RegionWriter is already closed") + lance_core::Error::invalid_input("ShardWriter is already closed") }) } } @@ -275,13 +330,13 @@ impl PyRegionWriter { /// Create an LSM scanner that includes the active MemTable for strong consistency. /// /// The scanner covers: base table + given flushed generations + current active MemTable. - #[pyo3(signature = (region_snapshots=vec![]))] + #[pyo3(signature = (shard_snapshots=vec![]))] pub fn lsm_scanner( &self, py: Python<'_>, - region_snapshots: Vec>, + shard_snapshots: Vec>, ) -> PyResult { - let mut snapshots: Vec = region_snapshots + let mut snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -289,7 +344,7 @@ impl PyRegionWriter { let pk_columns = get_pk_columns(&self.dataset)?; let inner = self.inner.clone(); let dataset = self.dataset.clone(); - let region_id = self.region_id; + let shard_id = self.shard_id; let (active_ref, writer_snapshot) = rt() .block_on(Some(py), async move { @@ -300,31 +355,31 @@ impl PyRegionWriter { let writer_snapshot = w .manifest() .await? - .map(region_snapshot_from_manifest) - .unwrap_or_else(|| RegionSnapshot::new(region_id)); + .map(shard_snapshot_from_manifest) + .unwrap_or_else(|| ShardSnapshot::new(shard_id)); Ok((active_ref, writer_snapshot)) } None => Err(lance_core::Error::invalid_input( - "RegionWriter is already closed", + "ShardWriter is already closed", )), } })? .map_err(|e| PyIOError::new_err(e.to_string()))?; - snapshots.retain(|snapshot| snapshot.shard_id != region_id); + snapshots.retain(|snapshot| snapshot.shard_id != shard_id); snapshots.push(writer_snapshot); let scanner = LsmScanner::new(dataset, snapshots, pk_columns) - .with_active_memtable(region_id, active_ref); + .with_active_memtable(shard_id, active_ref); Ok(PyLsmScanner { inner: Some(scanner), }) } - /// Return the region ID as a UUID string. + /// Return the shard ID as a UUID string. #[getter] - pub fn region_id(&self) -> String { - self.region_id.to_string() + pub fn shard_id(&self) -> String { + self.shard_id.to_string() } pub fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { @@ -343,13 +398,13 @@ impl PyRegionWriter { } } -impl PyRegionWriter { - /// Create from a Rust RegionWriter and dataset reference. - pub fn new(writer: RegionWriter, region_id: Uuid, dataset: Arc) -> Self { +impl PyShardWriter { + /// Create from a Rust ShardWriter and dataset reference. + pub fn new(writer: ShardWriter, shard_id: Uuid, dataset: Arc) -> Self { Self { inner: Arc::new(TokioMutex::new(Some(writer))), closed_state: Arc::new(TokioMutex::new(None)), - region_id, + shard_id, dataset, } } @@ -438,14 +493,14 @@ pub struct PyLsmScanner { #[pymethods] impl PyLsmScanner { - /// Create a scanner from dataset and region snapshots (without active MemTable). + /// Create a scanner from dataset and shard snapshots (without active MemTable). #[staticmethod] pub fn from_snapshots( dataset: &Bound<'_, PyDataset>, - region_snapshots: Vec>, + shard_snapshots: Vec>, ) -> PyResult { let ds = dataset.borrow().ds.clone(); - let snapshots: Vec = region_snapshots + let snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -584,14 +639,14 @@ pub struct PyLsmPointLookupPlanner { #[pymethods] impl PyLsmPointLookupPlanner { #[new] - #[pyo3(signature = (dataset, region_snapshots, pk_columns=None))] + #[pyo3(signature = (dataset, shard_snapshots, pk_columns=None))] pub fn new( dataset: &Bound<'_, PyDataset>, - region_snapshots: Vec>, + shard_snapshots: Vec>, pk_columns: Option>, ) -> PyResult { let ds = dataset.borrow().ds.clone(); - let snapshots: Vec = region_snapshots + let snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -649,16 +704,16 @@ pub struct PyLsmVectorSearchPlanner { #[pymethods] impl PyLsmVectorSearchPlanner { #[new] - #[pyo3(signature = (dataset, region_snapshots, vector_column, pk_columns=None, distance_type=None))] + #[pyo3(signature = (dataset, shard_snapshots, vector_column, pk_columns=None, distance_type=None))] pub fn new( dataset: &Bound<'_, PyDataset>, - region_snapshots: Vec>, + shard_snapshots: Vec>, vector_column: String, pk_columns: Option>, distance_type: Option, ) -> PyResult { let ds = dataset.borrow().ds.clone(); - let snapshots: Vec = region_snapshots + let snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -878,8 +933,8 @@ fn scalar_values_from_pk_value( Ok(pk_values) } -fn region_snapshot_from_manifest(manifest: lance_index::mem_wal::ShardManifest) -> RegionSnapshot { - RegionSnapshot { +fn shard_snapshot_from_manifest(manifest: lance_index::mem_wal::ShardManifest) -> ShardSnapshot { + ShardSnapshot { shard_id: manifest.shard_id, spec_id: manifest.shard_spec_id, current_generation: manifest.current_generation, diff --git a/rust/lance/src/dataset/mem_wal.rs b/rust/lance/src/dataset/mem_wal.rs index 1cea7e8b9e5..5f3bc2ed483 100644 --- a/rust/lance/src/dataset/mem_wal.rs +++ b/rust/lance/src/dataset/mem_wal.rs @@ -38,6 +38,7 @@ pub mod index; mod manifest; pub mod memtable; pub mod scanner; +pub mod sharding; pub mod util; mod wal; pub mod write; @@ -46,6 +47,10 @@ pub use api::{DatasetMemWalExt, InitializeMemWalBuilder}; pub use manifest::ShardManifestStore; pub use memtable::scanner::MemTableScanner; pub use scanner::{LsmDataSource, LsmGeneration, LsmScanner, ShardSnapshot}; +pub use sharding::{ + evaluate_sharding_spec, evaluate_sharding_spec_with_embedded_columns, + evaluate_sharding_spec_with_source_columns, +}; pub use wal::{WalAppendResult, WalAppender, WalReadEntry, WalTailer}; pub use write::ShardWriter; pub use write::ShardWriterConfig; diff --git a/rust/lance/src/dataset/mem_wal/api.rs b/rust/lance/src/dataset/mem_wal/api.rs index 65623c4f5e0..c8a3bc0441d 100644 --- a/rust/lance/src/dataset/mem_wal/api.rs +++ b/rust/lance/src/dataset/mem_wal/api.rs @@ -112,11 +112,9 @@ impl<'a> InitializeMemWalBuilder<'a> { /// Hash-bucket `column` into `num_buckets` shards. /// - /// For primary-key tables, `column` must name the dataset's single-column - /// unenforced primary key so every update for the same key routes to the - /// same shard. Append-only tables without a primary key may use any scalar - /// column. `num_buckets` must be in `[1, 1024]`. These constraints are - /// validated by [`execute`](Self::execute). + /// `column` must name a scalar dataset column that can be hash-bucketed. + /// `num_buckets` must be in `[1, 1024]`. These constraints are validated + /// by [`execute`](Self::execute). pub fn bucket_sharding(mut self, column: impl Into, num_buckets: u32) -> Self { self.sharding = Sharding::Bucket { column: column.into(), @@ -282,34 +280,15 @@ fn bucket_sharding_spec(dataset: &Dataset, column: &str, num_buckets: u32) -> Re ))); } - let pk_fields = dataset.schema().unenforced_primary_key(); - let source_field = match pk_fields.as_slice() { - [single] => { - let pk = *single; - if pk.name.as_str() != column { - return Err(Error::invalid_input(format!( - "bucket_sharding: column '{}' does not match the unenforced primary key column '{}'", - column, pk.name - ))); - } - pk - } - [] => dataset.schema().field(column).ok_or_else(|| { - Error::invalid_input(format!( - "bucket_sharding: column '{}' not found on the dataset", - column - )) - })?, - _ => { - return Err(Error::invalid_input( - "bucket_sharding requires a single-column unenforced primary key; \ - use unsharded() for a multi-column key", - )); - } - }; + let source_field = dataset.schema().field(column).ok_or_else(|| { + Error::invalid_input(format!( + "bucket_sharding: column '{}' not found on the dataset", + column + )) + })?; let data_type = source_field.data_type(); - if data_type.is_nested() || data_type.is_null() { + if !is_bucket_sharding_supported_type(&data_type) { return Err(Error::invalid_input(format!( "bucket_sharding: column '{}' has type {:?}, which cannot be used as a shard key", column, data_type @@ -359,6 +338,29 @@ fn identity_sharding_spec(dataset: &Dataset, column: &str) -> Result bool { + matches!( + data_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + ) +} + /// The Arrow type name for a scalar column usable as a shard key, or `None` /// for types that cannot be a shard key. fn scalar_result_type(data_type: &DataType) -> Option<&'static str> { diff --git a/rust/lance/src/dataset/mem_wal/sharding.rs b/rust/lance/src/dataset/mem_wal/sharding.rs new file mode 100644 index 00000000000..5982ce99ee9 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/sharding.rs @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Arrow-native evaluation for MemWAL sharding specs. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_array::cast::as_primitive_array; +use arrow_array::types::{ + Date32Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use arrow_array::{ + Array, ArrayRef, BooleanArray, Int32Array, LargeStringArray, RecordBatch, StringArray, +}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use lance_core::{Error, Result, datatypes::Schema as LanceSchema}; +use lance_index::mem_wal::{ShardingField, ShardingSpec}; + +const BUCKET_TRANSFORM: &str = "bucket"; +const IDENTITY_TRANSFORM: &str = "identity"; +const UNSHARDED_TRANSFORM: &str = "unsharded"; +const NUM_BUCKETS_PARAM: &str = "num_buckets"; +const COLUMN_PARAM: &str = "column"; +const MURMUR3_SEED: i32 = 0; + +/// Evaluate a MemWAL sharding specification against an Arrow record batch. +/// +/// The returned batch has one column per [`ShardingField`] in `spec`, with the +/// column names taken from `field_id`. Bucket sharding returns `Int32` bucket +/// IDs. Identity sharding returns the source column unchanged. Unsharded +/// sharding returns `Int32` zeros. +pub fn evaluate_sharding_spec( + batch: &RecordBatch, + spec: &ShardingSpec, + schema: &LanceSchema, +) -> Result { + let source_id_to_column = source_id_to_column_map(schema); + evaluate_sharding_spec_with_source_columns(batch, spec, &source_id_to_column) +} + +/// Evaluate a MemWAL sharding specification that embeds source column names. +/// +/// Prefer [`evaluate_sharding_spec`] for table-bound evaluation. This helper is +/// for specs that carry a `column` parameter for each source-dependent field. +pub fn evaluate_sharding_spec_with_embedded_columns( + batch: &RecordBatch, + spec: &ShardingSpec, +) -> Result { + evaluate_sharding_spec_with_source_columns(batch, spec, &HashMap::new()) +} + +/// Evaluate a MemWAL sharding specification with an explicit field-id mapping. +/// +/// Prefer [`evaluate_sharding_spec`] for table-bound evaluation. This helper is +/// intended for binding layers that have already derived the mapping from a +/// table schema. +pub fn evaluate_sharding_spec_with_source_columns( + batch: &RecordBatch, + spec: &ShardingSpec, + source_id_to_column: &HashMap, +) -> Result { + let mut fields = Vec::with_capacity(spec.fields.len()); + let mut columns = Vec::with_capacity(spec.fields.len()); + + for field in &spec.fields { + let column = evaluate_sharding_field(batch, field, source_id_to_column)?; + fields.push(Field::new( + field.field_id.clone(), + column.data_type().clone(), + column.is_nullable(), + )); + columns.push(column); + } + + Ok(RecordBatch::try_new( + Arc::new(ArrowSchema::new(fields)), + columns, + )?) +} + +fn source_id_to_column_map(schema: &LanceSchema) -> HashMap { + schema + .fields_pre_order() + .map(|field| { + let column = schema + .field_ancestry_by_id(field.id) + .map(|path| { + path.iter() + .map(|field| field.name.as_str()) + .collect::>() + .join(".") + }) + .unwrap_or_else(|| field.name.clone()); + (field.id, column) + }) + .collect() +} + +fn evaluate_sharding_field( + batch: &RecordBatch, + field: &ShardingField, + source_id_to_column: &HashMap, +) -> Result { + match field.transform.as_deref() { + Some(BUCKET_TRANSFORM) => evaluate_bucket_sharding(batch, field, source_id_to_column), + Some(IDENTITY_TRANSFORM) => evaluate_identity_sharding(batch, field, source_id_to_column), + Some(UNSHARDED_TRANSFORM) => Ok(Arc::new(Int32Array::from(vec![0; batch.num_rows()]))), + other => Err(Error::invalid_input(format!( + "Unsupported MemWAL sharding transform for field '{}': {:?}", + field.field_id, other + ))), + } +} + +fn evaluate_identity_sharding( + batch: &RecordBatch, + field: &ShardingField, + source_id_to_column: &HashMap, +) -> Result { + let column_name = source_column_name(field, source_id_to_column)?; + Ok(batch + .column_by_name(&column_name) + .ok_or_else(|| { + Error::invalid_input(format!( + "Sharding source column '{}' not found in batch", + column_name + )) + })? + .clone()) +} + +fn evaluate_bucket_sharding( + batch: &RecordBatch, + field: &ShardingField, + source_id_to_column: &HashMap, +) -> Result { + let column_name = source_column_name(field, source_id_to_column)?; + let num_buckets = field + .parameters + .get(NUM_BUCKETS_PARAM) + .ok_or_else(|| { + Error::invalid_input(format!( + "Bucket sharding field '{}' missing '{}' parameter", + field.field_id, NUM_BUCKETS_PARAM + )) + })? + .parse::() + .map_err(|e| { + Error::invalid_input(format!( + "Bucket sharding field '{}' has invalid num_buckets '{}': {}", + field.field_id, field.parameters[NUM_BUCKETS_PARAM], e + )) + })?; + if num_buckets <= 0 { + return Err(Error::invalid_input(format!( + "Bucket sharding field '{}' requires positive num_buckets, got {}", + field.field_id, num_buckets + ))); + } + + let column = batch.column_by_name(&column_name).ok_or_else(|| { + Error::invalid_input(format!( + "Sharding source column '{}' not found in batch", + column_name + )) + })?; + let mut bucket_ids = Vec::with_capacity(batch.num_rows()); + for row_idx in 0..batch.num_rows() { + let hash = hash_array_value(column.as_ref(), row_idx, MURMUR3_SEED)?; + bucket_ids.push((hash & i32::MAX) % num_buckets); + } + Ok(Arc::new(Int32Array::from(bucket_ids))) +} + +fn source_column_name( + field: &ShardingField, + source_id_to_column: &HashMap, +) -> Result { + if let Some(column) = field.parameters.get(COLUMN_PARAM) + && !column.trim().is_empty() + { + return Ok(column.clone()); + } + let Some(source_id) = field.source_ids.first() else { + return Err(Error::invalid_input(format!( + "MemWAL sharding field '{}' has no source column", + field.field_id + ))); + }; + source_id_to_column.get(source_id).cloned().ok_or_else(|| { + Error::invalid_input(format!( + "MemWAL sharding field '{}' source id {} was not mapped to a batch column", + field.field_id, source_id + )) + }) +} + +fn hash_array_value(array: &dyn Array, row_idx: usize, seed: i32) -> Result { + if array.is_null(row_idx) { + return Ok(seed); + } + match array.data_type() { + DataType::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(hash_int(if array.value(row_idx) { 1 } else { 0 }, seed)) + } + DataType::Int8 => hash_primitive_int::(array, row_idx, seed, |v| v as i32), + DataType::Int16 => hash_primitive_int::(array, row_idx, seed, |v| v as i32), + DataType::Int32 => hash_primitive_int::(array, row_idx, seed, |v| v), + DataType::Date32 => hash_primitive_int::(array, row_idx, seed, |v| v), + DataType::Int64 => hash_primitive::(array, row_idx, seed, |v| v), + DataType::UInt8 => hash_primitive_int::(array, row_idx, seed, |v| v as i32), + DataType::UInt16 => hash_primitive_int::(array, row_idx, seed, |v| v as i32), + DataType::UInt32 => hash_primitive_int::(array, row_idx, seed, |v| v as i32), + DataType::UInt64 => hash_primitive::(array, row_idx, seed, |v| v as i64), + DataType::Float32 => hash_primitive_int::(array, row_idx, seed, |v| { + canonical_f32_bits(v) as i32 + }), + DataType::Float64 => { + hash_primitive::(array, row_idx, seed, |v| canonical_f64_bits(v) as i64) + } + DataType::Timestamp(_, _) => hash_timestamp(array, row_idx, seed), + DataType::Time32(_) => hash_time32(array, row_idx, seed), + DataType::Time64(_) => hash_time64(array, row_idx, seed), + DataType::Utf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(hash_bytes(array.value(row_idx).as_bytes(), seed)) + } + DataType::LargeUtf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(hash_bytes(array.value(row_idx).as_bytes(), seed)) + } + other => Err(Error::invalid_input(format!( + "Unsupported bucket sharding column type: {:?}", + other + ))), + } +} + +fn hash_primitive_int(array: &dyn Array, row_idx: usize, seed: i32, convert: F) -> Result +where + T: arrow_array::types::ArrowPrimitiveType, + F: Fn(T::Native) -> i32, +{ + let array = as_primitive_array::(array); + Ok(hash_int(convert(array.value(row_idx)), seed)) +} + +fn hash_primitive(array: &dyn Array, row_idx: usize, seed: i32, convert: F) -> Result +where + T: arrow_array::types::ArrowPrimitiveType, + F: Fn(T::Native) -> i64, +{ + let array = as_primitive_array::(array); + Ok(hash_long(convert(array.value(row_idx)), seed)) +} + +fn hash_timestamp(array: &dyn Array, row_idx: usize, seed: i32) -> Result { + match array.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Second, _) => { + hash_primitive::(array, row_idx, seed, |v| v) + } + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => { + hash_primitive::(array, row_idx, seed, |v| v) + } + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => { + hash_primitive::(array, row_idx, seed, |v| v) + } + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _) => { + hash_primitive::(array, row_idx, seed, |v| v) + } + _ => unreachable!(), + } +} + +fn hash_time32(array: &dyn Array, row_idx: usize, seed: i32) -> Result { + match array.data_type() { + DataType::Time32(arrow_schema::TimeUnit::Second) => { + hash_primitive_int::(array, row_idx, seed, |v| v) + } + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + hash_primitive_int::(array, row_idx, seed, |v| v) + } + _ => unreachable!(), + } +} + +fn hash_time64(array: &dyn Array, row_idx: usize, seed: i32) -> Result { + match array.data_type() { + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + hash_primitive::(array, row_idx, seed, |v| v) + } + DataType::Time64(arrow_schema::TimeUnit::Nanosecond) => { + hash_primitive::(array, row_idx, seed, |v| v) + } + _ => unreachable!(), + } +} + +fn canonical_f32_bits(value: f32) -> u32 { + if value == 0.0 { + 0 + } else if value.is_nan() { + 0x7fc0_0000 + } else { + value.to_bits() + } +} + +fn canonical_f64_bits(value: f64) -> u64 { + if value == 0.0 { + 0 + } else if value.is_nan() { + 0x7ff8_0000_0000_0000 + } else { + value.to_bits() + } +} + +fn hash_int(value: i32, seed: i32) -> i32 { + fmix(mix_h1(seed, mix_k1(value)), 4) +} + +fn hash_long(value: i64, seed: i32) -> i32 { + let low = value as i32; + let high = (value >> 32) as i32; + let h1 = mix_h1(seed, mix_k1(low)); + let h1 = mix_h1(h1, mix_k1(high)); + fmix(h1, 8) +} + +fn hash_bytes(bytes: &[u8], seed: i32) -> i32 { + let mut h1 = seed; + let remainder = bytes.len() % 4; + let full_chunks_len = bytes.len() - remainder; + for chunk in bytes[..full_chunks_len].chunks_exact(4) { + let k1 = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + h1 = mix_h1(h1, mix_k1(k1)); + } + for byte in &bytes[full_chunks_len..] { + h1 = mix_h1(h1, mix_k1((*byte as i8) as i32)); + } + + fmix(h1, bytes.len() as i32) +} + +fn mix_k1(k1: i32) -> i32 { + let k1 = k1.wrapping_mul(0xcc9e_2d51u32 as i32); + k1.rotate_left(15).wrapping_mul(0x1b87_3593) +} + +fn mix_h1(h1: i32, k1: i32) -> i32 { + let h1 = h1 ^ k1; + h1.rotate_left(13) + .wrapping_mul(5) + .wrapping_add(0xe654_6b64u32 as i32) +} + +fn fmix(h1: i32, length: i32) -> i32 { + let mut h1 = h1 ^ length; + h1 ^= (h1 as u32 >> 16) as i32; + h1 = h1.wrapping_mul(0x85eb_ca6bu32 as i32); + h1 ^= (h1 as u32 >> 13) as i32; + h1 = h1.wrapping_mul(0xc2b2_ae35u32 as i32); + h1 ^ ((h1 as u32 >> 16) as i32) +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_array::{ + BooleanArray, Date32Array, Float32Array, Float64Array, Int32Array, StringArray, + }; + use lance_core::datatypes::Schema as LanceSchema; + use lance_index::mem_wal::ShardingField; + + fn single_field_spec(field: ShardingField) -> ShardingSpec { + ShardingSpec { + spec_id: 1, + fields: vec![field], + } + } + + fn bucket_field(num_buckets: i32) -> ShardingField { + ShardingField { + field_id: "bucket".to_string(), + source_ids: vec![0], + transform: Some(BUCKET_TRANSFORM.to_string()), + expression: None, + result_type: "int32".to_string(), + parameters: HashMap::from([(NUM_BUCKETS_PARAM.to_string(), num_buckets.to_string())]), + } + } + + fn lance_schema(batch: &RecordBatch) -> LanceSchema { + LanceSchema::try_from(batch.schema().as_ref()).unwrap() + } + + fn bucket_field_for_source(source_id: i32, num_buckets: i32) -> ShardingField { + ShardingField { + source_ids: vec![source_id], + ..bucket_field(num_buckets) + } + } + + #[test] + fn test_evaluate_bucket_sharding_int32() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(3)])) as ArrayRef, + )]) + .unwrap(); + let result = evaluate_sharding_spec( + &batch, + &single_field_spec(bucket_field(8)), + &lance_schema(&batch), + ) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[2, 7, 0, 1]); + } + + #[test] + fn test_evaluate_bucket_sharding_date32() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(Date32Array::from(vec![Some(1), Some(2), None, Some(3)])) as ArrayRef, + )]) + .unwrap(); + let result = evaluate_sharding_spec( + &batch, + &single_field_spec(bucket_field(8)), + &lance_schema(&batch), + ) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[2, 7, 0, 1]); + } + + #[test] + fn test_evaluate_bucket_sharding_string() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])) as ArrayRef, + )]) + .unwrap(); + let result = evaluate_sharding_spec( + &batch, + &single_field_spec(bucket_field(8)), + &lance_schema(&batch), + ) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[1, 5, 0]); + } + + #[test] + fn test_evaluate_bucket_sharding_scalar_types() { + let batch = RecordBatch::try_from_iter([ + ("bool", Arc::new(BooleanArray::from(vec![true])) as ArrayRef), + ( + "f32", + Arc::new(Float32Array::from(vec![1.25_f32])) as ArrayRef, + ), + ( + "f64", + Arc::new(Float64Array::from(vec![1.25_f64])) as ArrayRef, + ), + ]) + .unwrap(); + let schema = lance_schema(&batch); + let result = evaluate_sharding_spec( + &batch, + &single_field_spec(bucket_field_for_source(0, 8)), + &schema, + ) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[2]); + + let result = evaluate_sharding_spec( + &batch, + &single_field_spec(bucket_field_for_source(1, 8)), + &schema, + ) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[0]); + + let result = evaluate_sharding_spec( + &batch, + &single_field_spec(bucket_field_for_source(2, 8)), + &schema, + ) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[0]); + } + + #[test] + fn test_evaluate_identity_sharding() { + let batch = RecordBatch::try_from_iter([( + "id", + Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef, + )]) + .unwrap(); + let spec = single_field_spec(ShardingField { + field_id: "identity".to_string(), + source_ids: vec![0], + transform: Some(IDENTITY_TRANSFORM.to_string()), + expression: None, + result_type: "utf8".to_string(), + parameters: HashMap::new(), + }); + let result = evaluate_sharding_spec(&batch, &spec, &lance_schema(&batch)).unwrap(); + assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref()); + } + + #[test] + fn test_evaluate_bucket_sharding_embedded_column() { + let batch = RecordBatch::try_from_iter([( + "key", + Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])) as ArrayRef, + )]) + .unwrap(); + let mut field = bucket_field(8); + field.source_ids = Vec::new(); + field + .parameters + .insert(COLUMN_PARAM.to_string(), "key".to_string()); + let result = + evaluate_sharding_spec_with_embedded_columns(&batch, &single_field_spec(field)) + .unwrap(); + let buckets = as_primitive_array::(result.column(0).as_ref()); + assert_eq!(buckets.values(), &[1, 5, 0]); + } +} diff --git a/rust/lance/src/dataset/mem_wal/write.rs b/rust/lance/src/dataset/mem_wal/write.rs index 5f6a2af7243..85d0c92e78b 100644 --- a/rust/lance/src/dataset/mem_wal/write.rs +++ b/rust/lance/src/dataset/mem_wal/write.rs @@ -4344,20 +4344,9 @@ mod shard_writer_tests { .await; assert!(result.is_err(), "num_buckets = 0 should be rejected"); - // The bucket column must be the unenforced primary key column. - let result = dataset - .initialize_mem_wal() - .bucket_sharding("text", 8) - .execute() - .await; - assert!( - result.is_err(), - "a non-primary-key bucket column should be rejected" - ); - dataset .initialize_mem_wal() - .bucket_sharding("id", 8) + .bucket_sharding("text", 8) .execute() .await .expect("Failed to initialize MemWAL"); @@ -4376,6 +4365,10 @@ mod shard_writer_tests { field.parameters.get("num_buckets").map(String::as_str), Some("8") ); + assert_eq!(field.source_ids.len(), 1); + let source_id = field.source_ids[0]; + let source_field = dataset.schema().field("text").expect("text field exists"); + assert_eq!(source_id, source_field.id); } #[tokio::test] From c7c5626d4e830b46c239b5bf4e2a17e32ab901b8 Mon Sep 17 00:00:00 2001 From: Lance Release Bot Date: Wed, 20 May 2026 21:58:35 +0000 Subject: [PATCH 13/23] chore: release beta version 7.0.0-beta.17 --- .bumpversion.toml | 2 +- Cargo.lock | 56 +++++++++++++++++++-------------------- Cargo.toml | 40 ++++++++++++++-------------- java/lance-jni/Cargo.lock | 45 ++++++++++++++++--------------- java/lance-jni/Cargo.toml | 2 +- java/pom.xml | 2 +- python/Cargo.lock | 44 +++++++++++++++--------------- python/Cargo.toml | 2 +- 8 files changed, 97 insertions(+), 96 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 98be7158b5b..27b0534b009 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "7.0.0-beta.16" +current_version = "7.0.0-beta.17" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(-(?P(beta|rc))\\.(?P\\d+))?" serialize = [ "{major}.{minor}.{patch}-{prerelease}.{prerelease_num}", diff --git a/Cargo.lock b/Cargo.lock index 6e20ba7bf77..8cbddff051d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2848,9 +2848,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "encode_unicode" @@ -3093,7 +3093,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3189,9 +3189,9 @@ checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" +checksum = "af43fadb8a98512d547e37b4e92e0ced13e205c061b87b4623eff01d918d6968" [[package]] name = "futures-util" @@ -4321,7 +4321,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "all_asserts", "approx", @@ -4422,7 +4422,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -4470,7 +4470,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrayref", "paste", @@ -4479,7 +4479,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -4516,7 +4516,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4549,7 +4549,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4569,7 +4569,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-arith", "arrow-array", @@ -4614,7 +4614,7 @@ dependencies = [ [[package]] name = "lance-examples" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "all_asserts", "arrow", @@ -4640,7 +4640,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-arith", "arrow-array", @@ -4680,7 +4680,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "datafusion", "geo-traits", @@ -4694,7 +4694,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "approx", "arc-swap", @@ -4771,7 +4771,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-arith", @@ -4820,7 +4820,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "approx", "arrow-array", @@ -4841,7 +4841,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "async-trait", @@ -4853,7 +4853,7 @@ dependencies = [ [[package]] name = "lance-namespace-datafusion" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-schema", @@ -4869,7 +4869,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-ipc", @@ -4911,9 +4911,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4925,7 +4925,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4971,7 +4971,7 @@ dependencies = [ [[package]] name = "lance-test-macros" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "proc-macro2", "quote", @@ -4980,7 +4980,7 @@ dependencies = [ [[package]] name = "lance-testing" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-schema", @@ -4991,7 +4991,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "jieba-rs", "lindera", @@ -5002,7 +5002,7 @@ dependencies = [ [[package]] name = "lance-tools" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "clap", "lance-core", diff --git a/Cargo.toml b/Cargo.toml index 3a6c61acbe9..21c8420dcc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ resolver = "3" [workspace.package] -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" @@ -55,25 +55,25 @@ rust-version = "1.91.0" [workspace.dependencies] arc-swap = "1.7" libc = "0.2.176" -lance = { version = "=7.0.0-beta.16", path = "./rust/lance", default-features = false } -lance-arrow = { version = "=7.0.0-beta.16", path = "./rust/lance-arrow" } -lance-core = { version = "=7.0.0-beta.16", path = "./rust/lance-core" } -lance-datafusion = { version = "=7.0.0-beta.16", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=7.0.0-beta.16", path = "./rust/lance-datagen" } -lance-encoding = { version = "=7.0.0-beta.16", path = "./rust/lance-encoding" } -lance-file = { version = "=7.0.0-beta.16", path = "./rust/lance-file" } -lance-geo = { version = "=7.0.0-beta.16", path = "./rust/lance-geo" } -lance-index = { version = "=7.0.0-beta.16", path = "./rust/lance-index" } -lance-io = { version = "=7.0.0-beta.16", path = "./rust/lance-io", default-features = false } -lance-linalg = { version = "=7.0.0-beta.16", path = "./rust/lance-linalg" } -lance-namespace = { version = "=7.0.0-beta.16", path = "./rust/lance-namespace" } -lance-namespace-impls = { version = "=7.0.0-beta.16", path = "./rust/lance-namespace-impls" } +lance = { version = "=7.0.0-beta.17", path = "./rust/lance", default-features = false } +lance-arrow = { version = "=7.0.0-beta.17", path = "./rust/lance-arrow" } +lance-core = { version = "=7.0.0-beta.17", path = "./rust/lance-core" } +lance-datafusion = { version = "=7.0.0-beta.17", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=7.0.0-beta.17", path = "./rust/lance-datagen" } +lance-encoding = { version = "=7.0.0-beta.17", path = "./rust/lance-encoding" } +lance-file = { version = "=7.0.0-beta.17", path = "./rust/lance-file" } +lance-geo = { version = "=7.0.0-beta.17", path = "./rust/lance-geo" } +lance-index = { version = "=7.0.0-beta.17", path = "./rust/lance-index" } +lance-io = { version = "=7.0.0-beta.17", path = "./rust/lance-io", default-features = false } +lance-linalg = { version = "=7.0.0-beta.17", path = "./rust/lance-linalg" } +lance-namespace = { version = "=7.0.0-beta.17", path = "./rust/lance-namespace" } +lance-namespace-impls = { version = "=7.0.0-beta.17", path = "./rust/lance-namespace-impls" } lance-namespace-datafusion = { version = "=7.0.0-beta.9", path = "./rust/lance-namespace-datafusion" } lance-namespace-reqwest-client = "0.7.5" -lance-tokenizer = { version = "=7.0.0-beta.16", path = "./rust/lance-tokenizer" } -lance-table = { version = "=7.0.0-beta.16", path = "./rust/lance-table" } -lance-test-macros = { version = "=7.0.0-beta.16", path = "./rust/lance-test-macros" } -lance-testing = { version = "=7.0.0-beta.16", path = "./rust/lance-testing" } +lance-tokenizer = { version = "=7.0.0-beta.17", path = "./rust/lance-tokenizer" } +lance-table = { version = "=7.0.0-beta.17", path = "./rust/lance-table" } +lance-test-macros = { version = "=7.0.0-beta.17", path = "./rust/lance-test-macros" } +lance-testing = { version = "=7.0.0-beta.17", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "58.0.0", optional = false, features = ["prettyprint"] } @@ -99,7 +99,7 @@ half = { "version" = "2.1", default-features = false, features = [ "num-traits", "std", ] } -lance-bitpacking = { version = "=7.0.0-beta.16", path = "./rust/compression/bitpacking" } +lance-bitpacking = { version = "=7.0.0-beta.17", path = "./rust/compression/bitpacking" } bitpacking = "0.9" bitvec = "1" bytes = "1.11.1" @@ -139,7 +139,7 @@ deepsize = "0.2.0" dirs = "6.0.0" either = "1.0" fst = { version = "0.4.7", features = ["levenshtein"] } -fsst = { version = "=7.0.0-beta.16", path = "./rust/compression/fsst" } +fsst = { version = "=7.0.0-beta.17", path = "./rust/compression/fsst" } futures = "0.3" geoarrow-array = "0.8" geoarrow-schema = "0.8" diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index 0e2059bc749..26469fceea2 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -2344,9 +2344,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "encoding_rs" @@ -2509,7 +2509,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3616,7 +3616,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arc-swap", "arrow", @@ -3663,6 +3663,7 @@ dependencies = [ "lance-table", "lance-tokenizer", "log", + "moka", "object_store", "permutation", "pin-project", @@ -3686,7 +3687,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -3706,7 +3707,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrayref", "paste", @@ -3715,7 +3716,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -3750,7 +3751,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -3782,7 +3783,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -3800,7 +3801,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-arith", "arrow-array", @@ -3835,7 +3836,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-arith", "arrow-array", @@ -3866,7 +3867,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "datafusion", "geo-traits", @@ -3880,7 +3881,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arc-swap", "arrow", @@ -3947,7 +3948,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-arith", @@ -3989,7 +3990,7 @@ dependencies = [ [[package]] name = "lance-jni" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4025,7 +4026,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -4041,7 +4042,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "async-trait", @@ -4053,7 +4054,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-ipc", @@ -4083,9 +4084,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4097,7 +4098,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4134,7 +4135,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "rust-stemmers", "serde", diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 5c246583a3e..6e6f0e02b77 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lance-jni" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" edition = "2024" authors = ["Lance Devs "] rust-version = "1.91" diff --git a/java/pom.xml b/java/pom.xml index db32abad4d7..5494905a9ef 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -7,7 +7,7 @@ org.lance lance-core Lance Core - 7.0.0-beta.16 + 7.0.0-beta.17 jar Lance Format Java API diff --git a/python/Cargo.lock b/python/Cargo.lock index 882344b8c4f..6bd0ed065a1 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -2679,9 +2679,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "encoding_rs" @@ -2853,7 +2853,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3975,7 +3975,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arc-swap", "arrow", @@ -4047,7 +4047,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -4067,7 +4067,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrayref", "paste", @@ -4076,7 +4076,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -4111,7 +4111,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4143,7 +4143,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4161,7 +4161,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-arith", "arrow-array", @@ -4196,7 +4196,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-arith", "arrow-array", @@ -4227,7 +4227,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "datafusion", "geo-traits", @@ -4241,7 +4241,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arc-swap", "arrow", @@ -4309,7 +4309,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-arith", @@ -4351,7 +4351,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow-array", "arrow-buffer", @@ -4367,7 +4367,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "async-trait", @@ -4379,7 +4379,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-ipc", @@ -4409,9 +4409,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4423,7 +4423,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", @@ -4462,7 +4462,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "jieba-rs", "lindera", @@ -5882,7 +5882,7 @@ dependencies = [ [[package]] name = "pylance" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" dependencies = [ "arrow", "arrow-array", diff --git a/python/Cargo.toml b/python/Cargo.toml index 078e0ad6057..aadd5d9ddf0 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "7.0.0-beta.16" +version = "7.0.0-beta.17" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" From dbeeefbe77781023cb8e2011f6d6e47b6bae3221 Mon Sep 17 00:00:00 2001 From: Beinan Date: Wed, 20 May 2026 18:50:13 -0700 Subject: [PATCH 14/23] feat: expose granular trace event targets (#6853) ## Summary - Add granular tracing targets for Lance event categories under `lance::events::...`. - Preserve the original tracing target in Python log passthrough so `LANCE_LOG` can filter individual event types directly. - Document the new event targets and add Python regression coverage for target-specific log filtering. ## Testing - `PATH=/Users/beinan/.rustup/toolchains/1.94.0-aarch64-apple-darwin/bin:$PATH cargo fmt --all -- --check` - `git diff --check` - `PATH=/Users/beinan/.rustup/toolchains/1.94.0-aarch64-apple-darwin/bin:$PATH cargo check -p lance-core -p lance-io -p lance-index -p lance -p lance-datafusion` - `PATH=/Users/beinan/.rustup/toolchains/1.94.0-aarch64-apple-darwin/bin:$PATH cargo check --manifest-path python/Cargo.toml` - `uv sync --python 3.12 --no-install-project` - `.venv/bin/maturin develop --uv -v` - `uv run --no-sync --python 3.12 pytest -q python/tests/test_log.py::test_lance_log_filters_trace_event_targets python/tests/test_tracing.py::test_tracing_callback` - `uv run --no-sync --python 3.12 ruff format --check python/tests/test_log.py python/tests/test_tracing.py` - `uv run --no-sync --python 3.12 ruff check python/tests/test_log.py python/tests/test_tracing.py` --------- Co-authored-by: Beinan Wang --- docs/src/guide/performance.md | 32 +++++++++++++++- python/python/tests/test_log.py | 30 +++++++++++++++ python/src/tracing.rs | 14 ++++++- rust/lance-core/src/utils/tracing.rs | 1 + rust/lance-io/src/object_store/throttle.rs | 43 +++++++++++++++++----- 5 files changed, 109 insertions(+), 11 deletions(-) diff --git a/docs/src/guide/performance.md b/docs/src/guide/performance.md index a25e6a1584d..5590e3ff571 100644 --- a/docs/src/guide/performance.md +++ b/docs/src/guide/performance.md @@ -21,9 +21,14 @@ The Python/Java logger can be configured with several environment variables: ## Trace Events -Lance uses tracing to log events. If you are running `pylance` then these events will be emitted to +Lance uses tracing to log events. If you are running `pylance` then these events will be emitted as log messages. For Rust connections you can use the `tracing` crate to capture these events. +Rust tracing targets are listed below. In `pylance` logs, trace events are emitted under a +`lance::events::` prefix so they can be filtered separately from normal log records. For example, +`LANCE_LOG="warn,lance::events::object_store::throttle=info"` shows storage throttling events +without enabling other Lance event logs. + ### File Audit File audit events are emitted when significant files are created or deleted. @@ -33,6 +38,31 @@ File audit events are emitted when significant files are created or deleted. | `lance::file_audit` | `mode` | The mode of I/O operation (create, delete, delete_unverified) | | `lance::file_audit` | `type` | The type of file affected (manifest, data file, index file, deletion file) | +### Dataset Events + +Dataset events are emitted when datasets are loaded, written, committed, deleted, compacted, or cleaned. + +| Event | Parameter | Description | +| ----------------------- | ----------- | ------------------------------------------------------------------------- | +| `lance::dataset_events` | `event` | The dataset event type (loading, writing, committed, deleting, and others) | +| `lance::dataset_events` | `uri` | The dataset URI | +| `lance::dataset_events` | `mode` | The write mode | +| `lance::dataset_events` | `operation` | The committed transaction operation | +| `lance::dataset_events` | `predicate` | The delete predicate | +| `lance::dataset_events` | `columns` | The removed columns | + +### Object Store Throttle Events + +Object store throttle events are emitted when Lance observes cloud storage throttle responses and +reduces or retries request rates. + +| Event | Parameter | Description | +| -------------------------------- | --------------- | ---------------------------------------- | +| `lance::object_store::throttle` | `previous_rate` | The request rate before AIMD adjustment | +| `lance::object_store::throttle` | `new_rate` | The request rate after AIMD adjustment | +| `lance::object_store::throttle` | `attempt` | The retry attempt for retry debug events | +| `lance::object_store::throttle` | `error` | The underlying object store throttle error | + ### I/O Events I/O events are emitted when significant I/O operations are performed, particularly diff --git a/python/python/tests/test_log.py b/python/python/tests/test_log.py index b00fe5813a0..1eb02957044 100644 --- a/python/python/tests/test_log.py +++ b/python/python/tests/test_log.py @@ -148,6 +148,36 @@ def test_lance_log_file_with_directory_creation(tmp_path): assert len(log_content.strip()) > 0, "Log file is empty" +@pytest.mark.skipif( + sys.platform == "win32", + reason="subprocess does not work correctly in CI on Windows", +) +def test_lance_log_filters_trace_event_targets(tmp_path): + log_file = tmp_path / "lance_rust.log" + + result = subprocess.run( + [ + sys.executable, + "-c", + "import lance; import pyarrow as pa; " + "lance.write_dataset(pa.table({'x': range(10)}), 'memory://test')", + ], + capture_output=True, + env={ + "LANCE_LOG": "warn,lance::events::dataset_events=info", + "LANCE_LOG_FILE": str(log_file), + }, + ) + + assert result.returncode == 0, f"Command failed: {result.stderr.decode()}" + + log_content = log_file.read_text() + assert "lance::events::dataset_events" in log_content + assert 'target="lance::dataset_events"' in log_content + assert "lance::events::file_audit" not in log_content + assert "lance::file_audit" not in log_content + + @pytest.mark.skipif( sys.platform == "win32", reason="subprocess does not work correctly in CI on Windows", diff --git a/python/src/tracing.rs b/python/src/tracing.rs index fdeb3c49220..3aadae60060 100644 --- a/python/src/tracing.rs +++ b/python/src/tracing.rs @@ -219,6 +219,11 @@ impl Visit for EventToMap { } } +fn trace_event_log_target(target: &str) -> String { + let target = target.strip_prefix("lance::").unwrap_or(target); + format!("lance::events::{target}") +} + #[derive(Clone)] pub struct LoggingPassthroughRef(Arc>>); @@ -248,7 +253,14 @@ impl tracing_subscriber::Layer for LoggingPassthroughRef { let mut fields = EventToStr::default(); event.record(&mut fields); - log::log!(target: "lance::events", state.level, "target=\"{}\" {}", event.metadata().target(), fields.str); + let log_target = trace_event_log_target(event.metadata().target()); + log::log!( + target: &log_target, + state.level, + "target=\"{}\" {}", + event.metadata().target(), + fields.str + ); if let Some(callback_sender) = state.callback_sender.as_ref() { let mut args = EventToMap::default(); diff --git a/rust/lance-core/src/utils/tracing.rs b/rust/lance-core/src/utils/tracing.rs index e126fade06d..603a666e313 100644 --- a/rust/lance-core/src/utils/tracing.rs +++ b/rust/lance-core/src/utils/tracing.rs @@ -84,3 +84,4 @@ pub const DATASET_DELETING_EVENT: &str = "deleting"; pub const DATASET_COMPACTING_EVENT: &str = "compacting"; pub const DATASET_CLEANING_EVENT: &str = "cleaning"; pub const DATASET_LOADING_EVENT: &str = "loading"; +pub const TRACE_OBJECT_STORE_THROTTLE: &str = "lance::object_store::throttle"; diff --git a/rust/lance-io/src/object_store/throttle.rs b/rust/lance-io/src/object_store/throttle.rs index cbd678aa5cf..de33c5cd0d3 100644 --- a/rust/lance-io/src/object_store/throttle.rs +++ b/rust/lance-io/src/object_store/throttle.rs @@ -30,6 +30,7 @@ use bytes::Bytes; use futures::StreamExt; use futures::stream::BoxStream; use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome}; +use lance_core::utils::tracing::TRACE_OBJECT_STORE_THROTTLE; #[cfg(test)] use object_store::ObjectStoreExt; use object_store::path::Path; @@ -383,17 +384,25 @@ impl OperationThrottle { let outcome = match result { Ok(_) => RequestOutcome::Success, Err(err) if is_throttle_error(err) => { - debug!("Throttle error detected in stream"); + debug!( + target: TRACE_OBJECT_STORE_THROTTLE, + error = %err, + "Throttle error detected in stream" + ); RequestOutcome::Throttled } Err(_) => RequestOutcome::Success, }; let prev_rate = self.controller.current_rate(); let new_rate = self.controller.record_outcome(outcome); - if new_rate < prev_rate { + if new_rate < prev_rate + && let Err(err) = result.as_ref() + { warn!( + target: TRACE_OBJECT_STORE_THROTTLE, previous_rate = format!("{prev_rate:.1}"), new_rate = format!("{new_rate:.1}"), + error = %err, "AIMD throttle: rate reduced due to throttle errors" ); } @@ -416,17 +425,25 @@ impl OperationThrottle { let outcome = match &result { Ok(_) => RequestOutcome::Success, Err(err) if is_throttle_error(err) => { - debug!("Throttle error detected"); + debug!( + target: TRACE_OBJECT_STORE_THROTTLE, + error = %err, + "Throttle error detected" + ); RequestOutcome::Throttled } Err(_) => RequestOutcome::Success, // Non-throttle errors don't indicate capacity problems }; let prev_rate = self.controller.current_rate(); let new_rate = self.controller.record_outcome(outcome); - if new_rate < prev_rate { + if new_rate < prev_rate + && let Err(err) = result.as_ref() + { warn!( + target: TRACE_OBJECT_STORE_THROTTLE, previous_rate = format!("{prev_rate:.1}"), new_rate = format!("{new_rate:.1}"), + error = %err, "AIMD throttle: rate reduced due to throttle errors" ); } @@ -437,9 +454,11 @@ impl OperationThrottle { let backoff_ms = rand::rng().random_range(self.min_backoff_ms..=self.max_backoff_ms); debug!( + target: TRACE_OBJECT_STORE_THROTTLE, attempt = attempt + 1, max_retries = self.max_retries, backoff_ms, + error = %err, "Retrying after throttle error" ); tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await; @@ -723,6 +742,8 @@ mod tests { use std::collections::VecDeque; use std::sync::atomic::{AtomicU64, Ordering}; + const THROTTLE_ERROR_RESPONSE: &str = "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s - Server returned non-2xx status code: 503: x-ms-request-id: azure-request-id"; + fn make_generic_error(msg: &str) -> object_store::Error { object_store::Error::Generic { store: "test", @@ -1130,8 +1151,7 @@ mod tests { fn throttle_error() -> object_store::Error { object_store::Error::Generic { store: "RateLimitingMock", - source: "request failed, after 10 retries, max_retries: 10, retry_timeout: 180s" - .into(), + source: THROTTLE_ERROR_RESPONSE.into(), } } } @@ -1375,8 +1395,7 @@ mod tests { if should_error { Err(object_store::Error::Generic { store: "RetryTestMock", - source: "request failed, after 3 retries, max_retries: 3, retry_timeout: 30s" - .into(), + source: THROTTLE_ERROR_RESPONSE.into(), }) } else { self.inner.get_opts(location, options).await @@ -1448,7 +1467,13 @@ mod tests { let result = throttled.get(&path).await; assert!(result.is_err(), "Expected error after max retries"); - assert!(is_throttle_error(&result.unwrap_err())); + let err = result.unwrap_err(); + assert!(is_throttle_error(&err)); + + let lance_error = lance_core::Error::from(err); + let error_message = lance_error.to_string(); + assert!(error_message.contains("x-ms-request-id")); + assert!(error_message.contains("azure-request-id")); // Should have called get 4 times: initial attempt + 3 retries assert_eq!(mock.get_call_count.load(Ordering::Relaxed), 4); From 65a0e4105c2dbe188cf49b8d9ce3c770ae272ac7 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Wed, 20 May 2026 21:58:01 -0700 Subject: [PATCH 15/23] fix(mem_wal): exact PK dedup for LSM vector search (#6881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split from #6856 — vector-search portion. A primary key written multiple times into one memtable, or into both a memtable and an older generation, used to leak through as distinct rows: HNSW indexes every insert as its own graph node, so KNN could return both V1 and V2 of the same PK from a single source. Each per-source KNN now runs through `LsmSourceTagExec`, which appends `(_memtable_gen, _freshness)`. A single `LsmGlobalPkDedupExec` over the union keeps the row with the largest tuple per PK — newer generations win, ties fall to the normalized within-source order. This replaces the bloom-based `FilterStaleExec` design and is exact (no false-positive recall loss, no top-k under-fill). After global dedup + sort + top-k, a `TakeExec` materializes any user-projected columns not in the per-source KNN output by fetching from the base dataset via `_rowid`. `plan_search()` also accepts `refine_factor` so callers can enable base-table refine. Exposed in the Python and Java bindings. Removes `FilterStaleExec` and `GenerationBloomFilter`. Part of splitting #6856 into focused PRs. Co-authored with @jackye1995. Co-authored-by: Jack Ye --- java/lance-jni/src/mem_wal.rs | 24 +- .../lance/memwal/LsmVectorSearchPlanner.java | 31 +- python/python/lance/mem_wal.py | 10 +- python/src/mem_wal.rs | 10 +- rust/lance/benches/mem_wal_read.rs | 5 +- rust/lance/benches/mem_wal_vector.rs | 5 +- .../lance/src/dataset/mem_wal/scanner/exec.rs | 9 +- .../mem_wal/scanner/exec/filter_stale.rs | 590 ------------------ .../mem_wal/scanner/exec/global_pk_dedup.rs | 459 ++++++++++++++ .../mem_wal/scanner/exec/source_tag.rs | 404 ++++++++++++ .../src/dataset/mem_wal/scanner/projection.rs | 30 + .../dataset/mem_wal/scanner/vector_search.rs | 575 ++++++++++++----- 12 files changed, 1391 insertions(+), 761 deletions(-) delete mode 100644 rust/lance/src/dataset/mem_wal/scanner/exec/filter_stale.rs create mode 100644 rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs create mode 100644 rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs diff --git a/java/lance-jni/src/mem_wal.rs b/java/lance-jni/src/mem_wal.rs index 6457e5b1419..f62b0274c8e 100644 --- a/java/lance-jni/src/mem_wal.rs +++ b/java/lance-jni/src/mem_wal.rs @@ -827,14 +827,15 @@ fn inner_create_vector_planner( let dist_type = parse_distance_type(distance_type.as_deref().unwrap_or("l2"))?; let vector_dim = get_vector_dim(&dataset, &vector_column)?; - let collector = LsmDataSourceCollector::new(dataset, snapshots); + let collector = LsmDataSourceCollector::new(dataset.clone(), snapshots); let planner = LsmVectorSearchPlanner::new( collector, pk_columns, base_schema.clone(), vector_column, dist_type, - ); + ) + .with_dataset(dataset); let blocking = BlockingLsmVectorSearchPlanner { planner, @@ -854,10 +855,25 @@ pub extern "system" fn Java_org_lance_memwal_LsmVectorSearchPlanner_nativePlanSe k: jint, nprobes: jint, columns: JObject<'local>, + refine_factor: jint, ) -> JObject<'local> { + let refine = if refine_factor > 0 { + Some(refine_factor as u32) + } else { + None + }; ok_or_throw!( env, - inner_plan_search(&mut env, this, array_addr, schema_addr, k, nprobes, columns) + inner_plan_search( + &mut env, + this, + array_addr, + schema_addr, + k, + nprobes, + columns, + refine + ) ) } @@ -870,6 +886,7 @@ fn inner_plan_search<'local>( k: jint, nprobes: jint, columns: JObject<'local>, + refine_factor: Option, ) -> Result> { let query = import_ffi_array(array_addr, schema_addr)?; let columns = env.get_strings_opt(&columns)?; @@ -901,6 +918,7 @@ fn inner_plan_search<'local>( k as usize, nprobes as usize, columns.as_deref(), + refine_factor, ))?; (plan, guard.dataset_schema.clone()) }; diff --git a/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java b/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java index b4567db3acd..1fe0e8bd020 100644 --- a/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java +++ b/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java @@ -93,9 +93,12 @@ private native void nativeCreate( * @param nprobes number of IVF partitions to probe * @param columns columns to project; pass {@code null} to return all columns plus {@code * _distance} + * @param refineFactor when positive, the base-table arm over-fetches {@code k * refineFactor} + * candidates and re-ranks them with exact distances. Pass {@code 0} or negative to disable. * @return an executable plan */ - public ExecutionPlan planSearch(Float4Vector query, int k, int nprobes, List columns) { + public ExecutionPlan planSearch( + Float4Vector query, int k, int nprobes, List columns, int refineFactor) { Preconditions.checkNotNull(query, "query must not be null"); Preconditions.checkArgument(k > 0, "k must be positive, got %s", k); Preconditions.checkArgument(nprobes > 0, "nprobes must be positive, got %s", nprobes); @@ -111,20 +114,40 @@ public ExecutionPlan planSearch(Float4Vector query, int k, int nprobes, List columns) { + return planSearch(query, k, nprobes, columns, 0); + } + /** Plan a KNN vector search with default {@code nprobes} of 20. */ public ExecutionPlan planSearch(Float4Vector query, int k) { - return planSearch(query, k, 20, null); + return planSearch(query, k, 20, null, 0); } private native ExecutionPlan nativePlanSearch( - long arrayAddress, long schemaAddress, int k, int nprobes, Optional> columns); + long arrayAddress, + long schemaAddress, + int k, + int nprobes, + Optional> columns, + int refineFactor); /** * Close the planner and release native resources. If the planner is already closed, invoking this diff --git a/python/python/lance/mem_wal.py b/python/python/lance/mem_wal.py index 2ca293d790d..d2ccd463775 100644 --- a/python/python/lance/mem_wal.py +++ b/python/python/lance/mem_wal.py @@ -525,6 +525,7 @@ def plan_search( k: int = 10, nprobes: int = 20, columns: Optional[List[str]] = None, + refine_factor: Optional[int] = None, ) -> ExecutionPlan: """Plan a KNN vector search. @@ -539,6 +540,11 @@ def plan_search( columns : list of str, optional Columns to project. Returns all columns + ``_distance`` if omitted. + refine_factor : int, optional + When set, the base-table arm fetches ``k * refine_factor`` + candidates and re-ranks with exact distances. Useful when + the base index is approximate (e.g. IVF-PQ). Memtable arms + use exact HNSW and are unaffected. Returns ------- @@ -546,7 +552,9 @@ def plan_search( Physical plan for the vector search. Execute it via `to_table`, `to_reader`, or `to_batches`. """ - return ExecutionPlan(self._raw.plan_search(query, k, nprobes, columns)) + return ExecutionPlan( + self._raw.plan_search(query, k, nprobes, columns, refine_factor) + ) def _unwrap_shard_id(shard_id: str) -> str: diff --git a/python/src/mem_wal.rs b/python/src/mem_wal.rs index b9aca4e15a1..5a8516045ab 100644 --- a/python/src/mem_wal.rs +++ b/python/src/mem_wal.rs @@ -727,14 +727,15 @@ impl PyLsmVectorSearchPlanner { let vector_dim = get_vector_dim(&ds, &vector_column)?; - let collector = LsmDataSourceCollector::new(ds, snapshots); + let collector = LsmDataSourceCollector::new(ds.clone(), snapshots); let planner = LsmVectorSearchPlanner::new( collector, pk_cols, base_schema.clone(), vector_column, dist_type, - ); + ) + .with_dataset(ds); Ok(Self { planner, @@ -746,7 +747,7 @@ impl PyLsmVectorSearchPlanner { /// Plan a KNN vector search. /// /// `query` should be a flat PyArrow Float32Array with `vector_dim` elements. - #[pyo3(signature = (query, k=10, nprobes=20, columns=None))] + #[pyo3(signature = (query, k=10, nprobes=20, columns=None, refine_factor=None))] pub fn plan_search( &self, py: Python<'_>, @@ -754,6 +755,7 @@ impl PyLsmVectorSearchPlanner { k: usize, nprobes: usize, columns: Option>, + refine_factor: Option, ) -> PyResult { let query_array = make_array(query.0); let float32_array = query_array @@ -787,7 +789,7 @@ impl PyLsmVectorSearchPlanner { let plan = rt() .block_on(Some(py), async { planner_ref - .plan_search(&fsl, k, nprobes, columns.as_deref()) + .plan_search(&fsl, k, nprobes, columns.as_deref(), refine_factor) .await })? .map_err(|e| PyIOError::new_err(e.to_string()))?; diff --git a/rust/lance/benches/mem_wal_read.rs b/rust/lance/benches/mem_wal_read.rs index 31c2c4581ef..f00bcae7539 100644 --- a/rust/lance/benches/mem_wal_read.rs +++ b/rust/lance/benches/mem_wal_read.rs @@ -1005,7 +1005,10 @@ fn bench_vector_search(c: &mut Criterion) { "vector".to_string(), DistanceType::L2, ); - let plan = planner.plan_search(&query, k, nprobes, None).await.unwrap(); + let plan = planner + .plan_search(&query, k, nprobes, None, None) + .await + .unwrap(); let session_ctx = SessionContext::new(); let stream = plan.execute(0, session_ctx.task_ctx()).unwrap(); let batches: Vec = stream.try_collect().await.unwrap(); diff --git a/rust/lance/benches/mem_wal_vector.rs b/rust/lance/benches/mem_wal_vector.rs index 2ed94d5b9b5..63081801124 100644 --- a/rust/lance/benches/mem_wal_vector.rs +++ b/rust/lance/benches/mem_wal_vector.rs @@ -434,7 +434,10 @@ fn bench_vector_search(c: &mut Criterion) { "vector".to_string(), DistanceType::L2, ); - let plan = planner.plan_search(&query, k, nprobes, None).await.unwrap(); + let plan = planner + .plan_search(&query, k, nprobes, None, None) + .await + .unwrap(); let session_ctx = SessionContext::new(); let stream = plan.execute(0, session_ctx.task_ctx()).unwrap(); let batches: Vec = stream.try_collect().await.unwrap(); diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec.rs b/rust/lance/src/dataset/mem_wal/scanner/exec.rs index 705deaee631..c2ed01bb1f2 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/exec.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/exec.rs @@ -10,16 +10,19 @@ //! - [`DeduplicateExec`]: Deduplicates by primary key, keeping newest version //! - [`BloomFilterGuardExec`]: Guards child execution with bloom filter check //! - [`CoalesceFirstExec`]: Returns first non-empty result with short-circuit -//! - [`FilterStaleExec`]: Filters out rows with newer versions in higher generations +//! - [`LsmSourceTagExec`]: Tags rows with `_memtable_gen` + `_freshness` for the vector-search global dedup +//! - [`LsmGlobalPkDedupExec`]: Single-pass cross-source PK dedup over the merged vector-search stream mod bloom_guard; mod coalesce_first; mod deduplicate; -mod filter_stale; mod generation_tag; +mod global_pk_dedup; +mod source_tag; pub use bloom_guard::{BloomFilterGuardExec, compute_pk_hash_from_scalars}; pub use coalesce_first::CoalesceFirstExec; pub use deduplicate::{DeduplicateExec, ROW_ADDRESS_COLUMN}; -pub use filter_stale::{FilterStaleExec, GenerationBloomFilter}; pub use generation_tag::{MEMTABLE_GEN_COLUMN, MemtableGenTagExec}; +pub use global_pk_dedup::LsmGlobalPkDedupExec; +pub use source_tag::{FRESHNESS_COLUMN, FreshnessPolarity, LsmSourceTagExec}; diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/filter_stale.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/filter_stale.rs deleted file mode 100644 index b6e16cd869c..00000000000 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/filter_stale.rs +++ /dev/null @@ -1,590 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! FilterStaleExec - Filters out rows that have newer versions in higher generations. -//! -//! Used in vector search and FTS queries to detect stale results across LSM levels. - -use std::any::Any; -use std::fmt; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow_array::{Array, RecordBatch, UInt64Array}; -use arrow_schema::SchemaRef; -use datafusion::error::Result as DFResult; -use datafusion::execution::TaskContext; -use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, - SendableRecordBatchStream, -}; -use futures::{Stream, StreamExt}; -use lance_index::scalar::bloomfilter::sbbf::Sbbf; - -use super::generation_tag::MEMTABLE_GEN_COLUMN; - -/// Bloom filter for a specific generation. -#[derive(Clone)] -pub struct GenerationBloomFilter { - /// Generation number (0 = base table, 1+ = memtables). - pub generation: u64, - /// The bloom filter. - pub bloom_filter: Arc, -} - -impl std::fmt::Debug for GenerationBloomFilter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("GenerationBloomFilter") - .field("generation", &self.generation) - .field( - "bloom_filter_size", - &self.bloom_filter.estimated_memory_size(), - ) - .finish() - } -} - -/// Filters out rows that have a newer version in a higher generation. -/// -/// For each candidate row with primary key `pk` from generation G, this node -/// checks bloom filters of all generations > G. If the bloom filter indicates -/// the key may exist in a newer generation, the candidate is filtered out. -/// -/// # Bloom Filter Behavior -/// -/// - False negatives: impossible (if key is in bloom filter, `check_hash` returns true) -/// - False positives: possible (may filter valid results that don't actually have newer versions) -/// -/// This is acceptable for approximate search workloads (vector, FTS) where some -/// loss of recall is tolerable. The false positive rate is typically < 0.1%. -/// -/// # Required Columns -/// -/// The input must have: -/// - `_memtable_gen` (UInt64): Generation number for each row -/// - Primary key columns: Used for bloom filter hash computation -/// -/// # Performance -/// -/// - O(G) bloom filter checks per row, where G = number of newer generations -/// - Bloom filter checks are O(1) -/// - Overall: O(N * G) where N = input rows -#[derive(Debug)] -pub struct FilterStaleExec { - /// Child execution plan. - input: Arc, - /// Primary key column names (for hash computation). - pk_columns: Vec, - /// Bloom filters for each generation, sorted by generation DESC. - bloom_filters: Vec, - /// Output schema. - schema: SchemaRef, - /// Plan properties. - properties: Arc, -} - -impl FilterStaleExec { - /// Create a new FilterStaleExec. - /// - /// # Arguments - /// - /// * `input` - Child plan producing rows with `_memtable_gen` column - /// * `pk_columns` - Primary key column names for bloom filter hash - /// * `bloom_filters` - Bloom filters for each generation (will be sorted by gen DESC) - pub fn new( - input: Arc, - pk_columns: Vec, - bloom_filters: Vec, - ) -> Self { - let schema = input.schema(); - - // Sort bloom filters by generation DESC for efficient lookup - let mut bloom_filters = bloom_filters; - bloom_filters.sort_by(|a, b| b.generation.cmp(&a.generation)); - - let properties = Arc::new(PlanProperties::new( - EquivalenceProperties::new(schema.clone()), - Partitioning::UnknownPartitioning(1), - input.pipeline_behavior(), - input.boundedness(), - )); - - Self { - input, - pk_columns, - bloom_filters, - schema, - properties, - } - } - - /// Get the primary key columns. - pub fn pk_columns(&self) -> &[String] { - &self.pk_columns - } - - /// Get the bloom filters. - pub fn bloom_filters(&self) -> &[GenerationBloomFilter] { - &self.bloom_filters - } -} - -impl DisplayAs for FilterStaleExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default - | DisplayFormatType::Verbose - | DisplayFormatType::TreeRender => { - let gens: Vec = self - .bloom_filters - .iter() - .map(|bf| bf.generation.to_string()) - .collect(); - write!( - f, - "FilterStaleExec: pk=[{}], generations=[{}]", - self.pk_columns.join(", "), - gens.join(", ") - ) - } - } - } -} - -impl ExecutionPlan for FilterStaleExec { - fn name(&self) -> &str { - "FilterStaleExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn properties(&self) -> &Arc { - &self.properties - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> DFResult> { - if children.len() != 1 { - return Err(datafusion::error::DataFusionError::Internal( - "FilterStaleExec requires exactly one child".to_string(), - )); - } - Ok(Arc::new(Self::new( - children[0].clone(), - self.pk_columns.clone(), - self.bloom_filters.clone(), - ))) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> DFResult { - let input_stream = self.input.execute(partition, context)?; - - Ok(Box::pin(FilterStaleStream::new( - input_stream, - self.pk_columns.clone(), - self.bloom_filters.clone(), - self.schema.clone(), - ))) - } -} - -/// Stream that filters out stale rows. -struct FilterStaleStream { - /// Input stream. - input: SendableRecordBatchStream, - /// Primary key column names. - pk_columns: Vec, - /// Bloom filters sorted by generation DESC. - bloom_filters: Vec, - /// Output schema. - schema: SchemaRef, -} - -impl FilterStaleStream { - fn new( - input: SendableRecordBatchStream, - pk_columns: Vec, - bloom_filters: Vec, - schema: SchemaRef, - ) -> Self { - Self { - input, - pk_columns, - bloom_filters, - schema, - } - } - - /// Check if a row is stale (has newer version in higher generation). - fn is_stale(&self, pk_hash: u64, row_generation: u64) -> bool { - for bf in &self.bloom_filters { - // Bloom filters are sorted DESC, so we can stop early - if bf.generation <= row_generation { - break; - } - if bf.bloom_filter.check_hash(pk_hash) { - return true; - } - } - false - } - - /// Process a batch and filter out stale rows. - fn filter_batch(&self, batch: RecordBatch) -> DFResult { - if batch.num_rows() == 0 { - return Ok(batch); - } - - let gen_col = batch.column_by_name(MEMTABLE_GEN_COLUMN).ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Column '{}' not found in batch", - MEMTABLE_GEN_COLUMN - )) - })?; - let gen_array = gen_col - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Column '{}' is not UInt64", - MEMTABLE_GEN_COLUMN - )) - })?; - - let pk_indices: Vec = self - .pk_columns - .iter() - .map(|col| { - batch - .schema() - .column_with_name(col) - .map(|(idx, _)| idx) - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Primary key column '{}' not found", - col - )) - }) - }) - .collect::>>()?; - - let mut keep_indices: Vec = Vec::new(); - - for row_idx in 0..batch.num_rows() { - let row_generation = gen_array.value(row_idx); - let pk_hash = compute_pk_hash(&batch, &pk_indices, row_idx); - - if !self.is_stale(pk_hash, row_generation) { - keep_indices.push(row_idx as u32); - } - } - - if keep_indices.len() == batch.num_rows() { - return Ok(batch); - } - - if keep_indices.is_empty() { - return Ok(RecordBatch::new_empty(self.schema.clone())); - } - - let indices = arrow_array::UInt32Array::from(keep_indices); - let columns: Vec> = batch - .columns() - .iter() - .map(|col| arrow_select::take::take(col.as_ref(), &indices, None)) - .collect::, _>>() - .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?; - - RecordBatch::try_new(self.schema.clone(), columns) - .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) - } -} - -/// Compute hash for a row's primary key. -fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - - for &col_idx in pk_indices { - let col = batch.column(col_idx); - let is_null = col.is_null(row_idx); - is_null.hash(&mut hasher); - - if !is_null { - if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } - // Add more types as needed - } - } - - hasher.finish() -} - -impl Stream for FilterStaleStream { - type Item = DFResult; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.input.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - let filtered = self.filter_batch(batch); - Poll::Ready(Some(filtered)) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -impl datafusion::physical_plan::RecordBatchStream for FilterStaleStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::{Float32Array, Int32Array, StringArray}; - use arrow_schema::{DataType, Field, Schema}; - use datafusion::prelude::SessionContext; - use datafusion_physical_plan::test::TestMemoryExec; - use futures::TryStreamExt; - - fn create_test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - Field::new("_distance", DataType::Float32, true), - Field::new(MEMTABLE_GEN_COLUMN, DataType::UInt64, false), - ])) - } - - fn create_test_batch(schema: &Schema, ids: &[i32], generation: u64) -> RecordBatch { - let names: Vec = ids.iter().map(|id| format!("name_{}", id)).collect(); - let distances: Vec = ids.iter().map(|id| *id as f32 * 0.1).collect(); - let gens: Vec = vec![generation; ids.len()]; - - RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(ids.to_vec())), - Arc::new(StringArray::from(names)), - Arc::new(Float32Array::from(distances)), - Arc::new(UInt64Array::from(gens)), - ], - ) - .unwrap() - } - - fn create_bloom_filter_with_keys(ids: &[i32]) -> Arc { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut bf = Sbbf::with_ndv_fpp(100, 0.01).unwrap(); - for id in ids { - let mut hasher = DefaultHasher::new(); - false.hash(&mut hasher); // is_null = false - id.hash(&mut hasher); - let hash = hasher.finish(); - bf.insert_hash(hash); - } - Arc::new(bf) - } - - #[tokio::test] - async fn test_filter_stale_removes_rows_with_newer_versions() { - let schema = create_test_schema(); - - // Batch with rows from gen1: ids 1, 2, 3 - let batch = create_test_batch(&schema, &[1, 2, 3], 1); - - // Bloom filter for gen2 contains id=2 - let bf_gen2 = GenerationBloomFilter { - generation: 2, - bloom_filter: create_bloom_filter_with_keys(&[2]), - }; - - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema.clone(), None).unwrap(); - let filter = FilterStaleExec::new(input, vec!["id".to_string()], vec![bf_gen2]); - - let ctx = SessionContext::new(); - let stream = filter.execute(0, ctx.task_ctx()).unwrap(); - let batches: Vec = stream.try_collect().await.unwrap(); - - // id=2 should be filtered (stale - exists in gen2) - // id=1 and id=3 should remain - let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total_rows, 2); - - let ids: Vec = batches - .iter() - .flat_map(|b| { - b.column(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec() - }) - .collect(); - assert!(ids.contains(&1)); - assert!(!ids.contains(&2)); // filtered - assert!(ids.contains(&3)); - } - - #[tokio::test] - async fn test_filter_stale_respects_generation_order() { - let schema = create_test_schema(); - - // Batch from gen2 with ids 1, 2 - let batch = create_test_batch(&schema, &[1, 2], 2); - - // Bloom filter for gen1 (older) contains id=1 - // This should NOT filter id=1 because gen1 < gen2 - let bf_gen1 = GenerationBloomFilter { - generation: 1, - bloom_filter: create_bloom_filter_with_keys(&[1]), - }; - - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema.clone(), None).unwrap(); - let filter = FilterStaleExec::new(input, vec!["id".to_string()], vec![bf_gen1]); - - let ctx = SessionContext::new(); - let stream = filter.execute(0, ctx.task_ctx()).unwrap(); - let batches: Vec = stream.try_collect().await.unwrap(); - - // No rows should be filtered - gen1 bloom filter is for older gen - let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total_rows, 2); - } - - #[tokio::test] - async fn test_filter_stale_multiple_bloom_filters() { - let schema = create_test_schema(); - - // Batch from gen1 with ids 1, 2, 3, 4 - let batch = create_test_batch(&schema, &[1, 2, 3, 4], 1); - - // gen2 contains id=2, gen3 contains id=4 - let bf_gen2 = GenerationBloomFilter { - generation: 2, - bloom_filter: create_bloom_filter_with_keys(&[2]), - }; - let bf_gen3 = GenerationBloomFilter { - generation: 3, - bloom_filter: create_bloom_filter_with_keys(&[4]), - }; - - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema.clone(), None).unwrap(); - let filter = FilterStaleExec::new(input, vec!["id".to_string()], vec![bf_gen2, bf_gen3]); - - let ctx = SessionContext::new(); - let stream = filter.execute(0, ctx.task_ctx()).unwrap(); - let batches: Vec = stream.try_collect().await.unwrap(); - - // id=2 and id=4 should be filtered - let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total_rows, 2); - - let ids: Vec = batches - .iter() - .flat_map(|b| { - b.column(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec() - }) - .collect(); - assert!(ids.contains(&1)); - assert!(ids.contains(&3)); - } - - #[tokio::test] - async fn test_filter_stale_no_bloom_filters() { - let schema = create_test_schema(); - let batch = create_test_batch(&schema, &[1, 2, 3], 1); - - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema.clone(), None).unwrap(); - let filter = FilterStaleExec::new(input, vec!["id".to_string()], vec![]); - - let ctx = SessionContext::new(); - let stream = filter.execute(0, ctx.task_ctx()).unwrap(); - let batches: Vec = stream.try_collect().await.unwrap(); - - // No bloom filters = nothing filtered - let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total_rows, 3); - } - - #[tokio::test] - async fn test_filter_stale_empty_batch() { - let schema = create_test_schema(); - let batch = RecordBatch::new_empty(schema.clone()); - - let bf = GenerationBloomFilter { - generation: 2, - bloom_filter: create_bloom_filter_with_keys(&[1]), - }; - - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema.clone(), None).unwrap(); - let filter = FilterStaleExec::new(input, vec!["id".to_string()], vec![bf]); - - let ctx = SessionContext::new(); - let stream = filter.execute(0, ctx.task_ctx()).unwrap(); - let batches: Vec = stream.try_collect().await.unwrap(); - - assert_eq!(batches.len(), 1); - assert_eq!(batches[0].num_rows(), 0); - } - - #[test] - fn test_display() { - let schema = create_test_schema(); - let batch = RecordBatch::new_empty(schema.clone()); - let input = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); - - let bf = GenerationBloomFilter { - generation: 2, - bloom_filter: create_bloom_filter_with_keys(&[1]), - }; - - let filter = FilterStaleExec::new(input, vec!["id".to_string()], vec![bf]); - - // Verify it doesn't panic - let _ = format!("{:?}", filter); - } -} diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs new file mode 100644 index 00000000000..f35de7771b0 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs @@ -0,0 +1,459 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Global, exact primary-key deduplication for the LSM vector-search +//! pipeline. +//! +//! Replaces the older two-step `WithinSourceDedupExec` + `FilterStaleExec` +//! design with a single streaming hash-by-PK pass over the merged stream. +//! For each PK the row with the largest `(generation, freshness)` tuple +//! wins — generation is the source identity (base = 0, memtable gens 1..N, +//! active = N+1) and freshness is the per-source row order normalized so +//! that "larger = newer" (see [`super::LsmSourceTagExec`]). +//! +//! Compared with the bloom-based staleness filter this is: +//! +//! - Exact (no false-positive recall loss, no top-k under-fill, no +//! missing-bloom footgun). +//! - One node instead of two (no separate per-source dedup wrap). +//! - O(unique PKs in the merged stream) state — typically far smaller +//! than the n_sources · k upper bound because most PKs collide across +//! sources for typical LSM update workloads. + +use std::any::Any; +use std::collections::HashMap; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow_array::{Array, RecordBatch, UInt64Array}; +use arrow_schema::SchemaRef; +use datafusion::error::Result as DFResult; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, +}; +use futures::{Stream, StreamExt, ready}; + +/// Cross-source PK dedup. Keeps one row per primary key — the one with +/// the largest `(generation, freshness)` tuple. +/// +/// # Required input columns +/// +/// - `pk_columns` — the primary key columns. +/// - `generation_column` (UInt64, NOT NULL) — typically +/// [`super::MEMTABLE_GEN_COLUMN`]. +/// - `freshness_column` (UInt64, nullable) — typically +/// [`super::FRESHNESS_COLUMN`]. NULL-freshness rows are skipped (they +/// can't be ordered against real values). +/// +/// The output schema is unchanged from the input. Callers that need to +/// drop the generation / freshness columns from the final output should +/// compose this node with a downstream `project_to_canonical`. +#[derive(Debug)] +pub struct LsmGlobalPkDedupExec { + input: Arc, + pk_columns: Vec, + generation_column: String, + freshness_column: String, + schema: SchemaRef, + properties: Arc, +} + +impl LsmGlobalPkDedupExec { + pub fn new( + input: Arc, + pk_columns: Vec, + generation_column: impl Into, + freshness_column: impl Into, + ) -> Self { + let schema = input.schema(); + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + input.pipeline_behavior(), + input.boundedness(), + )); + Self { + input, + pk_columns, + generation_column: generation_column.into(), + freshness_column: freshness_column.into(), + schema, + properties, + } + } + + pub fn pk_columns(&self) -> &[String] { + &self.pk_columns + } + + pub fn generation_column(&self) -> &str { + &self.generation_column + } + + pub fn freshness_column(&self) -> &str { + &self.freshness_column + } +} + +impl DisplayAs for LsmGlobalPkDedupExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!( + f, + "LsmGlobalPkDedupExec: pk=[{}], gen={}, freshness={}", + self.pk_columns.join(", "), + self.generation_column, + self.freshness_column, + ) + } + } + } +} + +impl ExecutionPlan for LsmGlobalPkDedupExec { + fn name(&self) -> &str { + "LsmGlobalPkDedupExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + if children.len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "LsmGlobalPkDedupExec requires exactly one child".to_string(), + )); + } + Ok(Arc::new(Self::new( + children[0].clone(), + self.pk_columns.clone(), + self.generation_column.clone(), + self.freshness_column.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let input_stream = self.input.execute(partition, context)?; + Ok(Box::pin(GlobalPkDedupStream { + input: input_stream, + pk_columns: self.pk_columns.clone(), + generation_column: self.generation_column.clone(), + freshness_column: self.freshness_column.clone(), + schema: self.schema.clone(), + winners: HashMap::new(), + emitted: false, + })) + } +} + +struct Winner { + batch: RecordBatch, + generation: u64, + freshness: u64, +} + +struct GlobalPkDedupStream { + input: SendableRecordBatchStream, + pk_columns: Vec, + generation_column: String, + freshness_column: String, + schema: SchemaRef, + winners: HashMap, + emitted: bool, +} + +impl GlobalPkDedupStream { + fn consume_batch(&mut self, batch: RecordBatch) -> DFResult<()> { + if batch.num_rows() == 0 { + return Ok(()); + } + let pk_indices = resolve_pk_indices(&batch, &self.pk_columns)?; + let gen_arr = batch + .column_by_name(&self.generation_column) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Generation column '{}' not found in batch", + self.generation_column + )) + })? + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Generation column '{}' is not UInt64", + self.generation_column + )) + })?; + let fresh_arr = batch + .column_by_name(&self.freshness_column) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Freshness column '{}' not found in batch", + self.freshness_column + )) + })? + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Freshness column '{}' is not UInt64", + self.freshness_column + )) + })?; + + for row_idx in 0..batch.num_rows() { + if fresh_arr.is_null(row_idx) { + // A NULL freshness can't be ordered against a real value; + // skip rather than guess. Callers tag with a real value + // for every row eligible to win. + continue; + } + let generation = gen_arr.value(row_idx); + let fresh = fresh_arr.value(row_idx); + let pk_hash = compute_pk_hash(&batch, &pk_indices, row_idx); + + let take_row = match self.winners.get(&pk_hash) { + None => true, + Some(existing) => (generation, fresh) > (existing.generation, existing.freshness), + }; + + if take_row { + let single = batch.slice(row_idx, 1); + self.winners.insert( + pk_hash, + Winner { + batch: single, + generation, + freshness: fresh, + }, + ); + } + } + Ok(()) + } + + fn finalize(&mut self) -> DFResult { + if self.winners.is_empty() { + return Ok(RecordBatch::new_empty(self.schema.clone())); + } + let batches: Vec = self.winners.drain().map(|(_, w)| w.batch).collect(); + let batch_refs: Vec<&RecordBatch> = batches.iter().collect(); + arrow_select::concat::concat_batches(&self.schema, batch_refs) + .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) + } +} + +fn resolve_pk_indices(batch: &RecordBatch, pk_columns: &[String]) -> DFResult> { + pk_columns + .iter() + .map(|col| { + batch + .schema() + .column_with_name(col) + .map(|(idx, _)| idx) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Primary key column '{}' not found", + col + )) + }) + }) + .collect() +} + +/// Hash a row's primary key. Mirrors the variants supported by +/// [`super::WithinSourceDedupExec`] / `BloomFilterGuardExec`, so a single +/// PK produces the same hash everywhere in the LSM scanner. +fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + for &col_idx in pk_indices { + let col = batch.column(col_idx); + let is_null = col.is_null(row_idx); + is_null.hash(&mut hasher); + + if !is_null { + if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } + } + } + hasher.finish() +} + +impl Stream for GlobalPkDedupStream { + type Item = DFResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.emitted { + return Poll::Ready(None); + } + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if let Err(e) = self.consume_batch(batch) { + self.emitted = true; + return Poll::Ready(Some(Err(e))); + } + } + Some(Err(e)) => { + self.emitted = true; + return Poll::Ready(Some(Err(e))); + } + None => { + self.emitted = true; + return Poll::Ready(Some(self.finalize())); + } + } + } + } +} + +impl datafusion::physical_plan::RecordBatchStream for GlobalPkDedupStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Int32Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::prelude::SessionContext; + use datafusion_physical_plan::test::TestMemoryExec; + use futures::TryStreamExt; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("_memtable_gen", DataType::UInt64, false), + Field::new("_freshness", DataType::UInt64, true), + ])) + } + + fn batch(ids: &[i32], gens: &[u64], fresh: &[Option]) -> RecordBatch { + RecordBatch::try_new( + test_schema(), + vec![ + Arc::new(Int32Array::from(ids.to_vec())), + Arc::new(UInt64Array::from(gens.to_vec())), + Arc::new(UInt64Array::from(fresh.to_vec())), + ], + ) + .unwrap() + } + + async fn run(batches: Vec) -> Vec { + let schema = test_schema(); + let input = TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap(); + let exec = + LsmGlobalPkDedupExec::new(input, vec!["id".to_string()], "_memtable_gen", "_freshness"); + let ctx = SessionContext::new(); + let stream = exec.execute(0, ctx.task_ctx()).unwrap(); + stream.try_collect().await.unwrap() + } + + fn extract(batches: &[RecordBatch]) -> Vec<(i32, u64, Option)> { + let mut rows = Vec::new(); + for b in batches { + let ids = b.column(0).as_any().downcast_ref::().unwrap(); + let gens = b.column(1).as_any().downcast_ref::().unwrap(); + let fresh = b.column(2).as_any().downcast_ref::().unwrap(); + for i in 0..b.num_rows() { + rows.push(( + ids.value(i), + gens.value(i), + if fresh.is_null(i) { + None + } else { + Some(fresh.value(i)) + }, + )); + } + } + rows.sort_by_key(|r| r.0); + rows + } + + #[tokio::test] + async fn keeps_higher_freshness_within_single_generation() { + let b = batch(&[1, 1, 2], &[3, 3, 3], &[Some(10), Some(99), Some(5)]); + let rows = extract(&run(vec![b]).await); + assert_eq!(rows, vec![(1, 3, Some(99)), (2, 3, Some(5))]); + } + + #[tokio::test] + async fn higher_generation_beats_higher_freshness() { + let b = batch(&[1, 1, 2], &[1, 2, 2], &[Some(u64::MAX), Some(0), Some(5)]); + // id=1 in gen=2 with freshness 0 wins over gen=1 with freshness MAX. + let rows = extract(&run(vec![b]).await); + assert_eq!(rows, vec![(1, 2, Some(0)), (2, 2, Some(5))]); + } + + #[tokio::test] + async fn dedup_across_batches() { + let b1 = batch(&[1, 2], &[1, 2], &[Some(5), Some(5)]); + let b2 = batch(&[1, 3], &[3, 1], &[Some(0), Some(1)]); + // id=1: gen=3 wins. id=2: only gen=2 row. id=3: only gen=1 row. + let rows = extract(&run(vec![b1, b2]).await); + assert_eq!( + rows, + vec![(1, 3, Some(0)), (2, 2, Some(5)), (3, 1, Some(1))], + ); + } + + #[tokio::test] + async fn null_freshness_skipped() { + let b = batch(&[1, 1], &[5, 5], &[None, Some(0)]); + // The null-freshness row is dropped; the real one wins by default. + let rows = extract(&run(vec![b]).await); + assert_eq!(rows, vec![(1, 5, Some(0))]); + } + + #[tokio::test] + async fn empty_input() { + let total: usize = run(vec![]).await.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 0); + } +} diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs new file mode 100644 index 00000000000..29eac385381 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/source_tag.rs @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Per-source tagging for the LSM vector-search dedup pipeline. +//! +//! `LsmSourceTagExec` appends two columns to each row of a per-source scan: +//! - `_memtable_gen` (UInt64): the source's generation number (base = 0, +//! flushed gens 1..N, active memtable = N+1). +//! - `_freshness` (UInt64): a within-source "newness" indicator normalized +//! so that *larger value = newer insert* regardless of which side +//! produced it. The active memtable stores rows in insert order +//! (`_freshness = _rowid`), while flushed memtables are reverse-written +//! (`_freshness = u64::MAX - _rowid`). +//! +//! Together, the two columns let [`super::LsmGlobalPkDedupExec`] decide a +//! winner per primary key via a single lexicographic `(gen, freshness)` +//! comparison across the merged stream — no separate within-source dedup +//! and no bloom-based staleness filtering needed. + +use std::any::Any; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow_array::{Array, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::error::Result as DFResult; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, +}; +use futures::{Stream, StreamExt}; + +use crate::dataset::mem_wal::scanner::data_source::LsmGeneration; + +use super::generation_tag::MEMTABLE_GEN_COLUMN; + +/// Column name for the normalized within-source freshness. Higher = newer. +pub const FRESHNESS_COLUMN: &str = "_freshness"; + +/// Polarity for translating a source's row-id column into `_freshness`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FreshnessPolarity { + /// `_freshness = row_id`. Used by sources that store rows in insert + /// order (active memtable; also base table where duplicates aren't + /// expected but the polarity must still be consistent). + InsertOrder, + /// `_freshness = u64::MAX - row_id`. Used by flushed memtables, which + /// are reverse-written so a smaller `_rowid` is the newer insert. + ReverseWrite, +} + +/// Tag every row of a per-source scan with `_memtable_gen` + `_freshness`. +/// +/// # Required input columns +/// +/// - `row_id_column` (UInt64) — typically `_rowid`. Must be present on +/// every row; NULLs are propagated as NULL `_freshness` and will be +/// skipped by the downstream dedup. +/// +/// # Output schema +/// +/// Input schema + `_memtable_gen` (UInt64, NOT NULL) + `_freshness` +/// (UInt64, nullable to mirror the source's `_rowid` nullability). +#[derive(Debug)] +pub struct LsmSourceTagExec { + input: Arc, + generation: LsmGeneration, + polarity: FreshnessPolarity, + row_id_column: String, + schema: SchemaRef, + properties: Arc, +} + +impl LsmSourceTagExec { + pub fn new( + input: Arc, + generation: LsmGeneration, + polarity: FreshnessPolarity, + row_id_column: impl Into, + ) -> Self { + let input_schema = input.schema(); + let mut fields: Vec> = input_schema.fields().iter().cloned().collect(); + fields.push(Arc::new(Field::new( + MEMTABLE_GEN_COLUMN, + DataType::UInt64, + false, + ))); + fields.push(Arc::new(Field::new( + FRESHNESS_COLUMN, + DataType::UInt64, + true, + ))); + let schema = Arc::new(Schema::new(fields)); + + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + )); + + Self { + input, + generation, + polarity, + row_id_column: row_id_column.into(), + schema, + properties, + } + } + + pub fn generation(&self) -> LsmGeneration { + self.generation + } + + pub fn polarity(&self) -> FreshnessPolarity { + self.polarity + } + + pub fn row_id_column(&self) -> &str { + &self.row_id_column + } +} + +impl DisplayAs for LsmSourceTagExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!( + f, + "LsmSourceTagExec: gen={}, polarity={:?}, row_id_col={}", + self.generation, self.polarity, self.row_id_column, + ) + } + } + } +} + +impl ExecutionPlan for LsmSourceTagExec { + fn name(&self) -> &str { + "LsmSourceTagExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + if children.len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "LsmSourceTagExec requires exactly one child".to_string(), + )); + } + Ok(Arc::new(Self::new( + children[0].clone(), + self.generation, + self.polarity, + self.row_id_column.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let input_stream = self.input.execute(partition, context)?; + Ok(Box::pin(SourceTagStream { + input: input_stream, + generation: self.generation.as_u64(), + polarity: self.polarity, + row_id_column: self.row_id_column.clone(), + schema: self.schema.clone(), + })) + } +} + +struct SourceTagStream { + input: SendableRecordBatchStream, + generation: u64, + polarity: FreshnessPolarity, + row_id_column: String, + schema: SchemaRef, +} + +impl SourceTagStream { + fn tag_batch(&self, batch: RecordBatch) -> DFResult { + let num_rows = batch.num_rows(); + let gen_col: Arc = Arc::new(UInt64Array::from(vec![self.generation; num_rows])); + + let row_id_arr = batch + .column_by_name(&self.row_id_column) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Row id column '{}' not found in batch — LsmSourceTagExec needs the per-source row id to derive _freshness", + self.row_id_column + )) + })? + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Row id column '{}' is not UInt64", + self.row_id_column + )) + })?; + + let freshness: Arc = match self.polarity { + FreshnessPolarity::InsertOrder => Arc::new(row_id_arr.clone()), + FreshnessPolarity::ReverseWrite => { + let mut builder = arrow_array::builder::UInt64Builder::with_capacity(num_rows); + for i in 0..num_rows { + if row_id_arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(u64::MAX - row_id_arr.value(i)); + } + } + Arc::new(builder.finish()) + } + }; + + let mut columns: Vec> = batch.columns().to_vec(); + columns.push(gen_col); + columns.push(freshness); + + RecordBatch::try_new(self.schema.clone(), columns) + .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) + } +} + +impl Stream for SourceTagStream { + type Item = DFResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + let tagged = self.tag_batch(batch); + Poll::Ready(Some(tagged)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl datafusion::physical_plan::RecordBatchStream for SourceTagStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::Int32Array; + use datafusion::prelude::SessionContext; + use datafusion_physical_plan::test::TestMemoryExec; + use futures::TryStreamExt; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("_rowid", DataType::UInt64, true), + ])) + } + + fn batch(ids: &[i32], row_ids: &[Option]) -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(Int32Array::from(ids.to_vec())), + Arc::new(UInt64Array::from(row_ids.to_vec())), + ], + ) + .unwrap() + } + + async fn run( + b: RecordBatch, + generation: LsmGeneration, + polarity: FreshnessPolarity, + ) -> Vec { + let schema = b.schema(); + let input = TestMemoryExec::try_new_exec(&[vec![b]], schema, None).unwrap(); + let exec = LsmSourceTagExec::new(input, generation, polarity, "_rowid"); + let ctx = SessionContext::new(); + let stream = exec.execute(0, ctx.task_ctx()).unwrap(); + stream.try_collect().await.unwrap() + } + + fn columns(batches: &[RecordBatch]) -> (Vec, Vec>) { + let mut gens = Vec::new(); + let mut fresh = Vec::new(); + for b in batches { + let g = b + .column_by_name(MEMTABLE_GEN_COLUMN) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let f = b + .column_by_name(FRESHNESS_COLUMN) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..b.num_rows() { + gens.push(g.value(i)); + fresh.push(if f.is_null(i) { None } else { Some(f.value(i)) }); + } + } + (gens, fresh) + } + + #[tokio::test] + async fn insert_order_passes_row_id_through() { + let b = batch(&[1, 2, 3], &[Some(0), Some(5), Some(99)]); + let out = run( + b, + LsmGeneration::memtable(7), + FreshnessPolarity::InsertOrder, + ) + .await; + let (gens, fresh) = columns(&out); + assert_eq!(gens, vec![7, 7, 7]); + assert_eq!(fresh, vec![Some(0), Some(5), Some(99)]); + } + + #[tokio::test] + async fn reverse_write_flips_row_id() { + let b = batch(&[1, 2, 3], &[Some(0), Some(5), Some(99)]); + let out = run( + b, + LsmGeneration::memtable(2), + FreshnessPolarity::ReverseWrite, + ) + .await; + let (gens, fresh) = columns(&out); + assert_eq!(gens, vec![2, 2, 2]); + // Under reverse-write, smaller row_id = newer ⇒ larger _freshness. + assert_eq!( + fresh, + vec![Some(u64::MAX), Some(u64::MAX - 5), Some(u64::MAX - 99)], + ); + } + + #[tokio::test] + async fn null_row_id_yields_null_freshness() { + let b = batch(&[1, 2], &[None, Some(3)]); + let out = run( + b, + LsmGeneration::memtable(1), + FreshnessPolarity::ReverseWrite, + ) + .await; + let (_, fresh) = columns(&out); + assert_eq!(fresh, vec![None, Some(u64::MAX - 3)]); + } + + #[tokio::test] + async fn base_table_generation_is_zero() { + let b = batch(&[1], &[Some(0)]); + let out = run(b, LsmGeneration::BASE_TABLE, FreshnessPolarity::InsertOrder).await; + let (gens, _) = columns(&out); + assert_eq!(gens, vec![0]); + } + + #[tokio::test] + async fn empty_batch_passthrough() { + let schema = input_schema(); + let empty = RecordBatch::new_empty(schema); + let out = run( + empty, + LsmGeneration::memtable(1), + FreshnessPolarity::InsertOrder, + ) + .await; + let total: usize = out.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 0); + } +} diff --git a/rust/lance/src/dataset/mem_wal/scanner/projection.rs b/rust/lance/src/dataset/mem_wal/scanner/projection.rs index 20d1b1a403d..00c05056a18 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/projection.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/projection.rs @@ -125,6 +125,36 @@ pub fn canonical_output_schema( Arc::new(Schema::new(fields)) } +/// Like [`canonical_output_schema`] but with the internal LSM bookkeeping +/// columns appended: `_memtable_gen` (UInt64, NOT NULL) and `_freshness` +/// (UInt64, nullable). Used by the vector-search pipeline to carry source +/// identity + per-source row order through the union and the global +/// dedup; both columns are dropped by a downstream `project_to_canonical` +/// before returning to the caller. +pub fn canonical_internal_schema( + user_projection: Option<&[String]>, + base_schema: &SchemaRef, + pk_columns: &[String], + include_distance: bool, +) -> SchemaRef { + use crate::dataset::mem_wal::scanner::exec::{FRESHNESS_COLUMN, MEMTABLE_GEN_COLUMN}; + + let canonical = + canonical_output_schema(user_projection, base_schema, pk_columns, include_distance); + let mut fields: Vec> = canonical.fields().iter().cloned().collect(); + fields.push(Arc::new(Field::new( + MEMTABLE_GEN_COLUMN, + DataType::UInt64, + false, + ))); + fields.push(Arc::new(Field::new( + FRESHNESS_COLUMN, + DataType::UInt64, + true, + ))); + Arc::new(Schema::new(fields)) +} + /// Wrap `plan` so the named columns become typed NULL literals; all /// other columns are forwarded unchanged. Schema is preserved (same /// fields, same dtypes). Useful for stripping the *value* of an diff --git a/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs b/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs index f6d50f101d0..a78bf197a23 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/vector_search.rs @@ -13,85 +13,96 @@ use arrow_schema::SortOptions; use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion::physical_plan::ExecutionPlan; +#[allow(deprecated)] +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; use lance_core::Result; -use lance_index::scalar::bloomfilter::sbbf::Sbbf; +use lance_core::datatypes::OnMissing; use tracing::instrument; +use crate::dataset::Dataset; +use crate::io::exec::TakeExec; + use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; -use super::exec::{FilterStaleExec, GenerationBloomFilter, MemtableGenTagExec}; +use super::exec::{FreshnessPolarity, LsmGlobalPkDedupExec, LsmSourceTagExec}; use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ - DISTANCE_COLUMN, build_scanner_projection, canonical_output_schema, null_columns, - project_to_canonical, wants_row_id, + DISTANCE_COLUMN, build_scanner_projection, canonical_internal_schema, canonical_output_schema, + null_columns, project_to_canonical, wants_row_id, }; use crate::session::Session; /// Plans vector search queries over LSM data. /// -/// Vector search queries are executed across all LSM levels and results -/// are merged with staleness detection. The query plan uses: -/// -/// 1. **FilterStaleExec**: Filters out results with newer versions in higher generations -/// 2. **UnionExec**: Combines results from all sources -/// 3. **SortExec (per partition, fetch=k)**: Sorts each source's candidates by distance, in parallel -/// 4. **SortPreservingMergeExec (fetch=k)**: K-way merge of sorted streams, early-terminates at k +/// Each source independently runs KNN, then results are unioned and run +/// through a single global PK dedup that picks the row with the largest +/// `(generation, freshness)` tuple per primary key. Generation is the +/// source identity (base = 0, memtable gens 1..N, active = N+1) and +/// freshness is the per-source row order normalized so larger = newer +/// (see [`LsmSourceTagExec`]). /// /// # Query Plan Structure /// /// ```text -/// SortPreservingMergeExec: order_by=[_distance ASC], fetch=k -/// SortExec: order_by=[_distance ASC], fetch=k (per partition, parallel) -/// FilterStaleExec: bloom_filters=[gen3, gen2, gen1] -/// UnionExec -/// MemtableGenTagExec: gen=3 -/// KNNExec: memtable_gen_3, k=k -/// MemtableGenTagExec: gen=2 -/// KNNExec: flushed_gen_2, k=k (fast_search) -/// MemtableGenTagExec: gen=1 -/// KNNExec: flushed_gen_1, k=k (fast_search) -/// MemtableGenTagExec: gen=0 -/// KNNExec: base_table, k=k (fast_search) +/// TakeExec (optional: fetch user-projected cols from base dataset) +/// SortPreservingMergeExec: order_by=[_distance ASC], fetch=k +/// SortExec: order_by=[_distance ASC], fetch=k (per partition, parallel) +/// ProjectionExec (drops _memtable_gen, _freshness) +/// LsmGlobalPkDedupExec: pk=[…], gen=_memtable_gen, freshness=_freshness +/// CoalescePartitionsExec +/// UnionExec +/// ProjectionExec (canonical internal schema) +/// ProjectionExec (null_columns _rowid) (non-base only) +/// LsmSourceTagExec: gen=N+1, polarity=InsertOrder (active) +/// KNNExec: active memtable, k=k +/// ProjectionExec (canonical internal schema) +/// ProjectionExec (null_columns _rowid) +/// LsmSourceTagExec: gen=N, polarity=ReverseWrite (flushed) +/// KNNExec: flushed gen N, k=k (fast_search) +/// … one per flushed gen … +/// ProjectionExec (canonical internal schema) +/// LsmSourceTagExec: gen=0, polarity=InsertOrder (base) +/// KNNExec: base table, k=k (fast_search)[.refine()?] /// ``` /// /// # Index-Only Search (fast_search) /// -/// For base table and flushed memtables, we use `fast_search()` to only search -/// indexed data. This is correct because: -/// - Each flushed memtable has its own vector index built during flush -/// - The active memtable covers any unindexed data -/// - Searching unindexed data in base/flushed would be redundant +/// For base table and flushed memtables we use `fast_search()` to only +/// search indexed data. This is correct because: +/// - Each flushed memtable has its own vector index built during flush. +/// - The active memtable covers any unindexed data. +/// - Searching unindexed data in base/flushed would be redundant. /// -/// # Staleness Detection +/// # Dedup semantics /// -/// For each candidate result from generation G, FilterStaleExec checks if the -/// primary key exists in bloom filters of generations > G. If found, the result -/// is filtered out because a newer version exists. +/// `LsmGlobalPkDedupExec` keeps the row whose `(generation, freshness)` +/// tuple is largest, so newer generations always win and ties within a +/// generation fall to the source-local freshness (larger row offset for +/// active memtables; smaller `_rowid` for flushed memtables, flipped by +/// `LsmSourceTagExec` so the comparison stays uniform). pub struct LsmVectorSearchPlanner { /// Data source collector. collector: LsmDataSourceCollector, - /// Primary key column names (for staleness detection). + /// Primary key column names (used by the global dedup). pk_columns: Vec, /// Schema of the base table. base_schema: SchemaRef, - /// Bloom filters for each memtable generation. - bloom_filters: Vec, /// Vector column name. vector_column: String, /// Distance metric type (L2, Cosine, Dot, etc.). distance_type: lance_linalg::distance::DistanceType, - /// Refine factor applied to the base-table KNN scan. + /// Base dataset reference for post-rerank take. /// - /// `None` (default): no refine — base distances may be approximate - /// (e.g. when the base table is indexed with IVF-PQ). `Some(n)`: fetch - /// `k * n` candidates and re-rank with exact distances using the - /// original vectors. Set this to make cross-source distance comparison - /// across the LSM merge fully exact. - base_table_refine_factor: Option, + /// After the global PK dedup and sort, a `TakeExec` against this + /// dataset materializes any user-projected columns that were not + /// part of the per-source KNN output. Rows from memtables already + /// carry all columns; the take only fetches additional data for + /// base-table rows (which have a real `_rowid`). + dataset: Option>, /// Session threaded into flushed-generation opens (shared caches). session: Option>, /// Cache of opened flushed-generation datasets. @@ -119,10 +130,9 @@ impl LsmVectorSearchPlanner { collector, pk_columns, base_schema, - bloom_filters: Vec::new(), vector_column, distance_type, - base_table_refine_factor: None, + dataset: None, session: None, flushed_cache: None, } @@ -142,39 +152,15 @@ impl LsmVectorSearchPlanner { self } - /// Enable base-table refine. - /// - /// When set, the base-table arm of the KNN plan asks the scanner for - /// `k * factor` candidates and re-ranks them with exact distances. This - /// is useful when the base table uses an approximate index (IVF-PQ) and - /// you need exact distances for cross-source merging in the LSM scan. + /// Set the base dataset for post-rerank take. /// - /// Default: disabled (base table returns approximate distances). - pub fn with_base_table_refine_factor(mut self, factor: u32) -> Self { - self.base_table_refine_factor = Some(factor); - self - } - - /// Add a bloom filter for staleness detection. - pub fn with_bloom_filter(mut self, generation: u64, bloom_filter: Arc) -> Self { - self.bloom_filters.push(GenerationBloomFilter { - generation, - bloom_filter, - }); - self - } - - /// Add multiple bloom filters. - pub fn with_bloom_filters( - mut self, - bloom_filters: impl IntoIterator)>, - ) -> Self { - for (generation, bf) in bloom_filters { - self.bloom_filters.push(GenerationBloomFilter { - generation, - bloom_filter: bf, - }); - } + /// After global PK dedup and sort, a `TakeExec` against this dataset + /// materializes any user-projected columns that were not part of the + /// per-source KNN output. This is necessary because per-source KNN + /// only returns the columns needed for dedup and ranking; the take + /// step fetches the full user projection for the final top-k rows. + pub fn with_dataset(mut self, dataset: Arc) -> Self { + self.dataset = Some(dataset); self } @@ -186,6 +172,11 @@ impl LsmVectorSearchPlanner { /// * `k` - Number of nearest neighbors to return /// * `nprobes` - Number of IVF partitions to search (for IVF-based indexes) /// * `projection` - Columns to include in output (None = all columns) + /// * `refine_factor` - When set, the base-table arm of the KNN plan fetches + /// `k * refine_factor` candidates and re-ranks them with exact distances. + /// Useful when the base table uses an approximate index (IVF-PQ) so that + /// cross-source distance comparison is exact. Memtable arms use exact + /// HNSW search and do not need refine. /// /// # Returns /// @@ -198,6 +189,7 @@ impl LsmVectorSearchPlanner { k: usize, nprobes: usize, projection: Option<&[String]>, + refine_factor: Option, ) -> Result> { let sources = self.collector.collect()?; @@ -205,60 +197,75 @@ impl LsmVectorSearchPlanner { return self.empty_plan(projection); } - let has_bloom = !self.bloom_filters.is_empty(); let canonical_schema = canonical_output_schema( projection, &self.base_schema, &self.pk_columns, true, // include _distance — KNN always produces it ); + // The internal schema carries `_memtable_gen` + `_freshness` + // through the union and the global dedup; both are dropped + // afterwards by a project back to the canonical output schema. + let internal_schema = + canonical_internal_schema(projection, &self.base_schema, &self.pk_columns, true); let mut knn_plans = Vec::new(); for source in &sources { let generation = source.generation(); let is_base = matches!(source, LsmDataSource::BaseTable { .. }); let knn = self - .build_knn_plan(source, query_vector, k, nprobes, projection) + .build_knn_plan(source, query_vector, k, nprobes, projection, refine_factor) .await?; - // Normalize each source to the canonical schema. - // Lance's `fast_search()` always produces `_rowid` whether or - // not we asked for it. For non-base arms that value is local to - // the per-source dataset and would collide with base IDs, so we - // NULL it before merging. (no-op if `_rowid` isn't in the - // source schema, e.g. the active arm.) - let knn = if is_base { - knn - } else { - null_columns(knn, &[lance_core::ROW_ID])? + // Tag rows with `(_memtable_gen, _freshness)`. Polarity differs + // per source — see [`LsmSourceTagExec`] / [`FreshnessPolarity`]: + // * active memtable: insert order, larger `_rowid` = newer + // * flushed memtable: reverse-written, smaller `_rowid` = newer + // * base table: no duplicates expected; polarity moot + let polarity = match source { + LsmDataSource::FlushedMemTable { .. } => FreshnessPolarity::ReverseWrite, + LsmDataSource::ActiveMemTable { .. } | LsmDataSource::BaseTable { .. } => { + FreshnessPolarity::InsertOrder + } }; - let normalized = project_to_canonical(knn, &canonical_schema)?; - let plan: Arc = if has_bloom { - Arc::new(MemtableGenTagExec::new(normalized, generation)) + let tagged: Arc = Arc::new(LsmSourceTagExec::new( + knn, + generation, + polarity, + lance_core::ROW_ID, + )); + // Lance's `fast_search()` always produces `_rowid` whether or + // not we asked for it; the active arm also produces `_rowid` + // when we ask for it (to drive freshness). For non-base arms + // the per-source value would collide with base row ids in the + // canonical output, so NULL it before stitching into the + // internal schema. The dedup has already consumed it via + // `_freshness`. + let after_null = if is_base { + tagged } else { - normalized + null_columns(tagged, &[lance_core::ROW_ID])? }; - knn_plans.push(plan); + // Normalize each source to the internal canonical schema + // (canonical user cols + `_memtable_gen` + `_freshness`). + let normalized = project_to_canonical(after_null, &internal_schema)?; + knn_plans.push(normalized); } #[allow(deprecated)] let union: Arc = Arc::new(UnionExec::new(knn_plans)); - let merged: Arc = if has_bloom { - // FilterStaleExec declares one output partition but only reads partition 0 - // of its input — without coalescing first, every union partition past the - // base table is silently dropped on the bloom-filter path. - let coalesced_in: Arc = Arc::new(CoalescePartitionsExec::new(union)); - let filtered: Arc = Arc::new(FilterStaleExec::new( - coalesced_in, - self.pk_columns.clone(), - self.bloom_filters.clone(), - )); - // FilterStaleExec needs `_memtable_gen`; drop it before returning so the - // output matches `empty_plan` and excludes internal LSM bookkeeping cols. - project_to_canonical(filtered, &canonical_schema)? - } else { - union - }; + // LsmGlobalPkDedupExec declares one output partition but only + // reads partition 0 of its input — coalesce first or partitions + // past the base table get silently dropped. + let coalesced: Arc = Arc::new(CoalescePartitionsExec::new(union)); + let deduped: Arc = Arc::new(LsmGlobalPkDedupExec::new( + coalesced, + self.pk_columns.clone(), + super::exec::MEMTABLE_GEN_COLUMN, + super::exec::FRESHNESS_COLUMN, + )); + // Drop `_memtable_gen` and `_freshness` — they're internal-only. + let merged: Arc = project_to_canonical(deduped, &canonical_schema)?; let distance_idx = merged.schema().index_of(DISTANCE_COLUMN).map_err(|_| { lance_core::Error::invalid_input(format!( @@ -298,7 +305,30 @@ impl LsmVectorSearchPlanner { SortPreservingMergeExec::new(lex_ordering, per_partition_sorted).with_fetch(Some(k)), ); - Ok(merged_sorted) + // After global rerank, take any user-projected columns that the + // per-source KNN didn't return. This fetches from the base dataset + // using `_rowid`; memtable rows (NULL `_rowid`) already carry all + // their data so the take is a no-op for them. + #[allow(deprecated)] + let result = if let Some(dataset) = &self.dataset { + let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns); + let output_projection = dataset + .empty_projection() + .union_columns(cols, OnMissing::Ignore)?; + let coalesced: Arc = + Arc::new(CoalesceBatchesExec::new(merged_sorted.clone(), 8192)); + if let Some(take_plan) = + TakeExec::try_new(dataset.clone(), coalesced, output_projection)? + { + Arc::new(take_plan) as Arc + } else { + merged_sorted + } + } else { + merged_sorted + }; + + Ok(result) } /// Build KNN plan for a single data source. @@ -309,6 +339,7 @@ impl LsmVectorSearchPlanner { k: usize, nprobes: usize, projection: Option<&[String]>, + refine_factor: Option, ) -> Result> { match source { LsmDataSource::BaseTable { dataset } => { @@ -331,7 +362,7 @@ impl LsmVectorSearchPlanner { scanner.fast_search(); // Re-rank base candidates with exact distances when set, so // they're directly comparable to MemTable distances in the merge. - if let Some(factor) = self.base_table_refine_factor { + if let Some(factor) = refine_factor { scanner.refine(factor); } scanner.create_plan().await @@ -367,8 +398,14 @@ impl LsmVectorSearchPlanner { let cols = build_scanner_projection(projection, &self.base_schema, &self.pk_columns); scanner.project(&cols.iter().map(|s| s.as_str()).collect::>()); - // No `with_row_id/address`: MemTableScanner returns BatchStore - // positions, not Lance row ids. + // Expose `_rowid` (BatchStore row offset, monotonic with + // insert order) so [`WithinSourceDedupExec`] can collapse + // duplicate-PK rows to the newest insert. The value is + // per-source and NULL'd before reaching the canonical merge. + // (VectorIndexExec only plumbs `with_row_id`, not + // `with_row_address`, but the two yield identical values + // for an active memtable so either would work.) + scanner.with_row_id(); let query_arr: Arc = Arc::new(query_vector.clone()); scanner.nearest(&self.vector_column, query_arr, k); scanner.nprobes(nprobes); @@ -495,7 +532,7 @@ mod tests { ); let query = create_query_vector(); - let plan = planner.plan_search(&query, 10, 8, None).await; + let plan = planner.plan_search(&query, 10, 8, None, None).await; // Plan construction must succeed. Execution against empty data is a // separate concern handled by integration tests. @@ -583,7 +620,7 @@ mod tests { let query = create_query_vector(); let plan = planner - .plan_search(&query, 3, 1, None) + .plan_search(&query, 3, 1, None, None) .await .expect("planner should produce a plan"); @@ -704,7 +741,7 @@ mod tests { let query = create_query_vector(); let projection = vec!["vector".to_string()]; let plan = planner - .plan_search(&query, 3, 1, Some(&projection)) + .plan_search(&query, 3, 1, Some(&projection), None) .await .expect("planner should produce a plan"); @@ -786,7 +823,7 @@ mod tests { "_rowid".to_string(), ]; let plan = planner - .plan_search(&query, 3, 1, Some(&projection)) + .plan_search(&query, 3, 1, Some(&projection), None) .await .expect( "planner must accept `_distance`/`_rowid` in projection without breaking the plan", @@ -834,15 +871,21 @@ mod tests { } #[tokio::test] - async fn test_vector_search_with_bloom_filter_strips_memtable_gen() { - // Regression for: when bloom filters are configured, FilterStaleExec preserves - // its `_memtable_gen` input column. Without the post-filter projection that - // strips it, the column would leak into the user-visible output. + async fn test_vector_search_strips_internal_columns_and_preserves_active_rows() { + // Two regressions in one test: + // (1) `LsmGlobalPkDedupExec` consumes `_memtable_gen` and `_freshness` + // but the user-visible output must NOT contain them — the + // post-dedup `project_to_canonical` is what strips them, so a + // refactor that drops that projection would leak these columns. + // (2) `LsmGlobalPkDedupExec` declares one output partition but only + // reads partition 0 of its input. Without a `CoalescePartitionsExec` + // ahead of it, every union partition past partition 0 is silently + // dropped — i.e. active-memtable rows disappear when the union + // puts them in a non-zero partition. use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; use datafusion::prelude::SessionContext; use futures::TryStreamExt; - use lance_index::scalar::bloomfilter::sbbf::Sbbf; let schema = create_vector_schema(); let temp_dir = tempfile::tempdir().unwrap(); @@ -881,34 +924,28 @@ mod tests { }, ); - // An empty bloom filter — no IDs marked, so FilterStaleExec keeps everything. - // The point of this test isn't filtering correctness, it's verifying the - // post-filter projection strips `_memtable_gen` from the output. - let bloom = Arc::new(Sbbf::with_ndv_fpp(8, 0.01).unwrap()); - let planner = LsmVectorSearchPlanner::new( collector, vec!["id".to_string()], schema, "vector".to_string(), lance_linalg::distance::DistanceType::L2, - ) - .with_bloom_filter(2, bloom); + ); let query = create_query_vector(); let plan = planner - .plan_search(&query, 3, 1, None) + .plan_search(&query, 3, 1, None, None) .await .expect("planner should produce a plan"); - // Plan must include FilterStaleExec (proves bloom-filter path was taken). + // Plan must include the new global dedup (proves the pipeline is wired). let plan_str = format!( "{}", datafusion::physical_plan::displayable(plan.as_ref()).indent(true) ); assert!( - plan_str.contains("FilterStaleExec"), - "expected bloom-filter path with FilterStaleExec, got:\n{}", + plan_str.contains("LsmGlobalPkDedupExec"), + "expected new global-dedup pipeline, got:\n{}", plan_str ); @@ -920,25 +957,27 @@ mod tests { let out_schema = batches[0].schema(); assert!(out_schema.field_with_name(DISTANCE_COLUMN).is_ok()); - assert!( - out_schema - .field_with_name(super::super::exec::MEMTABLE_GEN_COLUMN) - .is_err(), - "`_memtable_gen` leaked into output when bloom filters were configured: {:?}", - out_schema - .fields() - .iter() - .map(|f| f.name().clone()) - .collect::>() - ); + for internal in [ + super::super::exec::MEMTABLE_GEN_COLUMN, + super::super::exec::FRESHNESS_COLUMN, + ] { + assert!( + out_schema.field_with_name(internal).is_err(), + "`{}` leaked into output: {:?}", + internal, + out_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>(), + ); + } - // Verify active-memtable rows survived the bloom-filter path. The collector - // emits base as partition 0 and the active memtable as partition 1+ of the - // UnionExec. FilterStaleExec declares 1 output partition but reads only - // partition 0 of its input — so without the CoalescePartitionsExec inserted - // ahead of it, partitions 1+ are silently dropped. The active memtable holds - // ids 1..=4; the base holds id 10. Asserting that at least one id in 1..=4 - // is present directly proves partition-1+ data made it through. + // (2) Active-memtable rows must survive: collector emits base as + // partition 0 of the union and the active memtable as partition 1+. + // The active memtable holds ids 1..=4; the base holds id 10. At + // least one id in 1..=4 must appear in the output, otherwise the + // CoalescePartitionsExec was skipped and partitions 1+ were dropped. let mut all_ids: Vec = Vec::new(); for batch in &batches { let id_col = batch @@ -959,6 +998,109 @@ mod tests { ); } + #[tokio::test] + async fn test_vector_search_dedup_across_generations() { + // Regression: same primary key inserted into two sources (older + // flushed gen and newer active memtable) with different vectors. + // Without the cross-source PK dedup the older flushed row would + // still appear in top-k. The newer-generation row must win. + // + // We simulate a "flushed gen 1" by writing a tiny Lance dataset + // under {base_uri}/_mem_wal/{shard}/gen_1 and pointing the + // collector at it. Real flush would reverse-write, but for this + // test we only have one row in the flushed gen so order is moot. + use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; + use crate::dataset::mem_wal::scanner::data_source::ShardSnapshot; + use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; + use datafusion::prelude::SessionContext; + use futures::TryStreamExt; + + let schema = create_vector_schema(); + let temp_dir = tempfile::tempdir().unwrap(); + let base_path = temp_dir.path().to_str().unwrap(); + let base_uri = format!("{}/base", base_path); + + // Flushed gen 1 holds an older version of pk=1 with a "wrong" vector. + let shard_id = uuid::Uuid::new_v4(); + let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id); + let old_pk1 = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]); + create_dataset(&gen1_uri, vec![old_pk1]).await; + + // Active memtable holds the newer version of pk=1 with the + // "right" vector close to the query, plus an unrelated pk=2. + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut index_store = IndexStore::new(); + index_store.add_hnsw( + "vector_hnsw".to_string(), + 1, + "vector".to_string(), + lance_linalg::distance::DistanceType::L2, + 64, + 8, + ); + let new_pk1 = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]); + let other = create_test_batch_with_vector(&schema, 2, [5.0, 5.0, 5.0, 5.0]); + let (_, _, bp1) = batch_store.append(new_pk1.clone()).unwrap(); + index_store + .insert_with_batch_position(&new_pk1, 0, Some(bp1)) + .unwrap(); + let (_, _, bp2) = batch_store.append(other.clone()).unwrap(); + index_store + .insert_with_batch_position(&other, 1, Some(bp2)) + .unwrap(); + let index_store = Arc::new(index_store); + + let shard_snapshot = ShardSnapshot::new(shard_id) + .with_current_generation(2) + .with_flushed_generation(1, "gen_1".to_string()); + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot]) + .with_in_memory_memtables( + shard_id, + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store, + schema: schema.clone(), + generation: 2, + }, + frozen: vec![], + }, + ); + + let planner = LsmVectorSearchPlanner::new( + collector, + vec!["id".to_string()], + schema, + "vector".to_string(), + lance_linalg::distance::DistanceType::L2, + ); + + let query = create_query_vector(); + let plan = planner.plan_search(&query, 5, 1, None, None).await.unwrap(); + let ctx = SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + + let ids: Vec = batches + .iter() + .flat_map(|b| { + b.column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec() + }) + .collect(); + let pk1_count = ids.iter().filter(|i| **i == 1).count(); + assert_eq!( + pk1_count, 1, + "pk=1 must appear exactly once after cross-source dedup; got ids={:?}", + ids, + ); + } + #[tokio::test] async fn test_vector_search_system_columns_real_only_for_base() { // Covers tests 1+2+3 from the PR review: @@ -1054,7 +1196,7 @@ mod tests { "vector".to_string(), ]; let plan = planner - .plan_search(&query, 3, 1, Some(&projection)) + .plan_search(&query, 3, 1, Some(&projection), None) .await .expect("planner should produce a plan"); @@ -1130,7 +1272,7 @@ mod tests { ]; let query = create_query_vector(); let plan = planner - .plan_search(&query, 5, 1, Some(&projection)) + .plan_search(&query, 5, 1, Some(&projection), None) .await .expect("empty plan must accept system columns in projection"); @@ -1175,7 +1317,7 @@ mod tests { let query = create_query_vector(); let plan = planner - .plan_search(&query, 10, 8, None) + .plan_search(&query, 10, 8, None, None) .await .expect("planner should produce a plan without a base table"); @@ -1202,4 +1344,129 @@ mod tests { let total: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!(total, 0, "fresh tier with no sources should yield no rows"); } + + /// Build a single-row batch with an explicit (id, vector) so tests can + /// pin the within-source dedup against same-PK / different-vector + /// inputs. + fn create_test_batch_with_vector( + schema: &ArrowSchema, + id: i32, + vector: [f32; 4], + ) -> RecordBatch { + use arrow_array::builder::Float32Builder; + + let mut vector_builder = FixedSizeListBuilder::new(Float32Builder::new(), 4); + for v in &vector { + vector_builder.values().append_value(*v); + } + vector_builder.append(true); + let vector_array = vector_builder.finish(); + + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![id])), Arc::new(vector_array)], + ) + .unwrap() + } + + #[tokio::test] + async fn test_vector_search_dedup_within_active_memtable() { + // Regression: same PK inserted twice into one active memtable with + // *different* vectors. HNSW indexes each as a distinct node, so + // without WithinSourceDedupExec a KNN can return both candidates + // for the same PK and pollute top-k. The newer insert must win. + use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; + use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; + use datafusion::prelude::SessionContext; + use futures::TryStreamExt; + + let schema = create_vector_schema(); + let temp_dir = tempfile::tempdir().unwrap(); + let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap()); + + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut index_store = IndexStore::new(); + index_store.add_hnsw( + "vector_hnsw".to_string(), + 1, + "vector".to_string(), + lance_linalg::distance::DistanceType::L2, + 64, + 8, + ); + + // Two rows with pk=1, different vectors. b_new (last insert) is the + // newer version that the LSM contract says should win. + let b_old = create_test_batch_with_vector(&schema, 1, [9.0, 9.0, 9.0, 9.0]); + let b_new = create_test_batch_with_vector(&schema, 1, [0.1, 0.2, 0.3, 0.4]); + // An unrelated row so top-k has more than one PK to choose from. + let b_other = create_test_batch_with_vector(&schema, 2, [5.0, 5.0, 5.0, 5.0]); + + let (_, _, bp_old) = batch_store.append(b_old.clone()).unwrap(); + index_store + .insert_with_batch_position(&b_old, 0, Some(bp_old)) + .unwrap(); + let (_, _, bp_new) = batch_store.append(b_new.clone()).unwrap(); + index_store + .insert_with_batch_position(&b_new, 1, Some(bp_new)) + .unwrap(); + let (_, _, bp_other) = batch_store.append(b_other.clone()).unwrap(); + index_store + .insert_with_batch_position(&b_other, 2, Some(bp_other)) + .unwrap(); + let index_store = Arc::new(index_store); + + let shard_id = uuid::Uuid::new_v4(); + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]) + .with_in_memory_memtables( + shard_id, + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store, + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }, + ); + + let planner = LsmVectorSearchPlanner::new( + collector, + vec!["id".to_string()], + schema, + "vector".to_string(), + lance_linalg::distance::DistanceType::L2, + ); + + // Query is exactly the *newer* vector for pk=1. If the older + // vector for pk=1 leaks through, it'd appear in top-k too because + // the older row's vector is far from the query but still a graph + // node. After dedup we should see pk=1 exactly once. + let query = create_query_vector(); + let plan = planner.plan_search(&query, 5, 1, None, None).await.unwrap(); + + let ctx = SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + + let ids: Vec = batches + .iter() + .flat_map(|b| { + b.column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec() + }) + .collect(); + let pk1_count = ids.iter().filter(|i| **i == 1).count(); + assert_eq!( + pk1_count, 1, + "pk=1 must appear exactly once after within-source dedup; got ids={:?}", + ids, + ); + } } From bd58ad07f3611edb208f9bf848bc3ef434aecf3a Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Wed, 20 May 2026 23:11:15 -0700 Subject: [PATCH 16/23] fix(mem_wal): dedupe duplicate primary keys in LSM point lookup (#6880) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split from #6856 — point-lookup portion. A primary key written multiple times into one active memtable used to leak through to the user as distinct rows: `FilterExec + LIMIT 1` over an insert-ordered scan returned the *oldest* match among duplicates. The active arm now runs `WithinSourceDedupExec(KeepMaxFreshness)`, which collapses by PK and keeps the freshest row. Flushed and base arms still rely on `LIMIT 1` under the reverse-write / forward-write conventions. Part of splitting #6856 into focused PRs. Co-authored with @jackye1995. Co-authored-by: Jack Ye --- .../lance/src/dataset/mem_wal/scanner/exec.rs | 4 + .../mem_wal/scanner/exec/global_pk_dedup.rs | 52 +-- .../src/dataset/mem_wal/scanner/exec/pk.rs | 60 +++ .../scanner/exec/within_source_dedup.rs | 432 ++++++++++++++++++ .../dataset/mem_wal/scanner/point_lookup.rs | 155 ++++++- 5 files changed, 649 insertions(+), 54 deletions(-) create mode 100644 rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs create mode 100644 rust/lance/src/dataset/mem_wal/scanner/exec/within_source_dedup.rs diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec.rs b/rust/lance/src/dataset/mem_wal/scanner/exec.rs index c2ed01bb1f2..40bb13b84b4 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/exec.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/exec.rs @@ -12,13 +12,16 @@ //! - [`CoalesceFirstExec`]: Returns first non-empty result with short-circuit //! - [`LsmSourceTagExec`]: Tags rows with `_memtable_gen` + `_freshness` for the vector-search global dedup //! - [`LsmGlobalPkDedupExec`]: Single-pass cross-source PK dedup over the merged vector-search stream +//! - [`WithinSourceDedupExec`]: Deduplicates rows with the same PK from a single source (used by point lookup) mod bloom_guard; mod coalesce_first; mod deduplicate; mod generation_tag; mod global_pk_dedup; +mod pk; mod source_tag; +mod within_source_dedup; pub use bloom_guard::{BloomFilterGuardExec, compute_pk_hash_from_scalars}; pub use coalesce_first::CoalesceFirstExec; @@ -26,3 +29,4 @@ pub use deduplicate::{DeduplicateExec, ROW_ADDRESS_COLUMN}; pub use generation_tag::{MEMTABLE_GEN_COLUMN, MemtableGenTagExec}; pub use global_pk_dedup::LsmGlobalPkDedupExec; pub use source_tag::{FRESHNESS_COLUMN, FreshnessPolarity, LsmSourceTagExec}; +pub use within_source_dedup::{DedupDirection, WithinSourceDedupExec}; diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs index f35de7771b0..fdf9372cc4e 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/global_pk_dedup.rs @@ -38,6 +38,8 @@ use datafusion::physical_plan::{ }; use futures::{Stream, StreamExt, ready}; +use super::pk::{compute_pk_hash, resolve_pk_indices}; + /// Cross-source PK dedup. Keeps one row per primary key — the one with /// the largest `(generation, freshness)` tuple. /// @@ -271,56 +273,6 @@ impl GlobalPkDedupStream { } } -fn resolve_pk_indices(batch: &RecordBatch, pk_columns: &[String]) -> DFResult> { - pk_columns - .iter() - .map(|col| { - batch - .schema() - .column_with_name(col) - .map(|(idx, _)| idx) - .ok_or_else(|| { - datafusion::error::DataFusionError::Internal(format!( - "Primary key column '{}' not found", - col - )) - }) - }) - .collect() -} - -/// Hash a row's primary key. Mirrors the variants supported by -/// [`super::WithinSourceDedupExec`] / `BloomFilterGuardExec`, so a single -/// PK produces the same hash everywhere in the LSM scanner. -fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = DefaultHasher::new(); - for &col_idx in pk_indices { - let col = batch.column(col_idx); - let is_null = col.is_null(row_idx); - is_null.hash(&mut hasher); - - if !is_null { - if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } else if let Some(arr) = col.as_any().downcast_ref::() { - arr.value(row_idx).hash(&mut hasher); - } - } - } - hasher.finish() -} - impl Stream for GlobalPkDedupStream { type Item = DFResult; diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs new file mode 100644 index 00000000000..abb2653fa50 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/pk.rs @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Shared primary-key helpers for the LSM scanner execution nodes. +//! +//! Centralizes PK column resolution and per-row hashing so that every dedup +//! node ([`super::WithinSourceDedupExec`] and [`super::LsmGlobalPkDedupExec`]) +//! resolves and hashes a primary key the same way. The row hash is kept +//! consistent with the variants supported by [`super::compute_pk_hash_from_scalars`] +//! so a single PK produces the same hash regardless of which exec consumes it. + +use arrow_array::{Array, RecordBatch}; +use datafusion::error::{DataFusionError, Result as DFResult}; + +/// Resolve the column index of each primary-key column in `batch`. +pub fn resolve_pk_indices(batch: &RecordBatch, pk_columns: &[String]) -> DFResult> { + pk_columns + .iter() + .map(|col| { + batch + .schema() + .column_with_name(col) + .map(|(idx, _)| idx) + .ok_or_else(|| { + DataFusionError::Internal(format!("Primary key column '{}' not found", col)) + }) + }) + .collect() +} + +/// Hash a single row's primary key, identified by the `pk_indices` column +/// positions and `row_idx`. +pub fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + for &col_idx in pk_indices { + let col = batch.column(col_idx); + let is_null = col.is_null(row_idx); + is_null.hash(&mut hasher); + + if !is_null { + if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } else if let Some(arr) = col.as_any().downcast_ref::() { + arr.value(row_idx).hash(&mut hasher); + } + } + } + hasher.finish() +} diff --git a/rust/lance/src/dataset/mem_wal/scanner/exec/within_source_dedup.rs b/rust/lance/src/dataset/mem_wal/scanner/exec/within_source_dedup.rs new file mode 100644 index 00000000000..be5dae6a668 --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/exec/within_source_dedup.rs @@ -0,0 +1,432 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! WithinSourceDedupExec - Deduplicates rows with the same primary key from a +//! single LSM source, keeping the newest insert. +//! +//! In MemWAL/LSM mode the same primary key can be written multiple times into +//! the same memtable. The active memtable stores rows in insert order (larger +//! `_rowaddr` = newer), while flushed memtables are reverse-written so that +//! within a flushed file the smallest `_rowid` is the newest insert (see +//! `memtable/flush.rs:152` and `hnsw/storage.rs:307`). Point lookup uses this +//! node to collapse such duplicates *within a single source* so that the +//! downstream `CoalesceFirstExec` / `LIMIT` sees at most one row per primary +//! key per source. + +use std::any::Any; +use std::collections::HashMap; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow_array::{Array, RecordBatch, UInt64Array}; +use arrow_schema::SchemaRef; +use datafusion::error::Result as DFResult; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, +}; +use futures::{Stream, StreamExt, ready}; + +use super::pk::{compute_pk_hash, resolve_pk_indices}; + +/// Among rows that share a primary key, which row-address extreme identifies +/// the newest insert to keep. The kept row is always the freshest; only the +/// row address (`_rowaddr`/`_rowid`) used to find it differs by source. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DedupDirection { + /// Keep the row with the largest row-address value (active memtable: larger + /// `_rowaddr` = inserted later). + KeepMaxRowAddr, + /// Keep the row with the smallest row-address value (flushed memtable under + /// reverse-write: smaller `_rowid` = inserted later). + KeepMinRowAddr, +} + +/// Deduplicates rows from a single source by primary key, keeping the row +/// whose `row_addr_column` value wins per [`DedupDirection`]. +/// +/// # Required columns +/// +/// The input must expose: +/// - All `pk_columns` +/// - `row_addr_column` of `UInt64` type +/// +/// The output schema is unchanged from the input. Callers that need to hide +/// the row-address column from downstream consumers should compose this node +/// with `project_to_canonical` or `null_columns`. +/// +/// # Performance +/// +/// Memory: `O(unique primary keys in input)`. For point lookup the input is +/// already filtered to a single primary key so the map holds at most one +/// entry. +#[derive(Debug)] +pub struct WithinSourceDedupExec { + input: Arc, + pk_columns: Vec, + row_addr_column: String, + direction: DedupDirection, + schema: SchemaRef, + properties: Arc, +} + +impl WithinSourceDedupExec { + pub fn new( + input: Arc, + pk_columns: Vec, + row_addr_column: impl Into, + direction: DedupDirection, + ) -> Self { + let schema = input.schema(); + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + input.pipeline_behavior(), + input.boundedness(), + )); + Self { + input, + pk_columns, + row_addr_column: row_addr_column.into(), + direction, + schema, + properties, + } + } + + pub fn pk_columns(&self) -> &[String] { + &self.pk_columns + } + + pub fn row_addr_column(&self) -> &str { + &self.row_addr_column + } + + pub fn direction(&self) -> DedupDirection { + self.direction + } +} + +impl DisplayAs for WithinSourceDedupExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!( + f, + "WithinSourceDedupExec: pk=[{}], row_addr={}, direction={:?}", + self.pk_columns.join(", "), + self.row_addr_column, + self.direction, + ) + } + } + } +} + +impl ExecutionPlan for WithinSourceDedupExec { + fn name(&self) -> &str { + "WithinSourceDedupExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + if children.len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "WithinSourceDedupExec requires exactly one child".to_string(), + )); + } + Ok(Arc::new(Self::new( + children[0].clone(), + self.pk_columns.clone(), + self.row_addr_column.clone(), + self.direction, + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let input_stream = self.input.execute(partition, context)?; + Ok(Box::pin(WithinSourceDedupStream { + input: input_stream, + pk_columns: self.pk_columns.clone(), + row_addr_column: self.row_addr_column.clone(), + direction: self.direction, + schema: self.schema.clone(), + winners: HashMap::new(), + emitted: false, + })) + } +} + +/// One winning row, materialized as a single-row `RecordBatch` so we don't +/// have to keep the source batch alive after we've picked the winner. +struct Winner { + batch: RecordBatch, + row_addr: u64, +} + +struct WithinSourceDedupStream { + input: SendableRecordBatchStream, + pk_columns: Vec, + row_addr_column: String, + direction: DedupDirection, + schema: SchemaRef, + winners: HashMap, + emitted: bool, +} + +impl WithinSourceDedupStream { + fn consume_batch(&mut self, batch: RecordBatch) -> DFResult<()> { + if batch.num_rows() == 0 { + return Ok(()); + } + let pk_indices = resolve_pk_indices(&batch, &self.pk_columns)?; + let row_addr_array = batch + .column_by_name(&self.row_addr_column) + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Row-address column '{}' not found in batch", + self.row_addr_column + )) + })? + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Internal(format!( + "Row-address column '{}' is not UInt64", + self.row_addr_column + )) + })?; + + for row_idx in 0..batch.num_rows() { + if row_addr_array.is_null(row_idx) { + // A NULL row address can't be ordered against a real one. Skip + // rather than guess — callers should always project a real + // row-address column for dedup-eligible sources. + continue; + } + let row_addr = row_addr_array.value(row_idx); + let pk_hash = compute_pk_hash(&batch, &pk_indices, row_idx); + + let take_row = match self.winners.get(&pk_hash) { + None => true, + Some(existing) => match self.direction { + DedupDirection::KeepMaxRowAddr => row_addr > existing.row_addr, + DedupDirection::KeepMinRowAddr => row_addr < existing.row_addr, + }, + }; + + if take_row { + let single = batch.slice(row_idx, 1); + self.winners.insert( + pk_hash, + Winner { + batch: single, + row_addr, + }, + ); + } + } + Ok(()) + } + + fn finalize(&mut self) -> DFResult { + if self.winners.is_empty() { + return Ok(RecordBatch::new_empty(self.schema.clone())); + } + let batches: Vec = self.winners.drain().map(|(_, w)| w.batch).collect(); + let batch_refs: Vec<&RecordBatch> = batches.iter().collect(); + arrow_select::concat::concat_batches(&self.schema, batch_refs) + .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None)) + } +} + +impl Stream for WithinSourceDedupStream { + type Item = DFResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.emitted { + return Poll::Ready(None); + } + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if let Err(e) = self.consume_batch(batch) { + self.emitted = true; + return Poll::Ready(Some(Err(e))); + } + } + Some(Err(e)) => { + self.emitted = true; + return Poll::Ready(Some(Err(e))); + } + None => { + self.emitted = true; + return Poll::Ready(Some(self.finalize())); + } + } + } + } +} + +impl datafusion::physical_plan::RecordBatchStream for WithinSourceDedupStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Float32Array, Int32Array, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion::prelude::SessionContext; + use datafusion_physical_plan::test::TestMemoryExec; + use futures::TryStreamExt; + + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("_distance", DataType::Float32, true), + Field::new("_row_addr", DataType::UInt64, true), + ])) + } + + fn batch(ids: &[i32], names: &[&str], distances: &[f32], row_addr: &[u64]) -> RecordBatch { + let schema = create_test_schema(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(ids.to_vec())), + Arc::new(StringArray::from(names.to_vec())), + Arc::new(Float32Array::from(distances.to_vec())), + Arc::new(UInt64Array::from(row_addr.to_vec())), + ], + ) + .unwrap() + } + + async fn run(batches: Vec, direction: DedupDirection) -> Vec { + let schema = create_test_schema(); + let input = TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap(); + let exec = + WithinSourceDedupExec::new(input, vec!["id".to_string()], "_row_addr", direction); + let ctx = SessionContext::new(); + let stream = exec.execute(0, ctx.task_ctx()).unwrap(); + stream.try_collect().await.unwrap() + } + + fn extract(batches: &[RecordBatch]) -> Vec<(i32, String, u64)> { + let mut out = Vec::new(); + for b in batches { + let ids = b.column(0).as_any().downcast_ref::().unwrap(); + let names = b.column(1).as_any().downcast_ref::().unwrap(); + let addr = b.column(3).as_any().downcast_ref::().unwrap(); + for i in 0..b.num_rows() { + out.push((ids.value(i), names.value(i).to_string(), addr.value(i))); + } + } + out.sort_by_key(|(id, _, _)| *id); + out + } + + #[tokio::test] + async fn keep_max_picks_largest_row_addr() { + // Active-memtable case: same pk inserted twice; newer = larger _rowaddr. + let b1 = batch( + &[1, 1, 2], + &["old", "new", "two"], + &[0.1, 0.2, 0.3], + &[10, 99, 5], + ); + let out = run(vec![b1], DedupDirection::KeepMaxRowAddr).await; + let rows = extract(&out); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0], (1, "new".to_string(), 99)); + assert_eq!(rows[1], (2, "two".to_string(), 5)); + } + + #[tokio::test] + async fn keep_min_picks_smallest_row_addr() { + // Flushed-memtable case under reverse-write: newer = smaller _rowid. + let b1 = batch( + &[1, 1, 2], + &["old", "new", "two"], + &[0.1, 0.2, 0.3], + &[99, 10, 5], + ); + let out = run(vec![b1], DedupDirection::KeepMinRowAddr).await; + let rows = extract(&out); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0], (1, "new".to_string(), 10)); + assert_eq!(rows[1], (2, "two".to_string(), 5)); + } + + #[tokio::test] + async fn dedup_across_batches() { + let b1 = batch(&[1, 2], &["a", "b"], &[0.1, 0.2], &[1, 1]); + let b2 = batch(&[1, 3], &["a_new", "c"], &[0.5, 0.4], &[7, 1]); + let out = run(vec![b1, b2], DedupDirection::KeepMaxRowAddr).await; + let rows = extract(&out); + assert_eq!(rows.len(), 3); + assert_eq!(rows[0], (1, "a_new".to_string(), 7)); + assert_eq!(rows[1], (2, "b".to_string(), 1)); + assert_eq!(rows[2], (3, "c".to_string(), 1)); + } + + #[tokio::test] + async fn empty_input() { + let out = run(vec![], DedupDirection::KeepMaxRowAddr).await; + let total: usize = out.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 0); + } + + #[tokio::test] + async fn null_row_addr_skipped() { + // Rows with NULL row address can't be ordered — they're dropped so they + // don't accidentally become winners against real values. + let schema = create_test_schema(); + let b = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(StringArray::from(vec!["nulladdr", "real"])), + Arc::new(Float32Array::from(vec![0.1, 0.2])), + Arc::new(UInt64Array::from(vec![None, Some(5)])), + ], + ) + .unwrap(); + let out = run(vec![b], DedupDirection::KeepMaxRowAddr).await; + let rows = extract(&out); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0], (1, "real".to_string(), 5)); + } +} diff --git a/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs b/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs index a70f200f14e..8063c5a2f3e 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/point_lookup.rs @@ -18,11 +18,14 @@ use tracing::instrument; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; -use super::exec::{BloomFilterGuardExec, CoalesceFirstExec, compute_pk_hash_from_scalars}; +use super::exec::{ + BloomFilterGuardExec, CoalesceFirstExec, DedupDirection, WithinSourceDedupExec, + compute_pk_hash_from_scalars, +}; use super::flushed_cache::{FlushedMemTableCache, open_flushed_dataset}; use super::projection::{ - build_scanner_projection, canonical_output_schema, project_to_canonical, wants_row_address, - wants_row_id, + build_scanner_projection, canonical_output_schema, null_columns, project_to_canonical, + wants_row_address, wants_row_id, }; use crate::session::Session; @@ -276,7 +279,26 @@ impl LsmPointLookupPlanner { MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone()); scanner.project(&cols.iter().map(|s| s.as_str()).collect::>()); scanner.filter_expr(filter.clone()); - scanner.create_plan().await? + // Expose `_rowid` (the BatchStore row offset, monotonic with + // insert order) so we can pick the most recently inserted + // duplicate below. Without this, a `FilterExec → LIMIT 1` + // over insert-ordered scan would return the *oldest* of + // multiple rows sharing the target primary key. + scanner.with_row_id(); + let raw = scanner.create_plan().await?; + // Within the active memtable, larger `_rowid` = newer + // insert. After dedup there is exactly one row per PK. + let deduped: Arc = Arc::new(WithinSourceDedupExec::new( + raw, + self.pk_columns.clone(), + lance_core::ROW_ID, + DedupDirection::KeepMaxRowAddr, + )); + // Per-source `_rowid` would collide with the base table's; + // NULL it before canonicalization (the value is internal to + // this arm). project_to_canonical drops it entirely when + // the user didn't request `_rowid` in the projection. + null_columns(deduped, &[lance_core::ROW_ID])? } }; project_to_canonical(scan, &target) @@ -629,4 +651,129 @@ mod tests { "empty point-lookup plan must honor user column order including system columns" ); } + + #[tokio::test] + async fn test_point_lookup_active_memtable_returns_newest_duplicate() { + // Regression: same primary key inserted twice into one active + // memtable must return the *newest* row. The bug was that + // `FilterExec → LIMIT 1` over an insert-ordered scan returned the + // first (oldest) match. `WithinSourceDedupExec` collapses by PK, + // keeping the row with the largest `_rowid` (insert order). + use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; + use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; + use futures::TryStreamExt; + + let schema = create_pk_schema(); + let temp_dir = tempfile::tempdir().unwrap(); + let base_uri = format!("{}/base", temp_dir.path().to_str().unwrap()); + + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut index_store = IndexStore::new(); + // BTree on the PK so that `max_visible_batch_position` advances as + // we insert, otherwise the scanner sees no batches at all. + index_store.add_btree("id_idx".to_string(), 0, "id".to_string()); + + // Two writes to pk=1, then an unrelated pk=2. The "new" row goes + // *second* so its `_rowid` is larger. + let b_old = create_test_batch(&schema, &[1], "old"); + let b_new = create_test_batch(&schema, &[1], "new"); + let b_other = create_test_batch(&schema, &[2], "two"); + let (_, _, bp_old) = batch_store.append(b_old.clone()).unwrap(); + index_store + .insert_with_batch_position(&b_old, 0, Some(bp_old)) + .unwrap(); + let (_, _, bp_new) = batch_store.append(b_new.clone()).unwrap(); + index_store + .insert_with_batch_position(&b_new, 1, Some(bp_new)) + .unwrap(); + let (_, _, bp_other) = batch_store.append(b_other.clone()).unwrap(); + index_store + .insert_with_batch_position(&b_other, 2, Some(bp_other)) + .unwrap(); + let index_store = Arc::new(index_store); + + let shard_id = Uuid::new_v4(); + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]) + .with_in_memory_memtables( + shard_id, + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store, + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }, + ); + + let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema); + + let plan = planner + .plan_lookup(&[ScalarValue::Int32(Some(1))], None) + .await + .unwrap(); + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 1, "expected exactly one row for pk=1"); + let name_col = batches[0].column_by_name("name").unwrap(); + let name_arr = name_col.as_any().downcast_ref::().unwrap(); + assert_eq!( + name_arr.value(0), + "new_1", + "active-arm lookup must return the newer insert, not the oldest" + ); + } + + #[tokio::test] + async fn test_point_lookup_flushed_memtable_returns_newest_duplicate() { + // Regression / invariant pin: when a flushed memtable contains two + // rows for the same PK, the lookup must return the newer one. The + // flushed dataset is reverse-written (newest at the smallest + // physical position), so we simulate that here by writing the + // dataset with the new row first. The point-lookup plan today + // returns the first match (smallest `_rowid`) under reverse-write, + // and remains so after this change. + use futures::TryStreamExt; + + let schema = create_pk_schema(); + let temp_dir = tempfile::tempdir().unwrap(); + let base_path = temp_dir.path().to_str().unwrap(); + let base_uri = format!("{}/base", base_path); + + // Simulated reverse-write: newest insert lives at row 0. + let shard_id = Uuid::new_v4(); + let gen1_uri = format!("{}/_mem_wal/{}/gen_1", base_uri, shard_id); + let row_new = create_test_batch(&schema, &[1], "new"); + let row_old = create_test_batch(&schema, &[1], "old"); + create_dataset(&gen1_uri, vec![row_new, row_old]).await; + + let shard_snapshot = ShardSnapshot::new(shard_id) + .with_current_generation(2) + .with_flushed_generation(1, "gen_1".to_string()); + + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![shard_snapshot]); + let planner = LsmPointLookupPlanner::new(collector, vec!["id".to_string()], schema); + + let plan = planner + .plan_lookup(&[ScalarValue::Int32(Some(1))], None) + .await + .unwrap(); + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 1, "expected exactly one row for pk=1"); + let name_col = batches[0].column_by_name("name").unwrap(); + let name_arr = name_col.as_any().downcast_ref::().unwrap(); + assert_eq!( + name_arr.value(0), + "new_1", + "flushed-arm lookup must return the row at the smallest _rowid (newest under reverse-write)" + ); + } } From 6e0e5a685d7786679e9d48dd2aaadc6a964089e3 Mon Sep 17 00:00:00 2001 From: YueZhang <69956021+zhangyue19921010@users.noreply.github.com> Date: Thu, 21 May 2026 20:37:52 +0800 Subject: [PATCH 17/23] feat: expose multi-base config to Python and Java write_fragments API (#6855) --- java/lance-jni/src/fragment.rs | 33 +++- java/src/main/java/org/lance/Fragment.java | 12 ++ .../java/org/lance/WriteFragmentBuilder.java | 40 +++++ .../test/java/org/lance/MultiBaseTest.java | 160 ++++++++++++++++-- python/python/lance/fragment.py | 15 ++ python/python/lance/lance/__init__.pyi | 6 + python/python/tests/test_multi_base.py | 37 ++++ rust/lance/src/dataset/fragment/write.rs | 82 ++++++++- 8 files changed, 364 insertions(+), 21 deletions(-) diff --git a/java/lance-jni/src/fragment.rs b/java/lance-jni/src/fragment.rs index 5e2812ddd75..a6798c2f237 100644 --- a/java/lance-jni/src/fragment.rs +++ b/java/lance-jni/src/fragment.rs @@ -101,6 +101,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiArray<'local>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -120,6 +123,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiArray<'local>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -142,6 +148,9 @@ fn inner_create_with_ffi_array<'local>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -170,6 +179,9 @@ fn inner_create_with_ffi_array<'local>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -191,6 +203,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiStream<'a>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -209,6 +224,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiStream<'a>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -230,6 +248,9 @@ fn inner_create_with_ffi_stream<'local>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -248,6 +269,9 @@ fn inner_create_with_ffi_stream<'local>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -267,6 +291,9 @@ fn create_fragment<'a>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -285,9 +312,9 @@ fn create_fragment<'a>( &data_storage_version, None, &storage_options_obj, - &JObject::null(), // base store params are not used when creating fragments - &JObject::null(), // not used when creating fragments - &JObject::null(), // not used when creating fragments + &base_store_params_obj, + &initial_bases, + &target_bases, &allow_external_blob_outside_bases, &blob_pack_file_size_threshold, )?; diff --git a/java/src/main/java/org/lance/Fragment.java b/java/src/main/java/org/lance/Fragment.java index 43091269382..b27b189bf48 100644 --- a/java/src/main/java/org/lance/Fragment.java +++ b/java/src/main/java/org/lance/Fragment.java @@ -278,6 +278,9 @@ static List create( params.getEnableStableRowIds(), params.getDataStorageVersion(), params.getStorageOptions(), + params.getBaseStoreParams(), + params.getInitialBases(), + params.getTargetBases(), namespaceClient, tableId, params.getAllowExternalBlobOutsideBases(), @@ -305,6 +308,9 @@ static List create( params.getEnableStableRowIds(), params.getDataStorageVersion(), params.getStorageOptions(), + params.getBaseStoreParams(), + params.getInitialBases(), + params.getTargetBases(), namespaceClient, tableId, params.getAllowExternalBlobOutsideBases(), @@ -323,6 +329,9 @@ private static native List createWithFfiArray( Optional enableStableRowIds, Optional dataStorageVersion, Map storageOptions, + Map> baseStoreParams, + Optional> initialBases, + Optional> targetBases, LanceNamespace namespaceClient, List tableId, Optional allowExternalBlobOutsideBases, @@ -339,6 +348,9 @@ private static native List createWithFfiStream( Optional enableStableRowIds, Optional dataStorageVersion, Map storageOptions, + Map> baseStoreParams, + Optional> initialBases, + Optional> targetBases, LanceNamespace namespaceClient, List tableId, Optional allowExternalBlobOutsideBases, diff --git a/java/src/main/java/org/lance/WriteFragmentBuilder.java b/java/src/main/java/org/lance/WriteFragmentBuilder.java index 693b5b6bc87..5d7dc1a42b2 100644 --- a/java/src/main/java/org/lance/WriteFragmentBuilder.java +++ b/java/src/main/java/org/lance/WriteFragmentBuilder.java @@ -123,6 +123,22 @@ public WriteFragmentBuilder storageOptions(Map storageOptions) { return this; } + /** + * Set runtime-only object store parameters for registered base paths. + * + *

Entries are keyed by the exact {@link BasePath#getPath()} value persisted in the manifest. + * Each value is the storage options map used for that base. Bases without an explicit entry use + * {@link #storageOptions(Map)} as the fallback. + * + * @param baseStoreParams object store parameters keyed by base path URI + * @return this builder + */ + public WriteFragmentBuilder baseStoreParams(Map> baseStoreParams) { + ensureWriteParamsBuilder(); + this.writeParamsBuilder.withBaseStoreParams(baseStoreParams); + return this; + } + /** * Set the namespace client for automatic credential refresh. * @@ -223,6 +239,30 @@ public WriteFragmentBuilder dataStorageVersion(String version) { return this; } + /** + * Register base paths when creating a new dataset from fragments. + * + * @param bases base paths to register + * @return this builder + */ + public WriteFragmentBuilder initialBases(List bases) { + ensureWriteParamsBuilder(); + this.writeParamsBuilder.withInitialBases(bases); + return this; + } + + /** + * Set base names or paths where new fragment files should be written. + * + * @param targetBases base names or exact paths + * @return this builder + */ + public WriteFragmentBuilder targetBases(List targetBases) { + ensureWriteParamsBuilder(); + this.writeParamsBuilder.withTargetBases(targetBases); + return this; + } + /** * Execute the fragment write operation. * diff --git a/java/src/test/java/org/lance/MultiBaseTest.java b/java/src/test/java/org/lance/MultiBaseTest.java index 802b4a2ca31..e4f1d982e14 100644 --- a/java/src/test/java/org/lance/MultiBaseTest.java +++ b/java/src/test/java/org/lance/MultiBaseTest.java @@ -36,13 +36,18 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; public class MultiBaseTest { private BufferAllocator allocator; @@ -103,6 +108,51 @@ private ArrowStreamReader makeReader(int startId, int count) throws Exception { } } + private VectorSchemaRoot makeRoot(int startId, int count) { + List fields = + Arrays.asList( + new Field("id", FieldType.notNullable(new ArrowType.Int(32, true)), null), + new Field("value", FieldType.nullable(new ArrowType.Utf8()), null)); + Schema schema = new Schema(fields); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + IntVector idVec = (IntVector) root.getVector("id"); + idVec.allocateNew(count); + VarCharVector valVec = (VarCharVector) root.getVector("value"); + valVec.allocateNew(); + for (int i = 0; i < count; i++) { + int id = startId + i; + idVec.setSafe(i, id); + byte[] b = ("val_" + id).getBytes(); + valVec.setSafe(i, b, 0, b.length); + } + root.setRowCount(count); + return root; + } + + private boolean hasLanceFile(String basePath) throws Exception { + try (Stream paths = Files.walk(Path.of(basePath))) { + return paths.anyMatch(path -> path.toString().endsWith(".lance")); + } + } + + private Set baseIds(Dataset dataset) { + return dataset.getFragments().stream() + .flatMap(f -> f.metadata().getFiles().stream()) + .map(DataFile::getBaseId) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); + } + + private Set baseIds(List fragments) { + return fragments.stream() + .flatMap(fragment -> fragment.getFiles().stream()) + .map(DataFile::getBaseId) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); + } + @Test public void testCreateMode() throws Exception { ArrowStreamReader reader = makeReader(0, 500); @@ -211,12 +261,7 @@ public void testTargetByPathUri() throws Exception { .maxRowsPerFile(50) .execute(); - Set baseIds = - ds.getFragments().stream() - .flatMap(f -> f.metadata().getFiles().stream().map(DataFile::getBaseId)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(Collectors.toSet()); + Set baseIds = baseIds(ds); assertEquals(1, baseIds.size()); ArrowStreamReader append = makeReader(100, 50); @@ -231,12 +276,103 @@ public void testTargetByPathUri() throws Exception { .execute(); assertEquals(150, updated.countRows()); - baseIds = - updated.getFragments().stream() - .flatMap(f -> f.metadata().getFiles().stream().map(DataFile::getBaseId)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(Collectors.toSet()); + baseIds = baseIds(updated); assertEquals(2, baseIds.size()); } + + @Test + public void testFragmentCreateWithMultiBaseParams() throws Exception { + ArrowStreamReader reader = makeReader(0, 100); + List bases = + Arrays.asList( + new BasePath(0, Optional.of("base1"), base1, false), + new BasePath(0, Optional.of("base2"), base2, false)); + + Dataset ds = + Dataset.write() + .allocator(allocator) + .reader(reader) + .uri(primary) + .mode(WriteParams.WriteMode.CREATE) + .initialBases(bases) + .targetBases(Arrays.asList("base1")) + .maxRowsPerFile(50) + .execute(); + Set initialBaseIds = baseIds(ds); + assertEquals(1, initialBaseIds.size()); + assertTrue(hasLanceFile(base1)); + + Map> baseStoreParams = new HashMap<>(); + baseStoreParams.put(base2, new HashMap<>()); + WriteParams params = + new WriteParams.Builder() + .withTargetBases(Arrays.asList("base2")) + .withBaseStoreParams(baseStoreParams) + .withMaxRowsPerFile(25) + .build(); + + List fragments; + try (VectorSchemaRoot root = makeRoot(100, 50)) { + fragments = Fragment.create(primary, allocator, root, params); + } + + assertEquals(2, fragments.size()); + Set fragmentBaseIds = baseIds(fragments); + assertEquals(1, fragmentBaseIds.size()); + assertTrue(Collections.disjoint(initialBaseIds, fragmentBaseIds)); + assertTrue(hasLanceFile(base2)); + + FragmentOperation.Append append = new FragmentOperation.Append(fragments); + Dataset updated = Dataset.commit(allocator, primary, append, Optional.of(ds.version())); + assertEquals(150, updated.countRows()); + assertEquals(2, baseIds(updated).size()); + } + + @Test + public void testFragmentWriteWithMultiBaseParams() throws Exception { + ArrowStreamReader reader = makeReader(0, 50); + List bases = + Arrays.asList( + new BasePath(0, Optional.of("base1"), base1, false), + new BasePath(0, Optional.of("base2"), base2, false)); + + Dataset ds = + Dataset.write() + .allocator(allocator) + .reader(reader) + .uri(primary) + .mode(WriteParams.WriteMode.CREATE) + .initialBases(bases) + .targetBases(Arrays.asList("base1")) + .maxRowsPerFile(50) + .execute(); + Set initialBaseIds = baseIds(ds); + assertEquals(1, initialBaseIds.size()); + assertTrue(hasLanceFile(base1)); + + Map> baseStoreParams = new HashMap<>(); + baseStoreParams.put(base2, new HashMap<>()); + + List fragments; + try (VectorSchemaRoot root = makeRoot(50, 25)) { + fragments = + Fragment.write() + .datasetUri(primary) + .allocator(allocator) + .data(root) + .targetBases(Arrays.asList("base2")) + .baseStoreParams(baseStoreParams) + .maxRowsPerFile(25) + .execute(); + } + + Set fragmentBaseIds = baseIds(fragments); + assertEquals(1, fragmentBaseIds.size()); + assertTrue(Collections.disjoint(initialBaseIds, fragmentBaseIds)); + FragmentOperation.Append append = new FragmentOperation.Append(fragments); + Dataset updated = Dataset.commit(allocator, primary, append, Optional.of(ds.version())); + assertEquals(75, updated.countRows()); + assertEquals(2, baseIds(updated).size()); + assertTrue(hasLanceFile(base2)); + } } diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index e76baf7e5dd..75722ac82e1 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -1011,6 +1011,7 @@ def write_fragments( enable_stable_row_ids: bool = False, target_bases: Optional[List[str]] = None, initial_bases: Optional[List["DatasetBasePath"]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, namespace_client: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, ) -> Transaction: ... @@ -1033,6 +1034,7 @@ def write_fragments( enable_stable_row_ids: bool = False, target_bases: Optional[List[str]] = None, initial_bases: Optional[List["DatasetBasePath"]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, namespace_client: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, ) -> List[FragmentMetadata]: ... @@ -1055,6 +1057,7 @@ def write_fragments( enable_stable_row_ids: bool = False, target_bases: Optional[List[str]] = None, initial_bases: Optional[List["DatasetBasePath"]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, namespace_client: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, ) -> List[FragmentMetadata] | Transaction: @@ -1132,6 +1135,13 @@ def write_fragments( **Only valid in CREATE mode**. Will raise an error if used with APPEND/OVERWRITE modes. + base_store_params : dict of str to dict, optional + Runtime-only object store parameters keyed by exact base path URI. + Each value is a dict of storage options for that base. These settings + are not persisted to the manifest. When a base has no explicit entry, + top-level ``storage_options`` is used as a fallback. If ``dataset_uri`` + is a LanceDataset and this is omitted, the dataset's base store params + are inherited. namespace_client : optional, LanceNamespace A namespace client for automatic credential refresh. When provided with `table_id`, a storage options provider will be created automatically to @@ -1173,6 +1183,10 @@ def write_fragments( if isinstance(dataset_uri, Path): dataset_uri = str(dataset_uri) elif isinstance(dataset_uri, LanceDataset): + if base_store_params is None: + base_store_params = dataset_uri._base_store_params + if storage_options is None: + storage_options = dataset_uri._storage_options dataset_uri = dataset_uri._ds elif not isinstance(dataset_uri, str): raise TypeError(f"Unknown dataset_uri type {type(dataset_uri)}") @@ -1204,6 +1218,7 @@ def write_fragments( enable_stable_row_ids=enable_stable_row_ids, target_bases=target_bases, initial_bases=initial_bases, + base_store_params=base_store_params, ) diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 39f22e8aded..dd91fb19614 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -559,6 +559,9 @@ def _write_fragments( namespace_client: Optional[LanceNamespace], table_id: Optional[List[str]], enable_stable_row_ids: bool, + target_bases: Optional[List[str]] = None, + initial_bases: Optional[List[Any]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, ): ... def _write_fragments_transaction( dataset_uri: str | Path | _Dataset, @@ -573,6 +576,9 @@ def _write_fragments_transaction( namespace_client: Optional[LanceNamespace], table_id: Optional[List[str]], enable_stable_row_ids: bool, + target_bases: Optional[List[str]] = None, + initial_bases: Optional[List[Any]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, ) -> Transaction: ... def _json_to_schema(schema_json: str) -> pa.Schema: ... def _schema_to_json(schema: pa.Schema) -> str: ... diff --git a/python/python/tests/test_multi_base.py b/python/python/tests/test_multi_base.py index baa437b4fc6..113f3224d14 100644 --- a/python/python/tests/test_multi_base.py +++ b/python/python/tests/test_multi_base.py @@ -1031,6 +1031,7 @@ def test_write_fragments_with_target_bases(self): dataset, mode="append", target_bases=["base2"], + base_store_params={self.base2_uri: {}}, max_rows_per_file=25, ) @@ -1054,6 +1055,42 @@ def test_write_fragments_with_target_bases(self): data_files = list(base2_path.glob("**/*.lance")) assert len(data_files) > 0, "Expected data files in base2" + def test_write_fragments_with_base_store_params(self): + """Test write_fragments inherits base_store_params from LanceDataset.""" + initial_data = pd.DataFrame({"id": range(10), "value": range(10)}) + + dataset = lance.write_dataset( + initial_data, + self.primary_uri, + mode="create", + initial_bases=[ + DatasetBasePath(self.base1_uri, name="base1"), + DatasetBasePath(self.base2_uri, name="base2"), + ], + target_bases=["base1"], + base_store_params={self.base2_uri: {}}, + max_rows_per_file=10, + ) + + fragment_data = pd.DataFrame({"id": range(10, 20), "value": range(10, 20)}) + fragments = write_fragments( + pa.Table.from_pandas(fragment_data), + dataset, + mode="append", + target_bases=["base2"], + max_rows_per_file=10, + ) + + operation = lance.LanceOperation.Append(fragments) + dataset = lance.LanceDataset.commit( + dataset, operation, read_version=dataset.version + ) + + result = dataset.to_table().to_pandas() + assert len(result) == 20 + assert set(result["id"]) == set(range(20)) + assert list(Path(self.base2_uri).glob("**/*.lance")) + def test_write_fragments_transaction_with_target_bases(self): """Test write_fragments with return_transaction and target_bases.""" # Create initial dataset diff --git a/rust/lance/src/dataset/fragment/write.rs b/rust/lance/src/dataset/fragment/write.rs index 9131d3ecb04..b10158224f9 100644 --- a/rust/lance/src/dataset/fragment/write.rs +++ b/rust/lance/src/dataset/fragment/write.rs @@ -20,8 +20,8 @@ use uuid::Uuid; use crate::Result; use crate::dataset::builder::DatasetBuilder; -use crate::dataset::write::do_write_fragments; -use crate::dataset::{DATA_DIR, WriteMode, WriteParams}; +use crate::dataset::write::{do_write_fragments, validate_and_resolve_target_bases}; +use crate::dataset::{DATA_DIR, Dataset, ReadParams, WriteMode, WriteParams}; /// Generates a filename optimized for S3 throughput using a UUID-based approach. /// @@ -196,11 +196,27 @@ impl<'a> FragmentCreateBuilder<'a> { stream: SendableRecordBatchStream, schema: Schema, ) -> Result> { - let params = self.write_params.map(Cow::Borrowed).unwrap_or_default(); + let mut params = self.write_params.cloned().unwrap_or_default(); Self::validate_schema(&schema, stream.schema().as_ref())?; let version = params.data_storage_version.unwrap_or_default(); + let needs_existing_dataset = params.target_base_names_or_paths.is_some() + || params.target_bases.is_some() + || params.initial_bases.is_some(); + let existing_dataset = if needs_existing_dataset { + self.existing_dataset(¶ms).await? + } else { + None + }; + let existing_base_paths = existing_dataset + .as_ref() + .map(|dataset| &dataset.manifest.base_paths); + let target_bases_info = if needs_existing_dataset { + validate_and_resolve_target_bases(&mut params, existing_base_paths).await? + } else { + None + }; let (object_store, base_path) = ObjectStore::from_uri_and_params( params.store_registry(), self.dataset_uri, @@ -208,14 +224,14 @@ impl<'a> FragmentCreateBuilder<'a> { ) .await?; do_write_fragments( - None, + existing_dataset.as_ref(), object_store, &base_path, &schema, stream, - params.into_owned(), + params, version, - None, // Fragment creation doesn't use target_bases + target_bases_info, ) .await } @@ -313,6 +329,25 @@ impl<'a> FragmentCreateBuilder<'a> { } } + async fn existing_dataset(&self, params: &WriteParams) -> Result> { + let mut builder = DatasetBuilder::from_uri(self.dataset_uri).with_read_params(ReadParams { + store_options: params.store_params.clone(), + commit_handler: params.commit_handler.clone(), + session: params.session.clone(), + ..Default::default() + }); + if let Some(base_store_params) = ¶ms.base_store_params { + for (base_path, store_params) in base_store_params { + builder = builder.with_base_store_params(base_path, store_params.clone()); + } + } + match builder.load().await { + Ok(dataset) => Ok(Some(dataset)), + Err(Error::DatasetNotFound { .. } | Error::NotFound { .. }) => Ok(None), + Err(e) => Err(e), + } + } + fn validate_schema(expected: &Schema, actual: &ArrowSchema) -> Result<()> { if actual.fields().is_empty() { return Err(Error::invalid_input("Cannot write with an empty schema.")); @@ -334,9 +369,11 @@ mod tests { use arrow_schema::{DataType, Field as ArrowField}; use lance_arrow::SchemaExt; use lance_core::utils::tempfile::{TempDir, TempStrDir}; + use lance_table::format::BasePath; use rstest::rstest; use super::*; + use crate::dataset::InsertBuilder; fn test_data() -> Box { let schema = Arc::new(ArrowSchema::new(vec![ @@ -533,6 +570,39 @@ mod tests { assert_eq!(fragments[2].files[0].column_indices.as_ref(), &[0, 1]); } + #[tokio::test] + async fn test_write_fragments_with_target_base() { + let primary = TempStrDir::default(); + let base1 = TempStrDir::default(); + let base2 = TempStrDir::default(); + let create_params = WriteParams::default() + .with_initial_bases(vec![ + BasePath::new(0, base1.to_string(), Some("base1".to_string()), false), + BasePath::new(0, base2.to_string(), Some("base2".to_string()), false), + ]) + .with_target_base_names_or_paths(vec!["base1".to_string()]); + + let dataset = InsertBuilder::new(primary.as_str()) + .with_params(&create_params) + .execute_stream(test_data()) + .await + .unwrap(); + + let append_params = WriteParams { + mode: WriteMode::Append, + ..Default::default() + } + .with_target_base_names_or_paths(vec!["base2".to_string()]); + let fragments = FragmentCreateBuilder::new(dataset.uri.as_str()) + .write_params(&append_params) + .write_fragments(test_data()) + .await + .unwrap(); + + assert_eq!(fragments.len(), 1); + assert_eq!(fragments[0].files[0].base_id, Some(2)); + } + #[rstest] #[tokio::test] async fn test_write_with_format_version( From c39f31e1d1c20e9bec74481654fdecfc1ea38711 Mon Sep 17 00:00:00 2001 From: Brendan Clement Date: Thu, 21 May 2026 08:43:35 -0700 Subject: [PATCH 18/23] fix: stop double-counting child CPU in node-with-children Exec plans (#6799) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes #5155. `InstrumentedRecordBatchStreamAdapter` measures a node's `elapsed_compute` by timing its outer `poll_next`. For an `ExecutionPlan` node with child inputs, that poll transitively polls every child — so `EXPLAIN ANALYZE` shows the parent's CPU as `parent + child + grandchild + ...`, and each ancestor double-counts its descendants. This PR fixes five nodes that have child inputs and were using the broken wrapper. The fix takes two shapes depending on the node: 1. **Four nodes** with a clean "child input -> per-batch transform -> output" shape are converted to use a new helper, `InstrumentedChildInputStream` (in `rust/lance/src/io/exec/utils.rs`), modeled on DataFusion's `FilterExecStream`. The helper pulls from a child input **without** the timer running, then drives a per-batch async transform **with** the timer running. 2. **`FlatMatchQueryExec`** doesn't fit that shape — its work all happens inside `flat_bm25_search_stream`, which consumes the child input and `spawn_cpu`s the per-batch tokenize/count work internally. For this node, the fix instruments `flat_bm25_search_stream` directly so it can report CPU time on a metric handle supplied by the caller. ## Nodes fixed | Node | File | Mechanism | |---|---|---| | `AddRowAddrExec` | `rust/lance/src/io/exec/rowids.rs` | helper | | `MapIndexExec` | `rust/lance/src/io/exec/scalar_index.rs` | helper | | `FlatMatchFilterExec` | `rust/lance/src/io/exec/fts.rs` | helper | | `KNNVectorDistanceExec` | `rust/lance/src/io/exec/knn.rs` | helper + caller-side `Instant::now()` around `compute_distance(...).await` to capture `spawn_blocking` work | | `FlatMatchQueryExec` | `rust/lance/src/io/exec/fts.rs` + `rust/lance-index/src/scalar/inverted/index.rs` | new `flat_bm25_search_stream_with_metrics` that records CPU on a supplied `Time` handle | For `KNNVectorDistanceExec`, this also addresses the spawned-CPU undercount from #5155: the helper's timer doesn't measure work happening on `spawn_blocking` worker threads. KNN's transform closure wraps `compute_distance(...).await` with `Instant::now()` and adds the elapsed duration to `elapsed_compute` from the caller side. `compute_distance`'s public signature is unchanged. For `FlatMatchQueryExec`, the analogous mechanism lives in `lance-index`: `flat_bm25_search_stream_with_metrics` accepts an `Option

Lance owns the parquet reader and refinement; callers see {@link SearchResult} as {@code + * (filePath, rowIndex, distance)} and never have to decode an internal rid. + * + *

Lifecycle: {@link #build} writes an index file under an output URI, returning a UUID that + * names the index directory; {@link #open} returns an open handle. {@link #close} releases the + * handle. The handle is {@link AutoCloseable}; use try-with-resources where possible. + * + *

Example: + * + *

{@code
+ * String uuid = ExternalIvfPqIndex.build(
+ *     List.of("/data/embeddings-0.parquet", "/data/embeddings-1.parquet"),
+ *     "vec",
+ *     "/index/v1",
+ *     ExternalIvfPqIndexParams.builder().numPartitions(256).numSubVectors(16).build());
+ *
+ * try (ExternalIvfPqIndex idx = ExternalIvfPqIndex.open("/index/v1/" + uuid)) {
+ *   List hits = idx.search(query, 10, 16, 8, null);
+ *   byte[] arrowIpc = idx.fetchRows(
+ *       hits.stream().map(h -> ParquetRowKey.of(h.getFilePath(), h.getRowIndex())).toList(),
+ *       List.of("doc_id", "title"));
+ * }
+ * }
+ */ +public final class ExternalIvfPqIndex implements AutoCloseable { + + static { + JniLoader.ensureLoaded(); + } + + private long handle; + + private ExternalIvfPqIndex(long handle) { + this.handle = handle; + } + + // ---- build / open --------------------------------------------------------- + + /** + * Build an external IVF-PQ index over the registered parquet files. + * + *

Writes a single index directory under {@code outputUri} named by the returned UUID. The + * directory contains {@code index.idx} (IVF model + PQ codebooks) and {@code manifest.json} + * (parquet file list + build params). + * + *

The {@code file_id} encoded into rids is implicit in {@code filePaths}'s position; + * reordering invalidates the index. + * + * @return UUID directory name (not the full URI); join with {@code outputUri} to get the open + * URI. + */ + public static String build( + List filePaths, + String vectorColumn, + String outputUri, + ExternalIvfPqIndexParams params) { + String[] paths = filePaths.toArray(new String[0]); + return nativeBuild( + paths, + vectorColumn, + outputUri, + params.getNumPartitions(), + params.getNumSubVectors(), + params.getNumBitsPerSubVector(), + params.getMetric().toRustString(), + params.getMaxIters(), + params.getSampleRate(), + params.getSeed()); + } + + /** + * Open an external IVF-PQ index by URI. Cheap: reads {@code manifest.json} and the index file + * footer. + */ + public static ExternalIvfPqIndex open(String uri) { + long handle = nativeOpen(uri); + return new ExternalIvfPqIndex(handle); + } + + /** + * Run a vector query and return up to {@code k} refined results. + * + * @param query Query vector. Length must match the index's dimension. + * @param k Number of results to return. + * @param nprobes Number of IVF partitions to probe. + * @param refineFactor Re-rank multiplier; {@code k * refineFactor} approximate candidates are + * fetched, refined exactly, then trimmed to {@code k}. + * @param deletedRids Optional little-endian {@code u64} packed array of deleted rids encoded as + * {@code (file_id << 32) | row_index}. Survivors will exclude these. {@code null} means no + * filter. + */ + public List search( + float[] query, int k, int nprobes, int refineFactor, byte[] deletedRids) { + SearchResult[] arr = nativeSearch(handle, query, k, nprobes, refineFactor, deletedRids); + return java.util.Arrays.asList(arr); + } + + /** + * Fetch arbitrary projection columns for {@code rowKeys} from the registered parquet files. + * + *

Returns Arrow IPC stream bytes; decode with {@code ArrowStreamReader} on the Java side. The + * batch has one row per input key, in caller-input order. + * + *

Use this for post-topK materialization: pass the {@code (filePath, rowIndex)} pairs from + * {@link #search} and project only the columns you need. + */ + public byte[] fetchRows(List rowKeys, List projection) { + String[] paths = new String[rowKeys.size()]; + long[] rows = new long[rowKeys.size()]; + for (int i = 0; i < rowKeys.size(); i++) { + paths[i] = rowKeys.get(i).getFilePath(); + rows[i] = rowKeys.get(i).getRowIndex(); + } + String[] proj = projection.toArray(new String[0]); + return nativeFetchRows(handle, paths, rows, proj); + } + + /** Pack a list of {@code (file_id, row_index)} deletes into the byte format {@link #search}. */ + public static byte[] packDeletedRids(List deletedFileRowPairs) { + ByteBuffer buf = + ByteBuffer.allocate(deletedFileRowPairs.size() * Long.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (long[] pair : deletedFileRowPairs) { + if (pair.length != 2) { + throw new IllegalArgumentException( + "expected (file_id, row_index) pair; got length " + pair.length); + } + long rid = (pair[0] << 32) | (pair[1] & 0xFFFFFFFFL); + buf.putLong(rid); + } + return buf.array(); + } + + /** Number of IVF partitions. */ + public int getNumPartitions() { + return nativeNumPartitions(handle); + } + + /** Number of registered parquet files. */ + public int getNumFiles() { + return nativeNumFiles(handle); + } + + /** Vector column the index was built over. */ + public String getVectorColumn() { + return nativeVectorColumn(handle); + } + + @Override + public void close() { + if (handle != 0L) { + nativeClose(handle); + handle = 0L; + } + } + + // ---- native methods ------------------------------------------------------- + + private static native String nativeBuild( + String[] filePaths, + String vectorColumn, + String outputUri, + int numPartitions, + int numSubVectors, + int numBitsPerSubVector, + String metric, + int maxIters, + int sampleRate, + long seed); + + private static native long nativeOpen(String uri); + + private static native void nativeClose(long handle); + + private static native SearchResult[] nativeSearch( + long handle, float[] query, int k, int nprobes, int refineFactor, byte[] deletedRids); + + private static native byte[] nativeFetchRows( + long handle, String[] filePaths, long[] rowIndices, String[] projection); + + private static native int nativeNumPartitions(long handle); + + private static native int nativeNumFiles(long handle); + + private static native String nativeVectorColumn(long handle); +} diff --git a/java/src/main/java/org/lance/index/external/ExternalIvfPqIndexParams.java b/java/src/main/java/org/lance/index/external/ExternalIvfPqIndexParams.java new file mode 100644 index 00000000000..cce3f7852a8 --- /dev/null +++ b/java/src/main/java/org/lance/index/external/ExternalIvfPqIndexParams.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +/** + * Build configuration for {@link ExternalIvfPqIndex#build}. Defaults match Lance's IVF-PQ defaults: + * {@code numPartitions=256}, {@code numSubVectors=16}, {@code numBitsPerSubVector=8}, metric = + * {@link Metric#L2}. + */ +public final class ExternalIvfPqIndexParams { + + /** Distance metric used by the index. Mirrors lance::Distance. */ + public enum Metric { + L2, + Cosine, + Dot; + + String toRustString() { + switch (this) { + case L2: + return "L2"; + case Cosine: + return "Cosine"; + case Dot: + return "Dot"; + default: + throw new IllegalStateException("unknown metric: " + this); + } + } + } + + private final int numPartitions; + private final int numSubVectors; + private final int numBitsPerSubVector; + private final Metric metric; + private final int maxIters; + private final int sampleRate; + private final long seed; + + private ExternalIvfPqIndexParams(Builder b) { + this.numPartitions = b.numPartitions; + this.numSubVectors = b.numSubVectors; + this.numBitsPerSubVector = b.numBitsPerSubVector; + this.metric = b.metric; + this.maxIters = b.maxIters; + this.sampleRate = b.sampleRate; + this.seed = b.seed; + } + + public int getNumPartitions() { + return numPartitions; + } + + public int getNumSubVectors() { + return numSubVectors; + } + + public int getNumBitsPerSubVector() { + return numBitsPerSubVector; + } + + public Metric getMetric() { + return metric; + } + + public int getMaxIters() { + return maxIters; + } + + public int getSampleRate() { + return sampleRate; + } + + public long getSeed() { + return seed; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private int numPartitions = 256; + private int numSubVectors = 16; + private int numBitsPerSubVector = 8; + private Metric metric = Metric.L2; + private int maxIters = 50; + private int sampleRate = 256; + private long seed = 0xCAFEBABEDEADBEEFL; + + public Builder numPartitions(int n) { + this.numPartitions = n; + return this; + } + + public Builder numSubVectors(int n) { + this.numSubVectors = n; + return this; + } + + public Builder numBitsPerSubVector(int n) { + this.numBitsPerSubVector = n; + return this; + } + + public Builder metric(Metric m) { + this.metric = m; + return this; + } + + public Builder maxIters(int n) { + this.maxIters = n; + return this; + } + + public Builder sampleRate(int n) { + this.sampleRate = n; + return this; + } + + public Builder seed(long seed) { + this.seed = seed; + return this; + } + + public ExternalIvfPqIndexParams build() { + return new ExternalIvfPqIndexParams(this); + } + } +} diff --git a/java/src/main/java/org/lance/index/external/ParquetRowKey.java b/java/src/main/java/org/lance/index/external/ParquetRowKey.java new file mode 100644 index 00000000000..d9f7dc79cef --- /dev/null +++ b/java/src/main/java/org/lance/index/external/ParquetRowKey.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import java.util.Objects; + +/** {@code (file_path, row_index)} input to {@link ExternalIvfPqIndex#fetchRows}. */ +public final class ParquetRowKey { + private final String filePath; + private final long rowIndex; + + public ParquetRowKey(String filePath, long rowIndex) { + this.filePath = filePath; + this.rowIndex = rowIndex; + } + + public static ParquetRowKey of(String filePath, long rowIndex) { + return new ParquetRowKey(filePath, rowIndex); + } + + public String getFilePath() { + return filePath; + } + + public long getRowIndex() { + return rowIndex; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ParquetRowKey)) { + return false; + } + ParquetRowKey that = (ParquetRowKey) o; + return rowIndex == that.rowIndex && Objects.equals(filePath, that.filePath); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, rowIndex); + } +} diff --git a/java/src/main/java/org/lance/index/external/SearchResult.java b/java/src/main/java/org/lance/index/external/SearchResult.java new file mode 100644 index 00000000000..1bd82c62681 --- /dev/null +++ b/java/src/main/java/org/lance/index/external/SearchResult.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import java.util.Objects; + +/** + * One result from {@link ExternalIvfPqIndex#search}. The result is already refined; the {@code + * distance} is the exact distance under the index's metric. + */ +public final class SearchResult { + private final String filePath; + private final long rowIndex; + private final float distance; + + public SearchResult(String filePath, long rowIndex, float distance) { + this.filePath = filePath; + this.rowIndex = rowIndex; + this.distance = distance; + } + + /** Path of the parquet file the row lives in (one of the registered file specs). */ + public String getFilePath() { + return filePath; + } + + /** Zero-based row index within {@link #getFilePath()}. */ + public long getRowIndex() { + return rowIndex; + } + + /** Exact distance under the index's distance metric. */ + public float getDistance() { + return distance; + } + + @Override + public String toString() { + return "SearchResult{" + filePath + "@" + rowIndex + " d=" + distance + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SearchResult)) { + return false; + } + SearchResult that = (SearchResult) o; + return rowIndex == that.rowIndex + && Float.compare(distance, that.distance) == 0 + && Objects.equals(filePath, that.filePath); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, rowIndex, distance); + } +} diff --git a/java/src/test/java/org/lance/index/external/ExternalIvfPqIndexJniTest.java b/java/src/test/java/org/lance/index/external/ExternalIvfPqIndexJniTest.java new file mode 100644 index 00000000000..d96e79f7d6f --- /dev/null +++ b/java/src/test/java/org/lance/index/external/ExternalIvfPqIndexJniTest.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import org.junit.jupiter.api.Test; +import org.lance.JniLoader; + +import java.util.ArrayList; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Smoke tests for the JNI surface. Full end-to-end coverage (build + open + search + + * fetchRows) lives in {@code rust/lance/tests/external_index_phase1.rs}, which exercises + * the same code path the JNI calls into. Java-side seeding of parquet test data + * requires a parquet writer dependency and is deferred to phase 1.7-followup. + */ +public class ExternalIvfPqIndexJniTest { + static { + JniLoader.ensureLoaded(); + } + + @Test + public void packDeletedRidsLittleEndian() { + java.util.List deletes = new ArrayList<>(); + deletes.add(new long[] {0L, 0L}); + deletes.add(new long[] {1L, 42L}); + byte[] packed = ExternalIvfPqIndex.packDeletedRids(deletes); + assertEquals(16, packed.length, "two u64s = 16 bytes"); + + // First rid: (0 << 32) | 0 = 0 + long first = java.nio.ByteBuffer.wrap(packed, 0, 8) + .order(java.nio.ByteOrder.LITTLE_ENDIAN) + .getLong(); + assertEquals(0L, first); + + // Second rid: (1 << 32) | 42 = 4294967338 + long second = java.nio.ByteBuffer.wrap(packed, 8, 8) + .order(java.nio.ByteOrder.LITTLE_ENDIAN) + .getLong(); + assertEquals((1L << 32) | 42L, second); + } + + @Test + public void packDeletedRidsValidatesPairLength() { + java.util.List bad = Collections.singletonList(new long[] {1L, 2L, 3L}); + assertThrows(IllegalArgumentException.class, () -> ExternalIvfPqIndex.packDeletedRids(bad)); + } + + @Test + public void openMissingDirThrows() { + // Validates the JNI exception bridge. LanceError::IO maps to java.io.IOException. + assertThrows( + java.io.IOException.class, + () -> ExternalIvfPqIndex.open("/tmp/this-path-does-not-exist-9f3a8e2c")); + } + + @Test + public void paramsBuilderDefaults() { + ExternalIvfPqIndexParams p = ExternalIvfPqIndexParams.builder().build(); + assertEquals(256, p.getNumPartitions()); + assertEquals(16, p.getNumSubVectors()); + assertEquals(8, p.getNumBitsPerSubVector()); + assertEquals(ExternalIvfPqIndexParams.Metric.L2, p.getMetric()); + assertNotNull(p.getMetric().toRustString()); + } + + @Test + public void paramsBuilderOverrides() { + ExternalIvfPqIndexParams p = + ExternalIvfPqIndexParams.builder() + .numPartitions(32) + .numSubVectors(4) + .metric(ExternalIvfPqIndexParams.Metric.Cosine) + .maxIters(10) + .seed(42L) + .build(); + assertEquals(32, p.getNumPartitions()); + assertEquals(4, p.getNumSubVectors()); + assertEquals(ExternalIvfPqIndexParams.Metric.Cosine, p.getMetric()); + assertEquals("Cosine", p.getMetric().toRustString()); + assertEquals(10, p.getMaxIters()); + assertEquals(42L, p.getSeed()); + } +} diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index ca2bfdeaf91..eeb09f0075d 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -89,6 +89,9 @@ async_cell = "0.2.2" semver.workspace = true tokio-stream = { workspace = true } tokio-util = { workspace = true } +# parquet is a runtime dep for the external vector index, which reads source +# parquet files via page-index-aware random access (PageIndexPolicy::Required). +parquet = { version = "58", default-features = false, features = ["arrow", "async"] } [build-dependencies] prost-build.workspace = true @@ -126,7 +129,6 @@ geoarrow-array = { workspace = true } geoarrow-schema = { workspace = true } geo-types = { workspace = true } datafusion-substrait = { workspace = true } -parquet = { version = "58", default-features = false, features = ["arrow", "async"] } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } [features] diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 87b32344ec6..bc27997dd13 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -9,6 +9,7 @@ use std::{any::Any, collections::HashMap}; pub mod builder; pub(crate) mod details; +pub mod external; pub mod ivf; pub mod pq; pub mod utils; diff --git a/rust/lance/src/index/vector/external/build.rs b/rust/lance/src/index/vector/external/build.rs new file mode 100644 index 00000000000..af93bd3d5d2 --- /dev/null +++ b/rust/lance/src/index/vector/external/build.rs @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! `build()` implementation for [`super::ExternalIvfPqIndex`]. +//! +//! Composes the existing public Lance APIs: +//! +//! 1. [`ParquetVectorSource::sample`] → kmeans (`KMeans::new`) → centroids → `IvfModel::new` +//! 2. [`ParquetVectorSource::sample`] → `PQBuildParams::build` → `ProductQuantizer` +//! 3. [`ParquetVectorSource::iter_batches`] → `IvfTransformer::with_pq` → `shuffle_dataset` +//! → partition-binned streams +//! 4. `write_ivf_pq_file_external(object_store, path, ..., streams)` writes the index +//! +//! The Dataset dependency is gone end-to-end on this path. + +use std::sync::Arc; + +use arrow_array::FixedSizeListArray; +use lance_arrow::FixedSizeListArrayExt; +use lance_core::{Error, Result}; +use lance_index::vector::ivf::IvfTransformer; +use lance_index::vector::ivf::shuffler::shuffle_dataset; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::kmeans::KMeans; +use lance_index::vector::pq::PQBuildParams; +use lance_io::object_store::ObjectStore; +use uuid::Uuid; + +use super::manifest::{ExternalIndexManifest, write_manifest}; +use super::params::ExternalIvfPqIndexParams; +use super::parquet_source::ParquetVectorSource; +use super::types::ParquetFileSpec; +use crate::index::vector::ivf::write_ivf_pq_file_external; + +/// Internal entry point — drives the full build pipeline. +/// +/// Layout written under `output_uri`: +/// +/// ```text +/// / +/// / +/// index.idx ← IVF-PQ partitions + protobuf metadata +/// manifest.json ← ParquetFileSpec list + build params (Phase 1.4 lands this) +/// ``` +/// +/// The manifest sidecar lands in Phase 1.4 along with `open()`. For now `build()` +/// writes only `index.idx` so the search-side primitives can be validated. +pub(super) async fn build_index( + files: Vec, + vector_column: &str, + output_uri: &str, + params: ExternalIvfPqIndexParams, +) -> Result { + if files.is_empty() { + return Err(Error::invalid_input( + "build_index requires at least one parquet file", + )); + } + + let source = ParquetVectorSource::try_new(files.clone(), vector_column)?; + + // 1. Train kmeans for IVF centroids + let sample_size = (params.num_partitions * params.sample_rate) + .min(source.num_rows()? as usize) + .max(params.num_partitions); + let training = source.sample(sample_size).await?; + let kmeans = KMeans::new(&training, params.num_partitions, params.max_iters as u32) + .map_err(|e| Error::index(format!("kmeans training failed: {e}")))?; + + let centroids = + FixedSizeListArray::try_new_from_values(kmeans.centroids.clone(), kmeans.dimension as i32) + .map_err(|e| Error::index(format!("kmeans centroids → FixedSizeListArray: {e}")))?; + let ivf = IvfModel::new(centroids.clone(), None); + + // 2. Train PQ codebooks over the same training data + let pq = PQBuildParams::new(params.num_sub_vectors, params.num_bits_per_sub_vector) + .build(&training, params.metric.into()) + .map_err(|e| Error::index(format!("PQ training failed: {e}")))?; + + // 3. Build IVF transformer; shuffle source batches into partition-binned streams + let transformer = Arc::new(IvfTransformer::with_pq( + centroids, + params.metric.into(), + vector_column, + pq.clone(), + None, + )); + let raw_stream = source.iter_batches()?; + let partitioned_streams = shuffle_dataset( + raw_stream, + transformer, + /* precomputed_partitions = */ None, + params.num_partitions as u32, + /* shuffle_partition_batches = */ 1024, + /* shuffle_partition_concurrency = */ 2, + /* precomputed_shuffle_buffers = */ None, + ) + .await + .map_err(|e| Error::index(format!("shuffle_dataset failed: {e}")))?; + + // 4. Write the index file. Resolve `output_uri` into (ObjectStore, Path) and + // write under //index.idx. + let (object_store, root_path) = ObjectStore::from_uri(output_uri).await?; + let index_uuid = Uuid::new_v4(); + let index_dir = root_path.clone().join(index_uuid.to_string()); + let index_path = index_dir.clone().join(super::open::INDEX_FILE_NAME); + + write_ivf_pq_file_external( + &object_store, + &index_path, + vector_column, + /* index_name = */ "external_ivf_pq", + /* dataset_version = */ 0, + ivf, + pq, + partitioned_streams, + ) + .await?; + + // Resolve any unfilled num_rows on the file specs from their footers so the + // manifest records authoritative values. + let mut resolved_files = files; + for spec in resolved_files.iter_mut() { + if spec.num_rows == 0 { + spec.num_rows = read_parquet_num_rows(&spec.file_path)?; + } + } + let manifest = ExternalIndexManifest::from_build(vector_column, &resolved_files, ¶ms); + write_manifest(&object_store, &index_dir, &manifest).await?; + + Ok(index_uuid) +} + +fn read_parquet_num_rows(path: &str) -> Result { + use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + let file = std::fs::File::open(path) + .map_err(|e| Error::invalid_input(format!("failed to open {path}: {e}")))?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file) + .map_err(|e| Error::invalid_input(format!("failed to read footer for {path}: {e}")))?; + Ok(builder.metadata().file_metadata().num_rows() as u64) +} diff --git a/rust/lance/src/index/vector/external/fetch.rs b/rust/lance/src/index/vector/external/fetch.rs new file mode 100644 index 00000000000..5c06d895df3 --- /dev/null +++ b/rust/lance/src/index/vector/external/fetch.rs @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! `fetch_rows()` implementation for [`super::ExternalIvfPqIndex`]. +//! +//! Post-topK materialization primitive. The caller supplies a list of +//! `(file_path, row_index)` keys and a projection column list; Lance batches by +//! file, issues one page-index-aware parquet read per file, reassembles the +//! result in caller-input order, and returns a `RecordBatch` with one row per +//! input key. +//! +//! Why this matters for lance-spark: today the join's materialize stage writes +//! all R columns into a temp Lance file because there's no way to fetch them on +//! demand later. With `fetch_rows`, the join only fetches the projection columns +//! for surviving top-K rows — a hard win on materialize I/O for large R. + +use std::collections::HashMap; +use std::fs::File; +use std::ops::Range; +use std::sync::Arc; + +use arrow::compute::concat_batches; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::{Field, Schema, SchemaRef}; +use lance_core::{Error, Result}; +use parquet::arrow::ProjectionMask; +use parquet::arrow::arrow_reader::{ + ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, +}; +use parquet::file::metadata::PageIndexPolicy; + +use super::open::OpenedExternalIndex; +use super::types::ParquetRowKey; + +/// Fetch rows by `(file_path, row_index)` from the registered parquet files. +/// +/// `projection` may include any column from the parquet schema, not just the +/// vector column the index was built over. +/// +/// Result rows are in caller-input order — duplicates in the input produce +/// duplicate result rows. Empty input returns an empty batch with the projected +/// schema. +pub async fn fetch_rows( + opened: &OpenedExternalIndex, + row_keys: &[ParquetRowKey], + projection: &[&str], +) -> Result { + if projection.is_empty() { + return Err(Error::invalid_input( + "fetch_rows: projection must contain at least one column", + )); + } + + // Group input by file_path while remembering each input's position so we + // can reorder later. + let mut by_file: HashMap> = HashMap::new(); + for (input_pos, key) in row_keys.iter().enumerate() { + by_file + .entry(key.file_path.clone()) + .or_default() + .push((input_pos, key.row_index)); + } + + // Validate every file appears in the manifest. We don't strictly need it + // for correctness here (Lance will just open the parquet path) but it + // catches typos early and matches the contract that fetched files belong + // to the registered set. + for path in by_file.keys() { + if opened.manifest.file_id(path).is_none() { + return Err(Error::invalid_input(format!( + "fetch_rows: file '{path}' not registered with this index" + ))); + } + } + + // Pull the schema from the first file's read so we can build an empty batch + // of the projected schema if the input is empty. + if row_keys.is_empty() { + let any_path = &opened.manifest.files[0].file_path; + let schema = projected_schema_from_file(any_path, projection)?; + return Ok(RecordBatch::new_empty(schema)); + } + + // Per-file read. Inputs within a file go to one parquet read regardless of + // duplicates / order. + let mut per_file_results: Vec<(Vec, RecordBatch)> = Vec::with_capacity(by_file.len()); + let mut shared_schema: Option = None; + for (file_path, hits) in by_file { + let row_indices: Vec = hits.iter().map(|(_, r)| *r).collect(); + let batch = read_rows_from_file(&file_path, projection, &row_indices)?; + if shared_schema.is_none() { + shared_schema = Some(batch.schema()); + } + per_file_results.push((hits.iter().map(|(p, _)| *p).collect(), batch)); + } + let schema = shared_schema.expect("non-empty input means at least one batch"); + + // Reorder: build one row of `RecordBatch` per input position, then concat. + // We use `take`-style indexing per column. For typical sizes (top-K * Q) + // this is fast. + let total_rows = row_keys.len(); + let mut column_collectors: Vec> = + vec![Vec::with_capacity(total_rows); schema.fields().len()]; + + // Build a position-aware view of each per-file batch: + // position_in_input → (file_batch_index, row_in_batch) + let mut by_input_position: Vec> = vec![None; total_rows]; + for (file_batch_index, (input_positions, _batch)) in per_file_results.iter().enumerate() { + // input_positions[i] is the input position for row i in this batch. + // Store dedup-aware positions: a duplicate input rid in the same file + // shares the same row in the read result (we deduped in read_rows_from_file + // before issuing the parquet read). + // Re-derive the row-in-batch via the dedup map produced inside + // read_rows_from_file? Easier: re-run a small reorder here. + // Build the dedup ordering the same way read_rows_from_file did. + let row_indices_in_file: Vec = input_positions + .iter() + .map(|&p| row_keys[p].row_index) + .collect(); + let mut sorted = row_indices_in_file.clone(); + sorted.sort_unstable(); + sorted.dedup(); + for (input_pos, &row_index) in input_positions.iter().zip(row_indices_in_file.iter()) { + let row_in_batch = sorted + .binary_search(&row_index) + .expect("row missing from per-file batch index"); + by_input_position[*input_pos] = Some((file_batch_index, row_in_batch)); + } + } + + // Slice each input row's columns into the collectors. + for input_pos in 0..total_rows { + let (batch_idx, row_in_batch) = by_input_position[input_pos].ok_or_else(|| { + Error::index(format!("fetch_rows: input position {input_pos} unmapped")) + })?; + let batch = &per_file_results[batch_idx].1; + for (col_idx, col) in batch.columns().iter().enumerate() { + column_collectors[col_idx].push(col.slice(row_in_batch, 1)); + } + } + + // Concat each column's collected slices into one array. + let mut final_columns: Vec = Vec::with_capacity(column_collectors.len()); + for col_chunks in column_collectors { + let refs: Vec<&dyn Array> = col_chunks.iter().map(|a| a.as_ref()).collect(); + let concatenated = arrow::compute::concat(&refs) + .map_err(|e| Error::index(format!("fetch_rows: column concat failed: {e}")))?; + final_columns.push(concatenated); + } + let out = RecordBatch::try_new(schema, final_columns) + .map_err(|e| Error::index(format!("fetch_rows: failed to build output batch: {e}")))?; + Ok(out) +} + +/// Read the requested `row_indices` of `projection` columns from one parquet +/// file. Returns a single `RecordBatch` with rows in **deduped sorted order**. +/// Caller is responsible for reordering to its input order. +fn read_rows_from_file( + path: &str, + projection: &[&str], + row_indices: &[u64], +) -> Result { + let file = File::open(path) + .map_err(|e| Error::invalid_input(format!("fetch_rows: failed to open {path}: {e}")))?; + let opts = ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Required); + let builder = + ParquetRecordBatchReaderBuilder::try_new_with_options(file, opts).map_err(|e| { + Error::invalid_input(format!("fetch_rows: failed parquet meta {path}: {e}")) + })?; + let total_rows: u64 = builder.metadata().file_metadata().num_rows() as u64; + + let mask = ProjectionMask::columns(builder.parquet_schema(), projection.iter().copied()); + + let mut sorted: Vec = row_indices.to_vec(); + sorted.sort_unstable(); + sorted.dedup(); + if let Some(&last) = sorted.last() { + if last >= total_rows { + return Err(Error::invalid_input(format!( + "fetch_rows: row_index {last} out of range for {path} ({total_rows} rows)" + ))); + } + } + let ranges: Vec> = sorted + .iter() + .map(|&r| (r as usize)..(r as usize + 1)) + .collect(); + let selection = RowSelection::from_consecutive_ranges(ranges.into_iter(), total_rows as usize); + + let reader = builder + .with_projection(mask) + .with_row_selection(selection) + .build() + .map_err(|e| Error::invalid_input(format!("fetch_rows: build reader {path}: {e}")))?; + + let batches: Vec = reader + .collect::, _>>() + .map_err(|e| Error::invalid_input(format!("fetch_rows: read {path}: {e}")))?; + if batches.is_empty() { + // Build an empty batch with the projected schema. + let schema = projected_schema_from_file(path, projection)?; + return Ok(RecordBatch::new_empty(schema)); + } + let schema = batches[0].schema(); + concat_batches(&schema, &batches) + .map_err(|e| Error::index(format!("fetch_rows: concat batches from {path}: {e}"))) +} + +fn projected_schema_from_file(path: &str, projection: &[&str]) -> Result { + let file = File::open(path) + .map_err(|e| Error::invalid_input(format!("fetch_rows: failed to open {path}: {e}")))?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| { + Error::invalid_input(format!("fetch_rows: failed parquet meta {path}: {e}")) + })?; + let arrow_schema = builder.schema(); + let fields: Vec = projection + .iter() + .map(|name| { + arrow_schema.field_with_name(name).cloned().map_err(|e| { + Error::invalid_input(format!( + "fetch_rows: projected column '{name}' missing in {path}: {e}" + )) + }) + }) + .collect::>()?; + Ok(Arc::new(Schema::new(fields))) +} diff --git a/rust/lance/src/index/vector/external/manifest.rs b/rust/lance/src/index/vector/external/manifest.rs new file mode 100644 index 00000000000..7f6b40831c4 --- /dev/null +++ b/rust/lance/src/index/vector/external/manifest.rs @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! External index manifest sidecar. +//! +//! `manifest.json` lives next to `index.idx` and records: +//! +//! - The list of parquet files the index covers (encoded `file_id` is implicit +//! in array position) +//! - Build params (metric, num_partitions, num_sub_vectors, etc.) — informational +//! +//! Sidecar JSON rather than a protobuf field on the index because (a) it lets the +//! file list evolve independently of Lance's format spec, and (b) it's easy to +//! inspect for debugging. + +use lance_core::{Error, Result}; +use lance_io::object_store::ObjectStore; +use object_store::path::Path; +use serde::{Deserialize, Serialize}; + +use super::params::ExternalIvfPqIndexParams; +use super::types::ParquetFileSpec; + +pub const MANIFEST_FILE_NAME: &str = "manifest.json"; + +/// Persistent manifest written alongside `index.idx`. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExternalIndexManifest { + /// Schema version of this manifest format. Bump on incompatible changes. + pub manifest_version: u32, + /// Vector column the index was built over. + pub vector_column: String, + /// Parquet files registered with this index. The encoded rid is + /// `(position_in_this_list_u32 << 32) | row_index_u32`. + pub files: Vec, + /// Build params, informational. The runtime metric type comes from the index + /// file's protobuf header; this is a paper trail. + pub params: ManifestParams, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ManifestFileEntry { + pub file_path: String, + pub num_rows: u64, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ManifestParams { + pub num_partitions: usize, + pub num_sub_vectors: usize, + pub num_bits_per_sub_vector: usize, + pub metric: String, // L2 / Cosine / Dot — the textual form from MetricType::to_string + pub max_iters: usize, + pub sample_rate: usize, + pub seed: u64, +} + +impl ExternalIndexManifest { + pub fn from_build( + vector_column: &str, + files: &[ParquetFileSpec], + params: &ExternalIvfPqIndexParams, + ) -> Self { + Self { + manifest_version: 1, + vector_column: vector_column.to_string(), + files: files + .iter() + .map(|s| ManifestFileEntry { + file_path: s.file_path.clone(), + num_rows: s.num_rows, + }) + .collect(), + params: ManifestParams { + num_partitions: params.num_partitions, + num_sub_vectors: params.num_sub_vectors, + num_bits_per_sub_vector: params.num_bits_per_sub_vector, + metric: format!("{}", params.metric), + max_iters: params.max_iters, + sample_rate: params.sample_rate, + seed: params.seed, + }, + } + } + + /// Resolve a manifest entry's `file_id`. Currently O(n) which is fine for + /// typical file counts (< 10k). + pub fn file_id(&self, file_path: &str) -> Option { + self.files + .iter() + .position(|e| e.file_path == file_path) + .map(|i| i as u32) + } + + pub fn file_path(&self, file_id: u32) -> Option<&str> { + self.files + .get(file_id as usize) + .map(|e| e.file_path.as_str()) + } +} + +/// Write the manifest to `dir/manifest.json`. +pub async fn write_manifest( + object_store: &ObjectStore, + dir: &Path, + manifest: &ExternalIndexManifest, +) -> Result<()> { + let path = dir.clone().join(MANIFEST_FILE_NAME); + let json = serde_json::to_vec_pretty(manifest) + .map_err(|e| Error::io(format!("manifest serialize failed: {e}")))?; + object_store + .put(&path, &json) + .await + .map_err(|e| Error::io(format!("manifest write failed at {path}: {e}")))?; + Ok(()) +} + +/// Read `dir/manifest.json`. +pub async fn read_manifest( + object_store: &ObjectStore, + dir: &Path, +) -> Result { + let path = dir.clone().join(MANIFEST_FILE_NAME); + let bytes = object_store + .read_one_all(&path) + .await + .map_err(|e| Error::io(format!("manifest read failed at {path}: {e}")))?; + let manifest: ExternalIndexManifest = serde_json::from_slice(&bytes) + .map_err(|e| Error::io(format!("manifest parse failed at {path}: {e}")))?; + Ok(manifest) +} + +#[cfg(test)] +mod tests { + use super::*; + use lance_linalg::distance::MetricType; + use std::sync::Arc; + use tempfile::TempDir; + + #[test] + fn from_build_round_trips_via_json() { + let files = vec![ + ParquetFileSpec::with_metadata( + "/tmp/a.parquet", + 100, + Arc::new(arrow_schema::Schema::empty()), + ), + ParquetFileSpec::with_metadata( + "/tmp/b.parquet", + 200, + Arc::new(arrow_schema::Schema::empty()), + ), + ]; + let params = ExternalIvfPqIndexParams::builder() + .num_partitions(64) + .num_sub_vectors(8) + .metric(MetricType::Cosine) + .build(); + + let manifest = ExternalIndexManifest::from_build("emb", &files, ¶ms); + let json = serde_json::to_vec_pretty(&manifest).unwrap(); + let parsed: ExternalIndexManifest = serde_json::from_slice(&json).unwrap(); + assert_eq!(parsed.manifest_version, 1); + assert_eq!(parsed.vector_column, "emb"); + assert_eq!(parsed.files.len(), 2); + assert_eq!(parsed.file_id("/tmp/b.parquet"), Some(1)); + assert_eq!(parsed.file_path(0), Some("/tmp/a.parquet")); + assert_eq!(parsed.file_path(99), None); + assert_eq!(parsed.params.num_partitions, 64); + assert_eq!(parsed.params.metric, "cosine"); + } + + #[tokio::test] + async fn write_then_read() { + let tmp = TempDir::new().unwrap(); + let uri = tmp.path().to_str().unwrap(); + let (object_store, root) = ObjectStore::from_uri(uri).await.unwrap(); + + let files = vec![ParquetFileSpec::with_metadata( + "/tmp/x.parquet", + 42, + Arc::new(arrow_schema::Schema::empty()), + )]; + let params = ExternalIvfPqIndexParams::builder().build(); + let manifest = ExternalIndexManifest::from_build("vec", &files, ¶ms); + write_manifest(&object_store, &root, &manifest) + .await + .unwrap(); + + let parsed = read_manifest(&object_store, &root).await.unwrap(); + assert_eq!(parsed.vector_column, "vec"); + assert_eq!(parsed.files.len(), 1); + assert_eq!(parsed.files[0].num_rows, 42); + } +} diff --git a/rust/lance/src/index/vector/external/mod.rs b/rust/lance/src/index/vector/external/mod.rs new file mode 100644 index 00000000000..eb77a9e6961 --- /dev/null +++ b/rust/lance/src/index/vector/external/mod.rs @@ -0,0 +1,497 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! External vector index — IVF-PQ over caller-supplied parquet files. +//! +//! Lance builds and queries a vector index over parquet data without copying it into +//! a Lance dataset. The caller registers a list of parquet files; Lance encodes each +//! row's identity as `(file_id_u32 << 32) | row_index_u32` internally, and surfaces +//! search results as `(file_path, row_index, distance)` so callers never have to +//! decode the rid themselves. +//! +//! Read [`ExternalIvfPqIndex`] for the entry point and [`crate::index::vector`] for +//! the equivalent dataset-backed index. +//! +//! # Why +//! +//! For engines (Spark, Trino, Daft, ...) that already store data in parquet/Delta/ +//! Iceberg, building a Lance dataset just to get a Lance vector index doubles the +//! storage. The external index lets the source format stay the source of truth and +//! confines Lance to the index file alone. +//! +//! # Status +//! +//! Phase 1 of the "External Vector Index" RFC. Surface area: +//! +//! - [`ExternalIvfPqIndex`] — handle: build / open / search / fetch_rows +//! - [`ParquetFileSpec`] — describes one parquet file in the registry +//! - [`SearchResult`] — `{file_path, row_index, distance}` returned by `search` +//! - [`ParquetRowKey`] — `(file_path, row_index)` accepted by `fetch_rows` +//! - [`RowFilter`] — extensibility hook for Delta deletion vectors / Iceberg +//! position deletes / ad-hoc skip predicates +//! - [`ExternalIvfPqIndexParams`] — kmeans / PQ / metric configuration + +mod build; +mod fetch; +pub(crate) mod manifest; +pub(crate) mod open; +pub mod params; +pub(crate) mod parquet_source; +mod search; +pub mod types; + +pub use params::ExternalIvfPqIndexParams; +pub use types::{ParquetFileSpec, ParquetRowKey, RowFilter, SearchResult}; + +use arrow_array::RecordBatch; +use lance_core::Result; + +/// IVF-PQ index over caller-registered parquet files. +/// +/// Lance owns the parquet reader (page-index-aware random access via +/// `PageIndexPolicy::Required`) and the refinement step. Callers see +/// `(file_path, row_index, distance)` results. +/// +/// # Example +/// +/// Sketch — full impl lands in subsequent phases. Today this only constructs. +/// +/// ```ignore +/// # use lance::index::vector::external::*; +/// # async fn run(files: Vec) -> lance_core::Result<()> { +/// let params = ExternalIvfPqIndexParams::builder() +/// .num_partitions(256) +/// .num_sub_vectors(16) +/// .build(); +/// ExternalIvfPqIndex::build(files, "vec", "/tmp/idx", params).await?; +/// +/// let idx = ExternalIvfPqIndex::open("/tmp/idx").await?; +/// let hits = idx.search(&[0.1; 128], 10, 16, 8, None).await?; +/// for hit in &hits { +/// println!("{} @ row {} = {}", hit.file_path, hit.row_index, hit.distance); +/// } +/// # Ok(()) } +/// ``` +pub struct ExternalIvfPqIndex { + /// All deserialized index state (manifest + IVF model + PQ codebooks + + /// object_store + index_dir). + inner: open::OpenedExternalIndex, +} + +impl ExternalIvfPqIndex { + /// Build an external IVF-PQ index over the given parquet files. + /// + /// Reads sample vectors for kmeans + PQ training, encodes residuals, writes a + /// single index file at `output_uri`. Synchronous: returns when the file is on + /// disk and durable. + /// + /// `vector_column` must be a non-null `FixedSizeList` column in every + /// file's schema; this is validated against each parquet footer. + /// + /// `file_id` is implicit in `files`'s position. Reordering invalidates the + /// index. + pub async fn build( + files: Vec, + vector_column: &str, + output_uri: &str, + params: ExternalIvfPqIndexParams, + ) -> Result { + build::build_index(files, vector_column, output_uri, params).await + } + + /// Open an external IVF-PQ index by URI. + /// + /// Cheap: reads the manifest + index header. Per-file parquet readers are + /// constructed lazily on first `search` / `fetch_rows`. + pub async fn open(uri: &str) -> Result { + let inner = open::open_index(uri).await?; + Ok(Self { inner }) + } + + /// Number of registered parquet files. + pub fn num_files(&self) -> usize { + self.inner.manifest.files.len() + } + + /// Look up the file path for a given `file_id` (the high 32 bits of the rid). + /// `None` if the id is out of range. + pub fn file_path(&self, file_id: u32) -> Option<&str> { + self.inner.manifest.file_path(file_id) + } + + /// Vector column the index was built over. + pub fn vector_column(&self) -> &str { + &self.inner.manifest.vector_column + } + + /// Number of IVF partitions in the index. + pub fn num_partitions(&self) -> usize { + self.inner.ivf.num_partitions() + } + + /// Run an approximate nearest-neighbor query over the index. + /// + /// Probes `nprobes` IVF partitions, fetches `k * refine_factor` PQ-approx + /// candidates, refines them by reading their actual vectors from the source + /// parquet via page-index-aware random access, and returns the top-`k` after + /// exact distance recompute. + /// + /// `filter` (if `Some`) is consulted during refinement; rows it rejects are + /// dropped before re-ranking. See [`RowFilter`] for typical use cases like + /// Delta deletion vectors. + pub async fn search( + &self, + query: &[f32], + k: usize, + nprobes: usize, + refine_factor: usize, + filter: Option<&dyn RowFilter>, + ) -> Result> { + search::search(&self.inner, query, k, nprobes, refine_factor, filter).await + } + + /// Random-access fetch by `(file_path, row_index)` keys. + /// + /// Lance batches by file internally and issues one page-index-aware parquet + /// read per file, then reassembles the result in caller-input order. The + /// result has one row per input key. + /// + /// `projection` may include columns that are not part of the index — they're + /// read from the parquet schema directly. The killer feature: post-topK + /// materialization fetches only the projection columns for the surviving rows. + pub async fn fetch_rows( + &self, + row_keys: &[ParquetRowKey], + projection: &[&str], + ) -> Result { + fetch::fetch_rows(&self.inner, row_keys, projection).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::path::PathBuf; + use std::sync::Arc; + + use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, RecordBatch}; + use arrow_schema::{Field, Schema}; + use lance_arrow::FixedSizeListArrayExt; + use lance_linalg::distance::MetricType; + use parquet::arrow::ArrowWriter; + use parquet::file::properties::WriterProperties; + use tempfile::TempDir; + + /// Smoke test: open() on a missing path errors out cleanly. Real round-trip + /// is exercised by `build_then_open_round_trip` below. + #[tokio::test] + async fn open_missing_errors() { + let result = ExternalIvfPqIndex::open("/tmp/this-path-does-not-exist-9f3a8").await; + assert!(result.is_err(), "open() on missing path should error"); + } + + fn write_random_parquet(path: &PathBuf, num_rows: usize, dim: usize, seed: u64) { + use rand::{Rng, SeedableRng, rngs::StdRng}; + let mut rng = StdRng::seed_from_u64(seed); + let values: Vec = (0..num_rows * dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect(); + let flat = Float32Array::from(values); + let fsl = FixedSizeListArray::try_new_from_values(flat, dim as i32).unwrap(); + let schema = Arc::new(Schema::new(vec![Field::new( + "vec", + fsl.data_type().clone(), + false, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl) as ArrayRef]).unwrap(); + let file = std::fs::File::create(path).unwrap(); + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + fn brute_force_topk( + vectors: &FixedSizeListArray, + query: &[f32], + k: usize, + ) -> Vec<(usize, f32)> { + let values = vectors + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let dim = vectors.value_length() as usize; + let n = vectors.len(); + let mut dists: Vec<(usize, f32)> = (0..n) + .map(|i| { + let mut s = 0.0f32; + for d in 0..dim { + let v = values.value(i * dim + d); + let q = query[d]; + let diff = v - q; + s += diff * diff; + } + (i, s) + }) + .collect(); + dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + dists.into_iter().take(k).collect() + } + + /// End-to-end: build → open → search → assert recall against brute force + /// ground truth. Toy scale; meant to validate the pipeline runs and rids round-trip. + #[tokio::test(flavor = "multi_thread")] + async fn build_open_search_recall_above_threshold() { + use rand::{Rng, SeedableRng, rngs::StdRng}; + + const NUM_VECTORS: usize = 1024; + const DIM: usize = 8; + const K: usize = 10; + const NUM_PARTITIONS: usize = 4; + const NUM_SUB_VECTORS: usize = 2; + + let tmp_data = TempDir::new().unwrap(); + let p = tmp_data.path().join("vec.parquet"); + write_random_parquet(&p, NUM_VECTORS, DIM, 42); + + // Read back the vectors so we can compute ground truth in-memory. + let all_vectors = { + let file = std::fs::File::open(&p).unwrap(); + let builder = + parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file) + .unwrap(); + let reader = builder.build().unwrap(); + let mut batches: Vec = Vec::new(); + for r in reader { + batches.push(r.unwrap()); + } + let arrays: Vec<&dyn Array> = batches.iter().map(|b| b.column(0).as_ref()).collect(); + let cat = arrow::compute::concat(&arrays).unwrap(); + cat.as_any() + .downcast_ref::() + .unwrap() + .clone() + }; + + let tmp_out = TempDir::new().unwrap(); + let output_uri = tmp_out.path().to_str().unwrap(); + let params = ExternalIvfPqIndexParams::builder() + .num_partitions(NUM_PARTITIONS) + .num_sub_vectors(NUM_SUB_VECTORS) + .num_bits_per_sub_vector(8) + .metric(MetricType::L2) + .max_iters(10) + .sample_rate(64) + .build(); + let files = vec![ParquetFileSpec::of(p.to_str().unwrap())]; + let uuid = ExternalIvfPqIndex::build(files, "vec", output_uri, params) + .await + .expect("build_index"); + let idx = ExternalIvfPqIndex::open(tmp_out.path().join(uuid.to_string()).to_str().unwrap()) + .await + .expect("open"); + + // Run a few queries; check that the top-1 from search matches the + // brute-force top-1 for at least most queries (recall@1 ≥ 0.5 at this + // toy scale is the bar — PQ at dim=8 / 2 sub-vectors loses fidelity). + let mut rng = StdRng::seed_from_u64(99); + let mut hits = 0; + let total_queries = 16; + for _ in 0..total_queries { + let query: Vec = (0..DIM).map(|_| rng.random_range(-1.0f32..1.0)).collect(); + let truth = brute_force_topk(&all_vectors, &query, K); + let truth_set: std::collections::HashSet = + truth.iter().map(|(i, _)| *i as u64).collect(); + + let results = idx + .search( + &query, K, /* nprobes = */ 4, /* refine_factor = */ 4, None, + ) + .await + .expect("search"); + assert!(!results.is_empty(), "search returned empty"); + let result_set: std::collections::HashSet = + results.iter().map(|r| r.row_index).collect(); + let intersection = result_set.intersection(&truth_set).count(); + if intersection >= K / 2 { + hits += 1; + } + } + assert!( + hits >= total_queries / 2, + "recall too low: {hits}/{total_queries} queries had ≥ K/2 correct" + ); + } + + /// fetch_rows() returns parquet projection cols for arbitrary (file_path, + /// row_index) keys, in caller-input order, including duplicates and + /// non-vector columns. + #[tokio::test(flavor = "multi_thread")] + async fn fetch_rows_returns_projection_in_input_order() { + use arrow_array::{Int64Array, StringArray, UInt64Array}; + + let tmp_data = TempDir::new().unwrap(); + let p = tmp_data.path().join("payload.parquet"); + + // Build a parquet file with a vector column + an integer payload column + // + a string payload column. fetch_rows should be able to project any + // subset. PQ training needs ≥ 256 rows so size accordingly. + let n = 320usize; + let dim = 4usize; + let values: Vec = (0..n * dim).map(|i| (i as f32) * 0.01).collect(); + let flat = Float32Array::from(values); + let fsl = FixedSizeListArray::try_new_from_values(flat, dim as i32).unwrap(); + let ids = Int64Array::from((0..n as i64).map(|i| i * 10).collect::>()); + let names = StringArray::from((0..n).map(|i| format!("name-{i}")).collect::>()); + + let schema = Arc::new(Schema::new(vec![ + Field::new("vec", fsl.data_type().clone(), false), + Field::new("id", arrow_schema::DataType::Int64, false), + Field::new("name", arrow_schema::DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(fsl) as ArrayRef, + Arc::new(ids) as ArrayRef, + Arc::new(names) as ArrayRef, + ], + ) + .unwrap(); + let file = std::fs::File::create(&p).unwrap(); + let mut writer = parquet::arrow::ArrowWriter::try_new( + file, + schema, + Some(parquet::file::properties::WriterProperties::builder().build()), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + // Build + open + let tmp_out = TempDir::new().unwrap(); + let params = ExternalIvfPqIndexParams::builder() + .num_partitions(2) + .num_sub_vectors(2) + .num_bits_per_sub_vector(8) + .metric(MetricType::L2) + .max_iters(5) + // sample_rate * num_partitions must yield ≥ 256 for PQ training. + .sample_rate(150) + .build(); + let uuid = ExternalIvfPqIndex::build( + vec![ParquetFileSpec::of(p.to_str().unwrap())], + "vec", + tmp_out.path().to_str().unwrap(), + params, + ) + .await + .unwrap(); + let idx = ExternalIvfPqIndex::open(tmp_out.path().join(uuid.to_string()).to_str().unwrap()) + .await + .unwrap(); + + // Fetch rows in non-sorted, with-duplicate order. + let path = p.to_str().unwrap().to_string(); + let keys = vec![ + ParquetRowKey::of(&path, 5), + ParquetRowKey::of(&path, 0), + ParquetRowKey::of(&path, 5), // duplicate + ParquetRowKey::of(&path, 30), + ]; + let result = idx + .fetch_rows(&keys, &["id", "name"]) + .await + .expect("fetch_rows"); + assert_eq!(result.num_rows(), 4); + assert_eq!(result.num_columns(), 2); + + let id_col = result + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + // expected ids: [50, 0, 50, 300] + assert_eq!( + (0..id_col.len()) + .map(|i| id_col.value(i)) + .collect::>(), + vec![50, 0, 50, 300] + ); + + let name_col = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + (0..name_col.len()) + .map(|i| name_col.value(i).to_string()) + .collect::>(), + vec!["name-5", "name-0", "name-5", "name-30"] + ); + + // Empty input: returns empty batch with projected schema. + let empty = idx.fetch_rows(&[], &["id"]).await.unwrap(); + assert_eq!(empty.num_rows(), 0); + assert_eq!(empty.schema().fields().len(), 1); + assert_eq!(empty.schema().field(0).name(), "id"); + + // Unknown file: validation error before any I/O. + let bad = idx + .fetch_rows(&[ParquetRowKey::of("/nope/no.parquet", 0)], &["id"]) + .await; + assert!(bad.is_err()); + + // Silence the unused variable warnings. + let _ = UInt64Array::from(vec![0u64]); + } + + /// build() writes a non-empty index file + manifest. open() round-trips them. + #[tokio::test(flavor = "multi_thread")] + async fn build_then_open_round_trip() { + let tmp_data = TempDir::new().unwrap(); + let p = tmp_data.path().join("vec.parquet"); + write_random_parquet(&p, 256, 4, 42); + + let tmp_out = TempDir::new().unwrap(); + let output_uri = tmp_out.path().to_str().unwrap(); + + let params = ExternalIvfPqIndexParams::builder() + .num_partitions(4) + .num_sub_vectors(2) + .num_bits_per_sub_vector(8) + .metric(MetricType::L2) + .max_iters(5) + .sample_rate(64) + .build(); + + let files = vec![ParquetFileSpec::of(p.to_str().unwrap())]; + let uuid = ExternalIvfPqIndex::build(files, "vec", output_uri, params) + .await + .expect("build_index"); + + let idx_dir = tmp_out.path().join(uuid.to_string()); + let idx_file = idx_dir.join("index.idx"); + let manifest_file = idx_dir.join("manifest.json"); + for f in [&idx_file, &manifest_file] { + let meta = std::fs::metadata(f) + .unwrap_or_else(|e| panic!("expected file at {}: {e}", f.display())); + assert!(meta.len() > 0, "{} is empty", f.display()); + } + + let opened_uri = idx_dir.to_str().unwrap(); + let idx = ExternalIvfPqIndex::open(opened_uri) + .await + .expect("open() must succeed after build()"); + assert_eq!(idx.num_files(), 1); + assert_eq!(idx.num_partitions(), 4); + assert_eq!(idx.vector_column(), "vec"); + assert_eq!(idx.file_path(0), Some(p.to_str().unwrap())); + assert_eq!(idx.file_path(1), None); + } +} diff --git a/rust/lance/src/index/vector/external/open.rs b/rust/lance/src/index/vector/external/open.rs new file mode 100644 index 00000000000..683e58cf4ea --- /dev/null +++ b/rust/lance/src/index/vector/external/open.rs @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! `open()` implementation for [`super::ExternalIvfPqIndex`]. +//! +//! Reads: +//! +//! - `//manifest.json` — parquet file list + build params +//! - `//index.idx` — IVF model + PQ codebooks (Lance protobuf format) +//! +//! `index_dir` is the URI passed to `open()`. The single-uuid layout means the +//! caller passes the directory that contains `manifest.json` directly; for now we +//! assume the `/` segment is included in the URI. Production callers will +//! get the URI back from `build()` and just round-trip it. + +use std::sync::Arc; + +use lance_core::{Error, Result}; +use lance_index::pb; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::pq::ProductQuantizer; +use lance_io::object_store::ObjectStore; +use lance_io::traits::Reader; +use lance_io::utils::{read_message, read_metadata_offset}; +use lance_linalg::distance::DistanceType; +use object_store::path::Path; + +use super::manifest::{ExternalIndexManifest, read_manifest}; + +/// Constants. `INDEX_FILE_NAME` mirrors what the build path writes. +pub const INDEX_FILE_NAME: &str = "index.idx"; + +/// All deserialized state of an opened external index, *except* per-partition +/// posting-list data (which the search path lazy-loads). Owns its own +/// [`ObjectStore`] handle plus the index file path so search and fetch_rows can +/// re-issue I/O without re-resolving URIs. +/// +/// `object_store` and `index_dir` aren't read by the current search/fetch +/// paths (which open source parquet files directly via `std::fs`); they're +/// kept on the handle for the JNI / remote-store path that lands with #33. +#[allow(dead_code)] +pub struct OpenedExternalIndex { + pub manifest: ExternalIndexManifest, + pub ivf: IvfModel, + pub pq: ProductQuantizer, + pub metric: DistanceType, + pub object_store: Arc, + pub index_dir: Path, + pub index_file_reader: Arc, +} + +/// Resolve `uri` into `(ObjectStore, Path)`, then load both the manifest and the +/// index file's metadata. +pub async fn open_index(uri: &str) -> Result { + let (object_store, index_dir) = ObjectStore::from_uri(uri).await?; + + let manifest = read_manifest(&object_store, &index_dir).await?; + + let index_file_path = index_dir.clone().join(INDEX_FILE_NAME); + let reader: Arc = Arc::from(object_store.open(&index_file_path).await?); + + // Tail layout (see write_magics): u64 offset, i16 major, i16 minor, 8-byte magic. + let file_size = reader.size().await?; + if file_size < 20 { + return Err(Error::io(format!( + "external index file at {index_file_path} is too small to contain footer" + ))); + } + let block_size = reader.block_size().min(file_size); + let tail_start = file_size.saturating_sub(block_size.max(20)); + let tail = reader.get_range(tail_start..file_size).await?; + let metadata_offset = read_metadata_offset(&tail)?; + + let pb_index: pb::Index = read_message(reader.as_ref(), metadata_offset).await?; + let (ivf, pq, metric) = decode_pb_index(&pb_index)?; + + Ok(OpenedExternalIndex { + manifest, + ivf, + pq, + metric, + object_store, + index_dir, + index_file_reader: reader, + }) +} + +fn decode_pb_index(pb_index: &pb::Index) -> Result<(IvfModel, ProductQuantizer, DistanceType)> { + use lance_index::pb::vector_index_stage::Stage; + let vec_idx = match pb_index.implementation.as_ref() { + Some(lance_index::pb::index::Implementation::VectorIndex(v)) => v, + _ => { + return Err(Error::index( + "external index file is not a VectorIndex".to_string(), + )); + } + }; + let metric: DistanceType = + lance_index::pb::VectorMetricType::try_from(vec_idx.metric_type)?.into(); + + let mut ivf: Option = None; + let mut pq: Option = None; + for stage in &vec_idx.stages { + match stage.stage.as_ref() { + Some(Stage::Ivf(ivf_pb)) => { + ivf = Some(IvfModel::try_from(ivf_pb.clone())?); + } + Some(Stage::Pq(pq_pb)) => { + pq = Some(ProductQuantizer::from_proto(pq_pb, metric)?); + } + Some(Stage::Transform(_)) => { + // Transform stages are recorded for downstream pipelines but the + // external builder writes none today; ignore. + } + Some(other) => { + return Err(Error::index(format!( + "external index file has unsupported stage: {other:?}" + ))); + } + None => {} + } + } + + let ivf = ivf.ok_or_else(|| Error::index("external index missing IVF stage"))?; + let pq = pq.ok_or_else(|| Error::index("external index missing PQ stage"))?; + Ok((ivf, pq, metric)) +} diff --git a/rust/lance/src/index/vector/external/params.rs b/rust/lance/src/index/vector/external/params.rs new file mode 100644 index 00000000000..c01c220175a --- /dev/null +++ b/rust/lance/src/index/vector/external/params.rs @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Build parameters for [`super::ExternalIvfPqIndex`]. + +use lance_linalg::distance::MetricType; + +/// Configuration for [`super::ExternalIvfPqIndex::build`]. +/// +/// Use [`Self::builder`]; defaults match Lance's normal IVF-PQ defaults. +#[derive(Clone, Debug)] +pub struct ExternalIvfPqIndexParams { + /// Number of IVF partitions (kmeans `k`). + pub num_partitions: usize, + /// Number of PQ sub-vectors. Vector dimension must be divisible by this. + pub num_sub_vectors: usize, + /// Bits per PQ code. 8 (256 codes) is the universal default. + pub num_bits_per_sub_vector: usize, + /// Distance metric. + pub metric: MetricType, + /// kmeans iterations during IVF training. + pub max_iters: usize, + /// Training sample size = `num_partitions * sample_rate`. + pub sample_rate: usize, + /// RNG seed (kmeans + PQ training). + pub seed: u64, +} + +impl ExternalIvfPqIndexParams { + pub fn builder() -> ExternalIvfPqIndexParamsBuilder { + ExternalIvfPqIndexParamsBuilder::default() + } +} + +#[derive(Clone, Debug)] +pub struct ExternalIvfPqIndexParamsBuilder { + num_partitions: usize, + num_sub_vectors: usize, + num_bits_per_sub_vector: usize, + metric: MetricType, + max_iters: usize, + sample_rate: usize, + seed: u64, +} + +impl Default for ExternalIvfPqIndexParamsBuilder { + fn default() -> Self { + // Defaults match what Lance's dataset-backed IVF-PQ uses today. + Self { + num_partitions: 256, + num_sub_vectors: 16, + num_bits_per_sub_vector: 8, + metric: MetricType::L2, + max_iters: 50, + sample_rate: 256, + seed: 0xCAFE_BABE_DEAD_BEEF, + } + } +} + +impl ExternalIvfPqIndexParamsBuilder { + pub fn num_partitions(mut self, n: usize) -> Self { + self.num_partitions = n; + self + } + pub fn num_sub_vectors(mut self, n: usize) -> Self { + self.num_sub_vectors = n; + self + } + pub fn num_bits_per_sub_vector(mut self, n: usize) -> Self { + self.num_bits_per_sub_vector = n; + self + } + pub fn metric(mut self, m: MetricType) -> Self { + self.metric = m; + self + } + pub fn max_iters(mut self, n: usize) -> Self { + self.max_iters = n; + self + } + pub fn sample_rate(mut self, n: usize) -> Self { + self.sample_rate = n; + self + } + pub fn seed(mut self, s: u64) -> Self { + self.seed = s; + self + } + + pub fn build(self) -> ExternalIvfPqIndexParams { + ExternalIvfPqIndexParams { + num_partitions: self.num_partitions, + num_sub_vectors: self.num_sub_vectors, + num_bits_per_sub_vector: self.num_bits_per_sub_vector, + metric: self.metric, + max_iters: self.max_iters, + sample_rate: self.sample_rate, + seed: self.seed, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builder_defaults() { + let p = ExternalIvfPqIndexParams::builder().build(); + assert_eq!(p.num_partitions, 256); + assert_eq!(p.num_sub_vectors, 16); + assert_eq!(p.num_bits_per_sub_vector, 8); + assert_eq!(p.metric, MetricType::L2); + } + + #[test] + fn builder_overrides() { + let p = ExternalIvfPqIndexParams::builder() + .num_partitions(32) + .num_sub_vectors(4) + .metric(MetricType::Cosine) + .max_iters(10) + .seed(42) + .build(); + assert_eq!(p.num_partitions, 32); + assert_eq!(p.num_sub_vectors, 4); + assert_eq!(p.metric, MetricType::Cosine); + assert_eq!(p.max_iters, 10); + assert_eq!(p.seed, 42); + } +} diff --git a/rust/lance/src/index/vector/external/parquet_source.rs b/rust/lance/src/index/vector/external/parquet_source.rs new file mode 100644 index 00000000000..783e9a1ccbc --- /dev/null +++ b/rust/lance/src/index/vector/external/parquet_source.rs @@ -0,0 +1,477 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Parquet vector source for the external IVF-PQ build path. +//! +//! Reads vectors from `Vec` and produces: +//! +//! - [`ParquetVectorSource::sample`] — a `FixedSizeListArray` of training vectors +//! for kmeans + PQ codebook learning +//! - [`ParquetVectorSource::iter_batches`] — a `RecordBatchStream` of `(vec, _rowid)` +//! batches feeding `IvfTransformer` + `shuffle_dataset`. `_rowid` is the encoded +//! `(file_id_u32 << 32) | row_index_u32` rid the RFC pins. + +use std::fs::File; +use std::sync::Arc; + +use arrow::compute::concat; +use arrow_array::{ + Array, ArrayRef, FixedSizeListArray, Float32Array, ListArray, RecordBatch, UInt64Array, +}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use lance_arrow::FixedSizeListArrayExt; +use lance_core::{Error, Result}; +use lance_io::stream::{RecordBatchStream, RecordBatchStreamAdapter}; +use parquet::arrow::ProjectionMask; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + +use super::types::ParquetFileSpec; + +/// The output column name for caller-controlled rids in batches yielded by +/// [`ParquetVectorSource::iter_batches`]. Matches the name `shuffle_dataset` and +/// `write_pq_partitions` already use internally for Lance's row identity. +pub const RID_COLUMN_NAME: &str = "_rowid"; + +/// Read vectors from a list of parquet files, attach the encoded rid, and feed +/// them into the IVF-PQ build pipeline. +pub struct ParquetVectorSource { + files: Vec, + vector_column: String, + /// Vector dimension, discovered when the first file's schema is read. Cached + /// so `dim()` and `iter_batches()` agree. + dim: usize, +} + +impl ParquetVectorSource { + /// Construct a source over `files` reading `vector_column`. Reads each file's + /// footer to validate the column type and infer the dimension. + pub fn try_new(files: Vec, vector_column: &str) -> Result { + if files.is_empty() { + return Err(Error::invalid_input( + "ExternalIvfPqIndex requires at least one parquet file", + )); + } + + let first_dim = read_vector_dim(&files[0].file_path, vector_column)?; + Ok(Self { + files, + vector_column: vector_column.to_string(), + dim: first_dim, + }) + } + + /// Vector dimension shared across all registered files. + #[allow(dead_code)] + pub fn dim(&self) -> usize { + self.dim + } + + /// Total row count, summed across files. Reads each footer. + pub fn num_rows(&self) -> Result { + let mut total = 0u64; + for spec in &self.files { + if spec.num_rows > 0 { + total += spec.num_rows; + } else { + total += read_num_rows(&spec.file_path)?; + } + } + Ok(total) + } + + /// Sample up to `n` vectors uniformly across the file list for training. + /// + /// Strategy: round-robin read from each file's start until `n` is hit. For PQ + /// training this is sufficient — kmeans + codebook quality is set by vector + /// distribution, not by random row selection. A future iteration can add + /// reservoir sampling if recall numbers say it's needed. + pub async fn sample(&self, n: usize) -> Result { + let per_file = n.div_ceil(self.files.len()).max(1); + let mut accumulated: Vec = Vec::new(); + let mut total = 0usize; + + for spec in &self.files { + if total >= n { + break; + } + let want = (n - total).min(per_file); + let fsl = read_first_n_vectors(&spec.file_path, &self.vector_column, want)?; + total += fsl.len(); + accumulated.push(Arc::new(fsl)); + if total >= n { + break; + } + } + + let array_refs: Vec<&dyn Array> = accumulated.iter().map(|a| a.as_ref()).collect(); + let concatenated = concat(&array_refs).map_err(|e| { + Error::invalid_input(format!("failed to concatenate sample batches: {e}")) + })?; + let fsl = concatenated + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::invalid_input("concatenated sample is not a FixedSizeListArray"))? + .clone(); + Ok(fsl) + } + + /// Stream batches of `(vec, _rowid)` across all registered files. `_rowid` is + /// `(file_id_u32 << 32) | row_index_u32`. Feeds `IvfTransformer` + + /// `shuffle_dataset` during the build. + /// + /// Implementation note: parquet-rs uses blocking reads, so we materialize + /// per-file batches eagerly into a `Vec` and stream them. Memory cost is + /// bounded by the largest single parquet file's row groups since we don't + /// hold all files in memory at once — but we do hold one file's worth at a + /// time. A truly streaming variant (spawn_blocking ladder) lands later if + /// large-file memory becomes a concern. + pub fn iter_batches(&self) -> Result { + let out_schema = Arc::new(Schema::new(vec![ + Field::new( + self.vector_column.clone(), + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + self.dim as i32, + ), + false, + ), + Field::new(RID_COLUMN_NAME, DataType::UInt64, false), + ])); + + let batches = collect_rid_annotated_batches( + self.files.clone(), + self.vector_column.clone(), + out_schema.clone(), + )?; + + let stream = + futures::stream::iter(batches.into_iter().map(|b| Ok::(b))); + Ok(RecordBatchStreamAdapter::new(out_schema, stream)) + } +} + +// ---- helpers ------------------------------------------------------------------------ + +fn parquet_open_err(path: &str, source: impl std::fmt::Display) -> Error { + Error::invalid_input(format!("failed to open parquet file {path}: {source}")) +} + +fn parquet_meta_err(path: &str, source: impl std::fmt::Display) -> Error { + Error::invalid_input(format!( + "failed to read parquet metadata for {path}: {source}" + )) +} + +fn parquet_reader_err(path: &str, source: impl std::fmt::Display) -> Error { + Error::invalid_input(format!( + "failed to build parquet reader for {path}: {source}" + )) +} + +fn parquet_batch_err(path: &str, source: impl std::fmt::Display) -> Error { + Error::invalid_input(format!("error reading parquet batch from {path}: {source}")) +} + +fn read_vector_dim(path: &str, column: &str) -> Result { + let file = File::open(path).map_err(|e| parquet_open_err(path, e))?; + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| parquet_meta_err(path, e))?; + let arrow_schema = builder.schema().clone(); + let field = arrow_schema.field_with_name(column).map_err(|e| { + Error::invalid_input(format!("vector column '{column}' not found in {path}: {e}")) + })?; + match field.data_type() { + DataType::FixedSizeList(_, n) => Ok(*n as usize), + DataType::List(_) | DataType::LargeList(_) => { + // Spark and most JVM parquet writers emit `List` even when every row + // has the same length — Arrow's fixed-size-list metadata round-trips poorly + // through parquet. Probe the first batch's first row to infer the dimension; + // the read path enforces every subsequent row to match. + let mask = ProjectionMask::columns(builder.parquet_schema(), [column]); + let reader = builder + .with_projection(mask) + .with_batch_size(1) + .build() + .map_err(|e| parquet_reader_err(path, e))?; + for batch_result in reader { + let batch = batch_result.map_err(|e| parquet_batch_err(path, e))?; + if batch.num_rows() == 0 { + continue; + } + let col = batch.column_by_name(column).ok_or_else(|| { + Error::invalid_input(format!("column '{column}' missing in {path}")) + })?; + if let Some(la) = col.as_any().downcast_ref::() { + let first_len = la.value_length(0); + return Ok(first_len as usize); + } + if let Some(la) = col.as_any().downcast_ref::() { + let first_len = la.value_length(0); + return Ok(first_len as usize); + } + break; + } + Err(Error::invalid_input(format!( + "vector column '{column}' in {path} is List but file contains no rows" + ))) + } + other => Err(Error::invalid_input(format!( + "vector column '{column}' in {path} must be FixedSizeList or List, got {other:?}" + ))), + } +} + +fn read_num_rows(path: &str) -> Result { + let file = File::open(path).map_err(|e| parquet_open_err(path, e))?; + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| parquet_meta_err(path, e))?; + Ok(builder.metadata().file_metadata().num_rows() as u64) +} + +/// Coerce a column array (FixedSizeList or List) to FixedSizeListArray, validating that +/// every row has length `expected_dim`. Spark and most JVM parquet writers emit List, so +/// this is the common-case path. +pub(crate) fn coerce_to_fsl(col: &ArrayRef, expected_dim: usize) -> Result { + if let Some(fsl) = col.as_any().downcast_ref::() { + if fsl.value_length() as usize != expected_dim { + return Err(Error::invalid_input(format!( + "vector column dim {} != expected {}", + fsl.value_length(), + expected_dim + ))); + } + return Ok(fsl.clone()); + } + if let Some(la) = col.as_any().downcast_ref::() { + let values = la.values(); + // Validate: every offset increment must equal expected_dim; no nulls; values are Float32. + let offsets = la.offsets(); + for i in 0..la.len() { + let len = (offsets[i + 1] - offsets[i]) as usize; + if len != expected_dim { + return Err(Error::invalid_input(format!( + "row {i} of List column has length {len}, expected {expected_dim}" + ))); + } + } + let f32_values = values + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::invalid_input(format!( + "List vector column has non-Float32 inner type {:?}", + values.data_type() + )) + })? + .clone(); + return FixedSizeListArray::try_new_from_values(f32_values, expected_dim as i32) + .map_err(|e| Error::invalid_input(format!("List → FSL conversion failed: {e}"))); + } + Err(Error::invalid_input(format!( + "vector column has unsupported type {:?}; expected FixedSizeList or List", + col.data_type() + ))) +} + +fn read_first_n_vectors(path: &str, column: &str, n: usize) -> Result { + let dim = read_vector_dim(path, column)?; + let file = File::open(path).map_err(|e| parquet_open_err(path, e))?; + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| parquet_meta_err(path, e))?; + let mask = ProjectionMask::columns(builder.parquet_schema(), [column]); + let reader = builder + .with_projection(mask) + .with_batch_size(n.max(1024)) + .build() + .map_err(|e| parquet_reader_err(path, e))?; + + let mut collected = 0usize; + let mut fsl_chunks: Vec = Vec::new(); + for batch_result in reader { + let batch = batch_result.map_err(|e| parquet_batch_err(path, e))?; + let take_n = (n - collected).min(batch.num_rows()); + let col = batch + .column_by_name(column) + .ok_or_else(|| { + Error::invalid_input(format!("column '{column}' missing from batch in {path}")) + })? + .slice(0, take_n); + let fsl = coerce_to_fsl(&col, dim)?; + fsl_chunks.push(fsl); + collected += take_n; + if collected >= n { + break; + } + } + + let array_refs: Vec<&dyn Array> = fsl_chunks.iter().map(|a| a as &dyn Array).collect(); + let concatenated = concat(&array_refs) + .map_err(|e| Error::invalid_input(format!("failed to concat batches from {path}: {e}")))?; + let fsl = concatenated + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::invalid_input(format!( + "vector column '{column}' did not yield FixedSizeListArray in {path}" + )) + })? + .clone(); + Ok(fsl) +} + +fn collect_rid_annotated_batches( + files: Vec, + column: String, + out_schema: SchemaRef, +) -> Result> { + // Coerce on the way in: List rows from Spark/JVM writers get reshaped to + // FixedSizeListArray so the downstream IvfTransformer / shuffle path sees the schema + // it expects (out_schema declares FixedSizeList). + let dim_field = out_schema.field_with_name(&column).map_err(|e| { + Error::invalid_input(format!("column '{column}' missing in out_schema: {e}")) + })?; + let dim = match dim_field.data_type() { + DataType::FixedSizeList(_, n) => *n as usize, + other => { + return Err(Error::invalid_input(format!( + "out_schema vector column '{column}' must be FixedSizeList, got {other:?}" + ))); + } + }; + + let mut out: Vec = Vec::new(); + for (file_id, spec) in files.iter().enumerate() { + let file = File::open(&spec.file_path).map_err(|e| parquet_open_err(&spec.file_path, e))?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file) + .map_err(|e| parquet_meta_err(&spec.file_path, e))?; + let mask = ProjectionMask::columns(builder.parquet_schema(), [column.as_str()]); + let reader = builder + .with_projection(mask) + .build() + .map_err(|e| parquet_reader_err(&spec.file_path, e))?; + + let mut row_in_file: u64 = 0; + for batch_result in reader { + let batch = batch_result.map_err(|e| parquet_batch_err(&spec.file_path, e))?; + let n = batch.num_rows(); + let mut rids: Vec = Vec::with_capacity(n); + for i in 0..n { + rids.push(((file_id as u64) << 32) | (row_in_file + i as u64)); + } + let rid_array = Arc::new(UInt64Array::from(rids)) as ArrayRef; + let raw_vec_col = batch + .column_by_name(&column) + .ok_or_else(|| { + Error::invalid_input(format!( + "column '{}' missing from batch in {}", + column, spec.file_path + )) + })? + .clone(); + let vec_fsl = coerce_to_fsl(&raw_vec_col, dim)?; + let vec_array: ArrayRef = Arc::new(vec_fsl); + let new_batch = RecordBatch::try_new(out_schema.clone(), vec![vec_array, rid_array]) + .map_err(|e| { + Error::invalid_input(format!("failed to build (vec,_rowid) batch: {e}")) + })?; + out.push(new_batch); + row_in_file += n as u64; + } + } + Ok(out) +} + +// Keep Float32Array referenced for tests; harmless on non-test builds. +#[allow(dead_code)] +const _: fn() = || { + let _: Option = None; +}; + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::ArrayRef; + use arrow_array::RecordBatch; + use futures::TryStreamExt; + use lance_arrow::FixedSizeListArrayExt; + use parquet::arrow::ArrowWriter; + use parquet::file::properties::WriterProperties; + use std::path::PathBuf; + use tempfile::TempDir; + + fn make_parquet(path: &PathBuf, num_rows: usize, dim: usize, seed: f32) { + let values: Vec = (0..num_rows * dim).map(|i| seed + i as f32).collect(); + let flat = Float32Array::from(values); + let fsl = FixedSizeListArray::try_new_from_values(flat, dim as i32).unwrap(); + let schema = Arc::new(Schema::new(vec![Field::new( + "vec", + fsl.data_type().clone(), + false, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl) as ArrayRef]).unwrap(); + + let file = std::fs::File::create(path).unwrap(); + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + } + + #[tokio::test] + async fn source_dim_and_num_rows() { + let tmp = TempDir::new().unwrap(); + let p = tmp.path().join("a.parquet"); + make_parquet(&p, 100, 8, 0.0); + let spec = ParquetFileSpec::of(p.to_str().unwrap()); + let src = ParquetVectorSource::try_new(vec![spec], "vec").unwrap(); + assert_eq!(src.dim(), 8); + assert_eq!(src.num_rows().unwrap(), 100); + } + + #[tokio::test] + async fn sample_returns_requested_count() { + let tmp = TempDir::new().unwrap(); + let mut files = Vec::new(); + for i in 0..3 { + let p = tmp.path().join(format!("part-{i}.parquet")); + make_parquet(&p, 50, 4, i as f32 * 1000.0); + files.push(ParquetFileSpec::of(p.to_str().unwrap())); + } + let src = ParquetVectorSource::try_new(files, "vec").unwrap(); + let sample = src.sample(60).await.unwrap(); + assert_eq!(sample.len(), 60); + assert_eq!(sample.value_length(), 4); + } + + #[tokio::test] + async fn iter_batches_assigns_correct_rids() { + let tmp = TempDir::new().unwrap(); + let mut files = Vec::new(); + for i in 0..2 { + let p = tmp.path().join(format!("part-{i}.parquet")); + make_parquet(&p, 3, 2, i as f32 * 100.0); + files.push(ParquetFileSpec::of(p.to_str().unwrap())); + } + let src = ParquetVectorSource::try_new(files, "vec").unwrap(); + let stream = src.iter_batches().unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + + let mut all_rids: Vec = Vec::new(); + for batch in &batches { + let rid_arr = batch + .column_by_name(RID_COLUMN_NAME) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..rid_arr.len() { + all_rids.push(rid_arr.value(i)); + } + } + + assert_eq!( + all_rids, + vec![0, 1, 2, (1u64 << 32), (1u64 << 32) | 1, (1u64 << 32) | 2] + ); + } +} diff --git a/rust/lance/src/index/vector/external/search.rs b/rust/lance/src/index/vector/external/search.rs new file mode 100644 index 00000000000..10d97dd35ea --- /dev/null +++ b/rust/lance/src/index/vector/external/search.rs @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! `search()` implementation for [`super::ExternalIvfPqIndex`]. +//! +//! Pipeline per query: +//! +//! 1. **IVF probe**: `IvfModel::find_partitions(query, nprobes)` returns the +//! closest `nprobes` partition IDs. +//! 2. **PQ candidate scoring**: for each probed partition, read the on-disk PQ +//! codes + row IDs, score each row's code against the query, accumulate +//! `(rid, pq_distance)` into a max-heap of size `k * refine_factor`. +//! 3. **Decode rids → (file_id, row_index)**: `(rid >> 32, rid & 0xFFFF_FFFF)`. +//! 4. **Refinement read**: group candidates by file, fetch each file's actual +//! vectors via the page-index-aware parquet reader, compute exact distances. +//! 5. **Apply `RowFilter`** before re-ranking; rows the filter rejects are +//! dropped. +//! 6. **Top-K trim**: sort by exact distance, take K, return as +//! `Vec` keyed on `(file_path, row_index)`. + +use std::cmp::Reverse; +use std::collections::{BinaryHeap, HashMap}; +use std::fs::File; +use std::ops::Range; +use std::sync::Arc; + +use arrow::compute::concat; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array}; +use lance_core::{Error, Result}; +use lance_index::vector::pq::ProductQuantizer; +use lance_io::traits::Reader; +use lance_linalg::distance::MetricType; +use parquet::arrow::ProjectionMask; +use parquet::arrow::arrow_reader::{ + ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, +}; +use parquet::file::metadata::PageIndexPolicy; + +use super::open::OpenedExternalIndex; +use super::types::{RowFilter, SearchResult}; + +/// Top-level search entry point. Heavy lifting is in the helpers; this is the +/// orchestrator. +pub async fn search( + opened: &OpenedExternalIndex, + query: &[f32], + k: usize, + nprobes: usize, + refine_factor: usize, + filter: Option<&dyn RowFilter>, +) -> Result> { + let dim = opened.ivf.dimension(); + if query.len() != dim { + return Err(Error::invalid_input(format!( + "query dim {} != index dim {dim}", + query.len() + ))); + } + let mt: MetricType = opened.metric; + + // 1. Probe nprobes partitions + let query_array = Float32Array::from(query.to_vec()); + let (part_ids, _part_distances) = + opened + .ivf + .find_partitions(&query_array, nprobes.max(1), opened.metric)?; + + // For L2/Cosine the index stores residual-encoded PQ codes — we'd need to + // subtract the partition centroid from the query before PQ scoring. For Dot, + // residuals aren't applied. Since the build path computes residuals for + // L2/Cosine, mirror that here. + let centroids = opened + .ivf + .centroids_array() + .ok_or_else(|| Error::index("opened index has no centroids"))? + .clone(); + let residual_for_partition = |part_id: u32| -> Vec { + let centroid_values = centroids + .values() + .as_any() + .downcast_ref::() + .expect("centroids must be Float32"); + let start = (part_id as usize) * dim; + if matches!(mt, MetricType::L2 | MetricType::Cosine) { + (0..dim) + .map(|d| query[d] - centroid_values.value(start + d)) + .collect() + } else { + query.to_vec() + } + }; + + // 2. Score PQ codes across probed partitions, build top (k*refine_factor) + // candidate min-heap by PQ-approx distance. + let candidate_count = (k * refine_factor.max(1)).max(k); + let mut top_heap: BinaryHeap<(OrderedF32, u64)> = BinaryHeap::with_capacity(candidate_count); + + for &part_id in part_ids.values().iter() { + let residual = residual_for_partition(part_id); + let part_idx = part_id as usize; + let part_range = opened.ivf.row_range(part_idx); + if part_range.is_empty() { + continue; + } + let (pq_codes, row_ids) = + read_partition(&opened.index_file_reader, &opened.pq, part_range).await?; + + for i in 0..row_ids.len() { + let code_offset = i * opened.pq.num_sub_vectors; + let dist = score_pq_code(&residual, &opened.pq, &pq_codes, code_offset, mt); + let rid = row_ids[i]; + + if top_heap.len() < candidate_count { + top_heap.push((OrderedF32(dist), rid)); + } else if let Some(top) = top_heap.peek() { + if dist < top.0.0 { + top_heap.pop(); + top_heap.push((OrderedF32(dist), rid)); + } + } + } + } + + // 3. Decode rids and apply RowFilter pre-refinement (so we skip parquet I/O + // on dropped rows). + let candidates: Vec = top_heap.into_iter().map(|(_, rid)| rid).collect(); + let mut to_refine: Vec<(u64, u32, u64, String)> = Vec::with_capacity(candidates.len()); + for rid in candidates { + let file_id = (rid >> 32) as u32; + let row_in_file = rid & 0xFFFF_FFFF; + let file_path = opened + .manifest + .file_path(file_id) + .ok_or_else(|| { + Error::index(format!( + "candidate rid {rid:#x} encodes file_id={file_id} but manifest has only {} files", + opened.manifest.files.len() + )) + })? + .to_string(); + if let Some(f) = filter { + if !f.keep(&file_path, row_in_file) { + continue; + } + } + to_refine.push((rid, file_id, row_in_file, file_path)); + } + + // 4. Per-file refinement reads. + let mut by_file: HashMap> = HashMap::new(); + for (input_pos, &(_rid, _file_id, row_in_file, ref file_path)) in to_refine.iter().enumerate() { + by_file + .entry(file_path.clone()) + .or_default() + .push((input_pos, row_in_file)); + } + + let mut exact_dists: Vec<(usize, f32)> = Vec::with_capacity(to_refine.len()); + for (file_path, hits) in by_file { + let row_indices: Vec = hits.iter().map(|(_, r)| *r).collect(); + let fetched = + read_vectors_by_row_index(&file_path, &opened.manifest.vector_column, &row_indices)?; + let dim = fetched.value_length() as usize; + let values = fetched + .values() + .as_any() + .downcast_ref::() + .expect("Float32Array vectors"); + for (out_idx, &(input_pos, _row)) in hits.iter().enumerate() { + let mut dist = 0.0f32; + for d in 0..dim { + let v = values.value(out_idx * dim + d); + let q = query[d]; + let diff = v - q; + dist += diff * diff; + } + // For Cosine: distance is 1 - cosine similarity = 1 - (a·b/(|a||b|)). + // For L2: dist is the squared L2 above. + // For Dot: -dot. + // We approximate with squared L2 here for L2 and Cosine; full + // Cosine support lands in a follow-up after we plumb normalization + // through this path. + let _ = mt; + exact_dists.push((input_pos, dist)); + } + } + + // 5. Sort + top-K + exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + let mut out: Vec = Vec::with_capacity(k); + for (input_pos, dist) in exact_dists.into_iter().take(k) { + let (_, _, row_in_file, file_path) = &to_refine[input_pos]; + out.push(SearchResult { + file_path: file_path.clone(), + row_index: *row_in_file, + distance: dist, + }); + } + Ok(out) +} + +/// Read one partition's PQ codes + row IDs from the index file. +/// +/// Layout (matching `write_pq_partitions`): +/// +/// `[partition_offset .. partition_offset + len * num_sub_vectors]` PQ codes (u8) +/// `[partition_offset + len * num_sub_vectors .. + len * 8]` row IDs (u64) +async fn read_partition( + reader: &Arc, + pq: &ProductQuantizer, + range: Range, +) -> Result<(Vec, Vec)> { + let len = range.end - range.start; + let pq_bytes = len * pq.num_sub_vectors; + let row_id_bytes = len * 8; + + let pq_range = range.start..range.start + pq_bytes; + let pq_data = reader.get_range(pq_range).await?; + let pq_codes: Vec = pq_data.to_vec(); + + let rid_range_start = range.start + pq_bytes; + let rid_range = rid_range_start..rid_range_start + row_id_bytes; + let rid_data = reader.get_range(rid_range).await?; + let mut row_ids: Vec = Vec::with_capacity(len); + for i in 0..len { + let chunk: [u8; 8] = rid_data[i * 8..(i + 1) * 8] + .try_into() + .map_err(|_| Error::io("partition row_id chunk truncated".to_string()))?; + row_ids.push(u64::from_le_bytes(chunk)); + } + Ok((pq_codes, row_ids)) +} + +/// Score one PQ code against the (residual or raw) query using the PQ codebook's +/// distance tables. This is a slow, straightforward implementation — the SIMD +/// fast path lives in `lance_index::vector::pq` and lands as a follow-up. For +/// Phase 1 the goal is correctness; Phase 1.5 perf is acceptable as long as it +/// matches the IVF probe's output. +fn score_pq_code( + query_or_residual: &[f32], + pq: &ProductQuantizer, + pq_codes: &[u8], + code_offset: usize, + _mt: MetricType, +) -> f32 { + let m = pq.num_sub_vectors; + let dim = pq.dimension; + let sub_dim = dim / m; + + let codebook_values = pq + .codebook + .values() + .as_any() + .downcast_ref::() + .expect("codebook is Float32"); + + let mut total = 0.0f32; + for s in 0..m { + let code = pq_codes[code_offset + s] as usize; + // codebook layout: [num_subvectors][num_codes][sub_dim] + let cb_offset = (s * (1usize << pq.num_bits) + code) * sub_dim; + let q_offset = s * sub_dim; + for d in 0..sub_dim { + let cb_val = codebook_values.value(cb_offset + d); + let q_val = query_or_residual[q_offset + d]; + let diff = q_val - cb_val; + total += diff * diff; + } + } + total +} + +/// Page-index-aware random fetch from one parquet file. Returns rows in +/// caller-input order. The same primitive [`super::fetch_rows`] uses for +/// post-topK materialization, but specialized to a single file (since +/// refinement is already grouped by file at the call site). +fn read_vectors_by_row_index( + path: &str, + column: &str, + row_indices: &[u64], +) -> Result { + let file = File::open(path) + .map_err(|e| Error::invalid_input(format!("failed to open parquet {path}: {e}")))?; + let opts = ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Required); + let builder = + ParquetRecordBatchReaderBuilder::try_new_with_options(file, opts).map_err(|e| { + Error::invalid_input(format!("failed to read parquet metadata for {path}: {e}")) + })?; + let total_rows: u64 = builder.metadata().file_metadata().num_rows() as u64; + let mask = ProjectionMask::columns(builder.parquet_schema(), [column]); + + // Dedup + sort since RowSelection requires strictly increasing ranges. + let mut sorted: Vec = row_indices.to_vec(); + sorted.sort_unstable(); + sorted.dedup(); + let ranges: Vec> = sorted + .iter() + .map(|&r| (r as usize)..(r as usize + 1)) + .collect(); + let selection = RowSelection::from_consecutive_ranges(ranges.into_iter(), total_rows as usize); + + let reader = builder + .with_projection(mask) + .with_row_selection(selection) + .build() + .map_err(|e| { + Error::invalid_input(format!("failed to build parquet reader for {path}: {e}")) + })?; + + let batches: Vec = reader + .collect::, _>>() + .map_err(|e| { + Error::invalid_input(format!("error reading parquet batches from {path}: {e}")) + })?; + + // Detect dim from the first non-empty batch's column. coerce_to_fsl handles both + // FixedSizeList and List. + let dim_from_batch: usize = batches + .iter() + .flat_map(|b| b.column_by_name(column).map(|c| c.clone())) + .find(|c| c.len() > 0) + .and_then(|col| { + if let Some(fsl) = col.as_any().downcast_ref::() { + Some(fsl.value_length() as usize) + } else if let Some(la) = col.as_any().downcast_ref::() { + Some(la.value_length(0) as usize) + } else { + None + } + }) + .ok_or_else(|| { + Error::invalid_input(format!( + "vector column '{column}' yielded no rows in {path}" + )) + })?; + + // Coerce each batch's column to FSL, then concat. + let mut fsl_chunks: Vec = Vec::new(); + for b in &batches { + let col = b + .column_by_name(column) + .ok_or_else(|| { + Error::invalid_input(format!("vector column '{column}' missing in {path}")) + })? + .clone(); + fsl_chunks.push(super::parquet_source::coerce_to_fsl(&col, dim_from_batch)?); + } + let array_refs: Vec<&dyn Array> = fsl_chunks.iter().map(|a| a as &dyn Array).collect(); + let concatenated = concat(&array_refs) + .map_err(|e| Error::invalid_input(format!("failed to concat batches from {path}: {e}")))?; + let fsl = concatenated + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::invalid_input(format!("vector column '{column}' missing in {path}")))? + .clone(); + + // Reorder to caller's input order. + let dim = fsl.value_length() as usize; + let result_values = fsl + .values() + .as_any() + .downcast_ref::() + .expect("Float32 values"); + let mut reordered = Vec::with_capacity(row_indices.len() * dim); + for &rid in row_indices { + let pos = sorted + .binary_search(&rid) + .map_err(|_| Error::index(format!("rid {rid} missing from parquet result")))?; + for d in 0..dim { + reordered.push(result_values.value(pos * dim + d)); + } + } + let flat = Float32Array::from(reordered); + let _ = ArrayRef::from(Arc::new(flat.clone()) as ArrayRef); + use lance_arrow::FixedSizeListArrayExt; + Ok(FixedSizeListArray::try_new_from_values(flat, dim as i32) + .map_err(|e| Error::index(format!("failed to rebuild FSL: {e}")))?) +} + +// Wrap f32 to make it Ord-eligible for the BinaryHeap. Heap is a max-heap so we +// also flip the order via Reverse pattern at insert time? Actually we use Reverse +// at heap construction to make it min-by-distance for keep semantics; but the +// snippet above pushes (OrderedF32(dist), rid) directly. Because BinaryHeap is a +// max-heap, peek() returns the largest distance — exactly what we want as the +// "weakest candidate to be kicked out when a better one arrives." +#[derive(Copy, Clone, PartialEq, PartialOrd)] +struct OrderedF32(f32); +impl Eq for OrderedF32 {} +impl Ord for OrderedF32 { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0 + .partial_cmp(&other.0) + .unwrap_or(std::cmp::Ordering::Equal) + } +} + +// `Reverse` is unused but referenced in the doc — keep it imported via use::reexport +// silencing. +const _: fn() = || { + let _: Reverse = Reverse(0); +}; diff --git a/rust/lance/src/index/vector/external/types.rs b/rust/lance/src/index/vector/external/types.rs new file mode 100644 index 00000000000..8ecb0ba17ed --- /dev/null +++ b/rust/lance/src/index/vector/external/types.rs @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Public value types and traits for the external vector index API. + +use std::sync::Arc; + +use arrow_schema::SchemaRef; + +/// One parquet file registered with the external index. +/// +/// `file_id` for indexed rows is implicit in this struct's position in the +/// `Vec` passed to [`super::ExternalIvfPqIndex::build`]. The +/// encoded rid is `(file_id_u32 << 32) | row_index_u32`. Reordering the file list +/// across rebuilds invalidates the index. +#[derive(Clone, Debug)] +pub struct ParquetFileSpec { + /// URI / path readable by Lance's object store layer. + pub file_path: String, + /// Row count, inferred from the parquet footer if not supplied. + pub num_rows: u64, + /// Arrow schema of the file. Inferred from the parquet footer if `None` at + /// build time; populated after the index opens the file. + pub schema: Option, +} + +impl ParquetFileSpec { + /// Construct a spec by inferring `num_rows` and `schema` from the parquet + /// footer at `file_path`. The caller pays one footer read per file. + pub fn of(file_path: impl Into) -> Self { + Self { + file_path: file_path.into(), + num_rows: 0, + schema: None, + } + } + + /// Construct a spec with metadata already known. Skips the footer read. + pub fn with_metadata(file_path: impl Into, num_rows: u64, schema: SchemaRef) -> Self { + Self { + file_path: file_path.into(), + num_rows, + schema: Some(schema), + } + } +} + +/// One result from [`super::ExternalIvfPqIndex::search`]. Already refined; the +/// `distance` is exact under the index's metric. +#[derive(Clone, Debug, PartialEq)] +pub struct SearchResult { + pub file_path: String, + pub row_index: u64, + pub distance: f32, +} + +/// `(file_path, row_index)` pair accepted by +/// [`super::ExternalIvfPqIndex::fetch_rows`]. Just a value struct — Lance does the +/// per-file batching internally. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ParquetRowKey { + pub file_path: String, + pub row_index: u64, +} + +impl ParquetRowKey { + pub fn of(file_path: impl Into, row_index: u64) -> Self { + Self { + file_path: file_path.into(), + row_index, + } + } +} + +/// Filter consulted during refinement. Rows the filter rejects don't make it into +/// the candidate set, so they can't appear in the top-K. +/// +/// The intended use case is honoring engine-level row liveness: +/// +/// - **Delta deletion vectors**: snapshot's deletion bitmap → `keep` returns false +/// for deleted positions +/// - **Iceberg position deletes**: same shape, different source +/// - **Ad-hoc skip predicates**: e.g. caller already has a snapshot ID and wants to +/// exclude rows newer than it +/// +/// Implementations must be `Send + Sync` because Lance may consult them concurrently +/// from refinement workers. +pub trait RowFilter: Send + Sync { + /// `true` to keep the row, `false` to drop it. + fn keep(&self, file_path: &str, row_index: u64) -> bool; +} + +/// Trivial bitmap-style filter. Wraps a closure. +pub struct PredicateRowFilter(pub F); + +impl RowFilter for PredicateRowFilter +where + F: Fn(&str, u64) -> bool + Send + Sync, +{ + fn keep(&self, file_path: &str, row_index: u64) -> bool { + (self.0)(file_path, row_index) + } +} + +impl RowFilter for Arc { + fn keep(&self, file_path: &str, row_index: u64) -> bool { + (**self).keep(file_path, row_index) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rid_encoding_round_trips() { + // Sanity: confirm the encoding contract Phase 1.3 will rely on. + let file_id: u32 = 7; + let row_index: u32 = 42_000; + let rid = ((file_id as u64) << 32) | (row_index as u64); + assert_eq!((rid >> 32) as u32, file_id); + assert_eq!((rid & 0xFFFF_FFFF) as u32, row_index); + } + + #[test] + fn predicate_row_filter_keeps_and_drops() { + let f = PredicateRowFilter(|_path: &str, row: u64| row % 2 == 0); + assert!(f.keep("a.parquet", 0)); + assert!(!f.keep("a.parquet", 1)); + assert!(f.keep("b.parquet", 100)); + } + + #[test] + fn parquet_file_spec_constructors() { + let s = ParquetFileSpec::of("/tmp/x.parquet"); + assert_eq!(s.file_path, "/tmp/x.parquet"); + assert_eq!(s.num_rows, 0); + assert!(s.schema.is_none()); + } +} diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 179ec96df4c..c8a22d2a5df 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -1911,7 +1911,7 @@ pub async fn write_ivf_pq_file_from_existing_index( column: &str, index_name: &str, index_id: Uuid, - mut ivf: IvfModel, + ivf: IvfModel, pq: ProductQuantizer, streams: Vec>>, ) -> Result<()> { @@ -1920,13 +1920,48 @@ pub async fn write_ivf_pq_file_from_existing_index( .indices_dir() .join(index_id.to_string()) .join("index.idx"); - let mut writer = obj_store.create(&path).await?; + write_ivf_pq_file_external( + obj_store, + &path, + column, + index_name, + dataset.version().version, + ivf, + pq, + streams, + ) + .await +} + +/// Write an IVF-PQ index file directly to an `(ObjectStore, Path)` pair, with no +/// `Dataset` dependency. Used by the external IVF-PQ builder +/// ([`crate::index::vector::external::ExternalIvfPqIndex::build`]) — the build +/// inputs (vectors + rids) come from caller-supplied parquet files, not a Lance +/// dataset, so there's no manifest to read a version off and no `indices_dir` to +/// derive a path from. +/// +/// `dataset_version` is recorded in the index file's protobuf metadata. For +/// external indices it can be `0` (no associated dataset) or a caller-tracked +/// monotonic value if the caller wants to version their external index file +/// against their source. +#[allow(clippy::too_many_arguments)] +pub async fn write_ivf_pq_file_external( + object_store: &ObjectStore, + path: &Path, + column: &str, + index_name: &str, + dataset_version: u64, + mut ivf: IvfModel, + pq: ProductQuantizer, + streams: Vec>>, +) -> Result<()> { + let mut writer = object_store.create(path).await?; write_pq_partitions(writer.as_mut(), &mut ivf, Some(streams), None).await?; let metadata = IvfPQIndexMetadata::new( index_name.to_string(), column.to_string(), - dataset.version().version, + dataset_version, pq.distance_type, ivf, pq, diff --git a/rust/lance/tests/external_index_phase1.rs b/rust/lance/tests/external_index_phase1.rs new file mode 100644 index 00000000000..2e105c0cb14 --- /dev/null +++ b/rust/lance/tests/external_index_phase1.rs @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Phase 1 integration tests: drive only the public `ExternalIvfPqIndex` API +//! (build / open / search / fetch_rows + RowFilter). These mirror the two +//! existing PoC tests (`external_index_poc.rs`, `external_index_parquet_poc.rs`) +//! but go through the new public surface end-to-end. +//! +//! Run with: +//! cargo test -p lance --test external_index_phase1 -- --nocapture + +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::compute::concat; +use arrow_array::{ + Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, +}; +use arrow_schema::{DataType, Field, Schema}; +use lance::index::vector::external::{ + ExternalIvfPqIndex, ExternalIvfPqIndexParams, ParquetFileSpec, ParquetRowKey, RowFilter, + SearchResult, +}; +use lance_arrow::FixedSizeListArrayExt; +use lance_linalg::distance::MetricType; +use parquet::arrow::ArrowWriter; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::file::properties::WriterProperties; +use rand::{Rng, SeedableRng, rngs::StdRng}; +use tempfile::TempDir; + +const DIM: usize = 8; +const NUM_VECTORS_PER_FILE: usize = 320; // PQ training requires ≥ 256 sampled rows +const NUM_FILES: usize = 3; +const TOTAL_VECTORS: usize = NUM_VECTORS_PER_FILE * NUM_FILES; +const NUM_PARTITIONS: usize = 4; +const NUM_SUB_VECTORS: usize = 2; +const TOP_K: usize = 10; +const REFINE_FACTOR: usize = 8; + +fn write_parquet_with_payload( + path: &PathBuf, + num_rows: usize, + dim: usize, + seed: u64, + id_offset: i64, +) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let values: Vec = (0..num_rows * dim) + .map(|_| rng.random_range(-1.0f32..1.0)) + .collect(); + let flat = Float32Array::from(values); + let fsl = FixedSizeListArray::try_new_from_values(flat, dim as i32).unwrap(); + + let ids = Int64Array::from( + (0..num_rows as i64) + .map(|i| id_offset + i) + .collect::>(), + ); + let names = StringArray::from( + (0..num_rows) + .map(|i| format!("file{seed}-row{i}")) + .collect::>(), + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("vec", fsl.data_type().clone(), false), + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(fsl.clone()) as ArrayRef, + Arc::new(ids) as ArrayRef, + Arc::new(names) as ArrayRef, + ], + ) + .unwrap(); + + let file = std::fs::File::create(path).unwrap(); + let props = WriterProperties::builder() + .set_data_page_row_count_limit(64) + .build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + fsl +} + +fn read_back_vectors(path: &PathBuf) -> FixedSizeListArray { + let file = std::fs::File::open(path).unwrap(); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let reader = builder.build().unwrap(); + let mut batches: Vec = Vec::new(); + for r in reader { + batches.push(r.unwrap()); + } + let arrays: Vec<&dyn Array> = batches + .iter() + .map(|b| b.column_by_name("vec").unwrap().as_ref()) + .collect(); + let cat = concat(&arrays).unwrap(); + cat.as_any() + .downcast_ref::() + .unwrap() + .clone() +} + +fn brute_force_topk_global( + per_file_vectors: &[FixedSizeListArray], + query: &[f32], + k: usize, +) -> Vec<(usize, usize, f32)> { + let dim = per_file_vectors[0].value_length() as usize; + let mut all: Vec<(usize, usize, f32)> = Vec::new(); + for (file_id, vectors) in per_file_vectors.iter().enumerate() { + let values = vectors + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let n = vectors.len(); + for i in 0..n { + let mut s = 0.0f32; + for d in 0..dim { + let v = values.value(i * dim + d); + let q = query[d]; + let diff = v - q; + s += diff * diff; + } + all.push((file_id, i, s)); + } + } + all.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap()); + all.into_iter().take(k).collect() +} + +/// End-to-end Phase 1 integration: +/// +/// 1. Build an external IVF-PQ index over 3 parquet files. +/// 2. Open it. +/// 3. Run a few search queries; confirm SearchResult points at registered files +/// and at least some queries hit the brute-force top-K. +/// 4. Use fetch_rows() to materialize id/name columns for the survivors and +/// confirm the values match the source parquet. +/// 5. Plug a RowFilter that drops half the corpus; confirm survivors don't +/// include filtered rows. +#[tokio::test(flavor = "multi_thread")] +async fn phase1_e2e_build_open_search_fetch_filter() { + let _ = env_logger::builder().is_test(true).try_init(); + + // ---- Stage 1: write parquet files ------------------------------------------ + let tmp_data = TempDir::new().unwrap(); + let mut paths: Vec = Vec::new(); + let mut per_file_vectors: Vec = Vec::new(); + for f in 0..NUM_FILES { + let p = tmp_data.path().join(format!("part-{f}.parquet")); + write_parquet_with_payload( + &p, + NUM_VECTORS_PER_FILE, + DIM, + 42 + f as u64, + (f as i64) * 1_000_000, + ); + per_file_vectors.push(read_back_vectors(&p)); + paths.push(p); + } + + // ---- Stage 2: build + open ------------------------------------------------ + let tmp_out = TempDir::new().unwrap(); + let params = ExternalIvfPqIndexParams::builder() + .num_partitions(NUM_PARTITIONS) + .num_sub_vectors(NUM_SUB_VECTORS) + .num_bits_per_sub_vector(8) + .metric(MetricType::L2) + .max_iters(10) + .sample_rate(64) + .build(); + let files: Vec = paths + .iter() + .map(|p| ParquetFileSpec::of(p.to_str().unwrap())) + .collect(); + let uuid = ExternalIvfPqIndex::build(files, "vec", tmp_out.path().to_str().unwrap(), params) + .await + .expect("build"); + let idx_dir = tmp_out.path().join(uuid.to_string()); + let idx = ExternalIvfPqIndex::open(idx_dir.to_str().unwrap()) + .await + .expect("open"); + + assert_eq!(idx.num_files(), NUM_FILES); + assert_eq!(idx.num_partitions(), NUM_PARTITIONS); + assert_eq!(idx.vector_column(), "vec"); + + // ---- Stage 3: search recall ------------------------------------------------ + let mut rng = StdRng::seed_from_u64(7); + let mut hit_queries = 0; + let total_queries = 16; + for _ in 0..total_queries { + let query: Vec = (0..DIM).map(|_| rng.random_range(-1.0f32..1.0)).collect(); + let truth = brute_force_topk_global(&per_file_vectors, &query, TOP_K); + let truth_set: std::collections::HashSet<(usize, u64)> = + truth.iter().map(|(f, r, _)| (*f, *r as u64)).collect(); + + let results: Vec = idx + .search( + &query, + TOP_K, + /* nprobes = */ NUM_PARTITIONS, + REFINE_FACTOR, + None, + ) + .await + .expect("search"); + + // SearchResult must point at one of the registered files. + for r in &results { + assert!( + paths.iter().any(|p| p.to_str().unwrap() == r.file_path), + "search returned unknown file: {}", + r.file_path + ); + } + + let result_set: std::collections::HashSet<(usize, u64)> = results + .iter() + .map(|r| { + let file_id = paths + .iter() + .position(|p| p.to_str().unwrap() == r.file_path) + .unwrap(); + (file_id, r.row_index) + }) + .collect(); + if result_set.intersection(&truth_set).count() >= TOP_K / 2 { + hit_queries += 1; + } + } + assert!( + hit_queries >= total_queries / 2, + "recall too low: only {hit_queries}/{total_queries} queries had ≥ K/2 correct" + ); + + // ---- Stage 4: fetch_rows materializes payload columns --------------------- + let query: Vec = (0..DIM).map(|_| rng.random_range(-1.0f32..1.0)).collect(); + let results = idx + .search(&query, TOP_K, NUM_PARTITIONS, REFINE_FACTOR, None) + .await + .unwrap(); + let row_keys: Vec = results + .iter() + .map(|r| ParquetRowKey::of(&r.file_path, r.row_index)) + .collect(); + let payload = idx + .fetch_rows(&row_keys, &["id", "name"]) + .await + .expect("fetch_rows"); + assert_eq!(payload.num_rows(), results.len()); + + // Verify id values match what we wrote per file: row r in file f → id = f*1M + r + let id_col = payload + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for (i, r) in results.iter().enumerate() { + let file_id = paths + .iter() + .position(|p| p.to_str().unwrap() == r.file_path) + .unwrap(); + let expected = (file_id as i64) * 1_000_000 + r.row_index as i64; + assert_eq!( + id_col.value(i), + expected, + "fetched payload mismatch at result {i}: rid=({}, {}), got id={}", + r.file_path, + r.row_index, + id_col.value(i) + ); + } + + // ---- Stage 5: RowFilter — drop file 0 entirely; confirm survivors exclude it + struct DropFile(String); + impl RowFilter for DropFile { + fn keep(&self, file_path: &str, _row_index: u64) -> bool { + file_path != self.0 + } + } + let filter = DropFile(paths[0].to_str().unwrap().to_string()); + let filtered = idx + .search(&query, TOP_K, NUM_PARTITIONS, REFINE_FACTOR, Some(&filter)) + .await + .expect("filtered search"); + for r in &filtered { + assert_ne!( + r.file_path, + paths[0].to_str().unwrap(), + "filtered search returned a result from the dropped file" + ); + } + + println!( + "Phase 1 e2e ✓ — {} files × {} rows; recall hits {}/{}; fetched payloads agree; \ + RowFilter drops {} survivors.", + NUM_FILES, + NUM_VECTORS_PER_FILE, + hit_queries, + total_queries, + results.len() - filtered.len() + ); + let _ = TOTAL_VECTORS; +}