diff --git a/.bumpversion.toml b/.bumpversion.toml index 43ddd048a98..4cac71e4e51 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "7.0.0-beta.15" +current_version = "7.1.0-beta.1" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(-(?P(beta|rc))\\.(?P\\d+))?" serialize = [ "{major}.{minor}.{patch}-{prerelease}.{prerelease_num}", diff --git a/AGENTS.md b/AGENTS.md index df823a8ac5e..2003d6dba10 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,6 +130,11 @@ AWS_DEFAULT_REGION=us-east-1 pytest --run-integration python/tests/test_s3_ddb.p - Indent content under MkDocs admonition directives (`!!! note`, etc.) with 4 spaces. - Proofread comments and docs for typos before committing. +## Pull Requests + +- PR titles must follow the Conventional Commits specification because `.github/workflows/pr-title.yml` validates the PR title and body with commitlint. Use prefixes like `feat:`, `fix:`, `docs:`, `perf:`, `ci:`, `test:`, `build:`, `style:`, or `chore:`; add a scope when useful. +- Before creating or updating a PR, run the lint checks for every touched language surface, even when they are expensive. For Rust changes, run `cargo fmt --all` and `cargo clippy --all --tests --benches -- -D warnings`. For Python changes, follow the environment workflow in `python/AGENTS.md` and run `uv run make lint` from `python/`. If a required lint check cannot be run, state the blocker explicitly in the PR summary. + ## Review Guidelines Contributor and maintainer attention is the most valuable resource. Less is more. diff --git a/Cargo.lock b/Cargo.lock index 64655ee60d8..6fb213092ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2848,9 +2848,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "encode_unicode" @@ -3093,7 +3093,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3189,9 +3189,9 @@ checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" +checksum = "af43fadb8a98512d547e37b4e92e0ced13e205c061b87b4623eff01d918d6968" [[package]] name = "futures-util" @@ -4321,7 +4321,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "all_asserts", "approx", @@ -4386,6 +4386,7 @@ dependencies = [ "log", "lzma-sys", "mock_instant", + "moka", "object_store", "parquet", "permutation", @@ -4421,7 +4422,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4469,7 +4470,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrayref", "paste", @@ -4478,7 +4479,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4486,6 +4487,7 @@ dependencies = [ "async-trait", "byteorder", "bytes", + "criterion", "datafusion-common", "datafusion-sql", "deepsize", @@ -4515,7 +4517,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4548,7 +4550,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4568,7 +4570,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-arith", "arrow-array", @@ -4613,7 +4615,7 @@ dependencies = [ [[package]] name = "lance-examples" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "all_asserts", "arrow", @@ -4639,7 +4641,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-arith", "arrow-array", @@ -4679,7 +4681,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "datafusion", "geo-traits", @@ -4693,7 +4695,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "approx", "arc-swap", @@ -4770,7 +4772,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-arith", @@ -4819,7 +4821,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "approx", "arrow-array", @@ -4840,7 +4842,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "async-trait", @@ -4852,7 +4854,7 @@ dependencies = [ [[package]] name = "lance-namespace-datafusion" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-schema", @@ -4868,7 +4870,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-ipc", @@ -4910,9 +4912,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4924,7 +4926,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4970,7 +4972,7 @@ dependencies = [ [[package]] name = "lance-test-macros" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "proc-macro2", "quote", @@ -4979,7 +4981,7 @@ dependencies = [ [[package]] name = "lance-testing" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-schema", @@ -4990,7 +4992,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "jieba-rs", "lindera", @@ -5001,7 +5003,7 @@ dependencies = [ [[package]] name = "lance-tools" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "clap", "lance-core", diff --git a/Cargo.toml b/Cargo.toml index 08c5f4e024a..fdb44b1342b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ resolver = "3" [workspace.package] -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" @@ -55,25 +55,25 @@ rust-version = "1.91.0" [workspace.dependencies] arc-swap = "1.7" libc = "0.2.176" -lance = { version = "=7.0.0-beta.15", path = "./rust/lance", default-features = false } -lance-arrow = { version = "=7.0.0-beta.15", path = "./rust/lance-arrow" } -lance-core = { version = "=7.0.0-beta.15", path = "./rust/lance-core" } -lance-datafusion = { version = "=7.0.0-beta.15", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=7.0.0-beta.15", path = "./rust/lance-datagen" } -lance-encoding = { version = "=7.0.0-beta.15", path = "./rust/lance-encoding" } -lance-file = { version = "=7.0.0-beta.15", path = "./rust/lance-file" } -lance-geo = { version = "=7.0.0-beta.15", path = "./rust/lance-geo" } -lance-index = { version = "=7.0.0-beta.15", path = "./rust/lance-index" } -lance-io = { version = "=7.0.0-beta.15", path = "./rust/lance-io", default-features = false } -lance-linalg = { version = "=7.0.0-beta.15", path = "./rust/lance-linalg" } -lance-namespace = { version = "=7.0.0-beta.15", path = "./rust/lance-namespace" } -lance-namespace-impls = { version = "=7.0.0-beta.15", path = "./rust/lance-namespace-impls" } +lance = { version = "=7.1.0-beta.1", path = "./rust/lance", default-features = false } +lance-arrow = { version = "=7.1.0-beta.1", path = "./rust/lance-arrow" } +lance-core = { version = "=7.1.0-beta.1", path = "./rust/lance-core" } +lance-datafusion = { version = "=7.1.0-beta.1", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=7.1.0-beta.1", path = "./rust/lance-datagen" } +lance-encoding = { version = "=7.1.0-beta.1", path = "./rust/lance-encoding" } +lance-file = { version = "=7.1.0-beta.1", path = "./rust/lance-file" } +lance-geo = { version = "=7.1.0-beta.1", path = "./rust/lance-geo" } +lance-index = { version = "=7.1.0-beta.1", path = "./rust/lance-index" } +lance-io = { version = "=7.1.0-beta.1", path = "./rust/lance-io", default-features = false } +lance-linalg = { version = "=7.1.0-beta.1", path = "./rust/lance-linalg" } +lance-namespace = { version = "=7.1.0-beta.1", path = "./rust/lance-namespace" } +lance-namespace-impls = { version = "=7.1.0-beta.1", path = "./rust/lance-namespace-impls" } lance-namespace-datafusion = { version = "=7.0.0-beta.9", path = "./rust/lance-namespace-datafusion" } -lance-namespace-reqwest-client = "0.7.5" -lance-tokenizer = { version = "=7.0.0-beta.15", path = "./rust/lance-tokenizer" } -lance-table = { version = "=7.0.0-beta.15", path = "./rust/lance-table" } -lance-test-macros = { version = "=7.0.0-beta.15", path = "./rust/lance-test-macros" } -lance-testing = { version = "=7.0.0-beta.15", path = "./rust/lance-testing" } +lance-namespace-reqwest-client = "0.7.7" +lance-tokenizer = { version = "=7.1.0-beta.1", path = "./rust/lance-tokenizer" } +lance-table = { version = "=7.1.0-beta.1", path = "./rust/lance-table" } +lance-test-macros = { version = "=7.1.0-beta.1", path = "./rust/lance-test-macros" } +lance-testing = { version = "=7.1.0-beta.1", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "58.0.0", optional = false, features = ["prettyprint"] } @@ -99,7 +99,7 @@ half = { "version" = "2.1", default-features = false, features = [ "num-traits", "std", ] } -lance-bitpacking = { version = "=7.0.0-beta.15", path = "./rust/compression/bitpacking" } +lance-bitpacking = { version = "=7.1.0-beta.1", path = "./rust/compression/bitpacking" } bitpacking = "0.9" bitvec = "1" bytes = "1.11.1" @@ -139,7 +139,7 @@ deepsize = "0.2.0" dirs = "6.0.0" either = "1.0" fst = { version = "0.4.7", features = ["levenshtein"] } -fsst = { version = "=7.0.0-beta.15", path = "./rust/compression/fsst" } +fsst = { version = "=7.1.0-beta.1", path = "./rust/compression/fsst" } futures = "0.3" geoarrow-array = "0.8" geoarrow-schema = "0.8" diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 7ef4e59666c..4112230aec5 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -3,7 +3,7 @@ name = "lance-docs" version = "0.1.0" description = "Documentation for Lance - Modern columnar data format for ML and LLMs" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10,<3.11" dependencies = [ "mkdocs>=1.5.0", "mkdocs-material>=9.4.0", diff --git a/docs/src/format/index/system/mem_wal.md b/docs/src/format/index/system/mem_wal.md index ac8696aa5cc..44515f0af28 100644 --- a/docs/src/format/index/system/mem_wal.md +++ b/docs/src/format/index/system/mem_wal.md @@ -4,6 +4,8 @@ The MemWAL Index is a system index that serves as the centralized structure for It stores configuration (shard specs, indexes to maintain), merge progress, and shard state snapshots. A table has at most one MemWAL index. +The table may be a primary-key table or an append-only table without primary-key metadata. +Primary-key-dependent lookup and deduplication semantics only apply when a primary key is defined. For the complete specification, see: diff --git a/docs/src/format/table/layout.md b/docs/src/format/table/layout.md index 46efa56a908..7b08ce0a0dd 100644 --- a/docs/src/format/table/layout.md +++ b/docs/src/format/table/layout.md @@ -20,7 +20,8 @@ A Lance dataset in its basic form stores all files within the dataset root direc data/ *.lance -- Data files containing column data _versions/ - *.manifest -- Manifest files (one per version) + *.manifest -- Manifest files (one per version) + latest_version_hint.json -- Optional hint of the latest version (see below) _transactions/ *.txn -- Transaction files for commit coordination _deletions/ @@ -201,3 +202,15 @@ Manifest files are stored in the `_versions/` directory with naming schemes that See [Manifest Naming Schemes](transaction.md#manifest-naming-schemes) for details on the V1 and V2 patterns and their implications for version discovery. +### Version Hint + +The optional file `_versions/latest_version_hint.json` records the latest committed version as JSON: + +```json +{"version": 42} +``` + +It exists to accelerate latest-version discovery on stores where listing `_versions/` is expensive: a reader can read the hint and probe higher versions with HEAD requests instead of listing the whole directory, falling back to a full listing if the hint is missing or stale. + +The hint is purely an optimization. It is always safe to delete, never affects correctness, and can be ignored by readers that don't understand it. Writers may choose not to write it. + diff --git a/docs/src/format/table/mem_wal.md b/docs/src/format/table/mem_wal.md index 892b38d1d01..8a228721123 100644 --- a/docs/src/format/table/mem_wal.md +++ b/docs/src/format/table/mem_wal.md @@ -8,7 +8,9 @@ scan, point lookup, vector search and full-text search. ![MemWAL Overview](../../images/mem_wal_overview.png) A Lance table is called a **base table** under the context of the MemWAL spec. -It must have an [unenforced primary key](index.md#unenforced-primary-key) defined in the table schema. +It may have an [unenforced primary key](index.md#unenforced-primary-key) defined in the table schema. +Primary keys are required for primary-key lookups and last-write-wins upsert semantics, +but append-only MemWAL tables may omit them. On top of the base table, the MemWAL spec defines a set of shards. Writers write to shards, and data in each shard is merged into the base table asynchronously. @@ -22,7 +24,7 @@ Each shard has exactly one active writer at any time. Writers claim a shard and then write data to that shard. Data in each shard is expected to be merged into the base table asynchronously. -Rows of the same primary key must be written to one and only one shard. +For tables with a primary key, rows of the same primary key must be written to one and only one shard. If two shards contain rows with the same primary key, the following scenario can cause data corruption: 1. Shard A receives a write with primary key `pk=1` at time T1 @@ -34,6 +36,8 @@ If two shards contain rows with the same primary key, the following scenario can This violates the expected "last write wins" semantics. By ensuring each primary key is assigned to exactly one shard via the sharding spec, merge order between shards becomes irrelevant for correctness. +Append-only tables without a primary key do not rely on last-write-wins conflict resolution +and may shard by any deterministic append key or partitioning column. See [MemWAL Shard Architecture](#shard-architecture) for the complete shard architecture. @@ -162,8 +166,9 @@ The content within the generation directory follows the [Lance table storage lay Generation numbers determine merge order of flushed MemTable into base table: lower numbers represent older data and must be merged to the base table first to preserve correct upsert semantics. -Within a single flushed MemTable, if there are multiple rows of the same primary key, -the row that is last inserted wins. +Within a single flushed MemTable for a primary-key table, +if there are multiple rows of the same primary key, the row that is last inserted wins. +Append-only tables without a primary key retain all inserted rows. ### Shard Manifest @@ -479,7 +484,8 @@ The garbage collector removes obsolete data from shard directories. Flushed MemT ### LSM Tree Merging Read -Readers **MUST** merge results from multiple data sources (base table, flushed MemTables, in-memory MemTables) by primary key to ensure correctness. +For tables with a primary key, readers **MUST** merge results from multiple data sources +(base table, flushed MemTables, in-memory MemTables) by primary key to ensure correctness. When the same primary key exists in multiple sources, the reader must keep only the newest version based on: @@ -496,6 +502,10 @@ This deduplication is essential because: Without proper merging, queries would return duplicate or stale rows. +Append-only tables without a primary key do not perform primary-key deduplication. +Readers should include the relevant base table, flushed MemTables, and in-memory MemTables +according to the requested consistency level; duplicate values are treated as distinct appended rows. + ### Reader Consistency Reader consistency depends on two factors: @@ -526,7 +536,9 @@ Datasets come from: Each dataset is tagged with a generation number: 0 for the base table, and positive integers for MemTable generations. Within a shard, the generation number determines data freshness, with higher numbers representing newer data. -Rows from different shards do not need deduplication since each primary key maps to exactly one shard. +For primary-key tables, rows from different shards do not need deduplication +since each primary key maps to exactly one shard. +Append-only tables without a primary key do not require cross-shard primary-key deduplication. The planner also collects bloom filters from each generation for staleness detection during search queries. diff --git a/docs/src/guide/performance.md b/docs/src/guide/performance.md index a25e6a1584d..5590e3ff571 100644 --- a/docs/src/guide/performance.md +++ b/docs/src/guide/performance.md @@ -21,9 +21,14 @@ The Python/Java logger can be configured with several environment variables: ## Trace Events -Lance uses tracing to log events. If you are running `pylance` then these events will be emitted to +Lance uses tracing to log events. If you are running `pylance` then these events will be emitted as log messages. For Rust connections you can use the `tracing` crate to capture these events. +Rust tracing targets are listed below. In `pylance` logs, trace events are emitted under a +`lance::events::` prefix so they can be filtered separately from normal log records. For example, +`LANCE_LOG="warn,lance::events::object_store::throttle=info"` shows storage throttling events +without enabling other Lance event logs. + ### File Audit File audit events are emitted when significant files are created or deleted. @@ -33,6 +38,31 @@ File audit events are emitted when significant files are created or deleted. | `lance::file_audit` | `mode` | The mode of I/O operation (create, delete, delete_unverified) | | `lance::file_audit` | `type` | The type of file affected (manifest, data file, index file, deletion file) | +### Dataset Events + +Dataset events are emitted when datasets are loaded, written, committed, deleted, compacted, or cleaned. + +| Event | Parameter | Description | +| ----------------------- | ----------- | ------------------------------------------------------------------------- | +| `lance::dataset_events` | `event` | The dataset event type (loading, writing, committed, deleting, and others) | +| `lance::dataset_events` | `uri` | The dataset URI | +| `lance::dataset_events` | `mode` | The write mode | +| `lance::dataset_events` | `operation` | The committed transaction operation | +| `lance::dataset_events` | `predicate` | The delete predicate | +| `lance::dataset_events` | `columns` | The removed columns | + +### Object Store Throttle Events + +Object store throttle events are emitted when Lance observes cloud storage throttle responses and +reduces or retries request rates. + +| Event | Parameter | Description | +| -------------------------------- | --------------- | ---------------------------------------- | +| `lance::object_store::throttle` | `previous_rate` | The request rate before AIMD adjustment | +| `lance::object_store::throttle` | `new_rate` | The request rate after AIMD adjustment | +| `lance::object_store::throttle` | `attempt` | The retry attempt for retry debug events | +| `lance::object_store::throttle` | `error` | The underlying object store throttle error | + ### I/O Events I/O events are emitted when significant I/O operations are performed, particularly diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index 61d7f1d7987..0f40cdeb3d7 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -2344,9 +2344,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "encoding_rs" @@ -2509,7 +2509,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3616,7 +3616,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arc-swap", "arrow", @@ -3663,7 +3663,9 @@ dependencies = [ "lance-table", "lance-tokenizer", "log", + "moka", "object_store", + "parquet", "permutation", "pin-project", "prost", @@ -3686,7 +3688,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -3706,7 +3708,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrayref", "paste", @@ -3715,7 +3717,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -3750,7 +3752,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -3782,7 +3784,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -3800,7 +3802,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-arith", "arrow-array", @@ -3835,7 +3837,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-arith", "arrow-array", @@ -3866,7 +3868,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "datafusion", "geo-traits", @@ -3880,7 +3882,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arc-swap", "arrow", @@ -3947,7 +3949,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-arith", @@ -3989,7 +3991,7 @@ dependencies = [ [[package]] name = "lance-jni" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4025,7 +4027,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4041,7 +4043,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "async-trait", @@ -4053,7 +4055,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-ipc", @@ -4083,9 +4085,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4097,7 +4099,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4134,7 +4136,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "rust-stemmers", "serde", diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 1cbee8d10e1..434159a57ef 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lance-jni" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" edition = "2024" authors = ["Lance Devs "] rust-version = "1.91" diff --git a/java/lance-jni/src/external_index.rs b/java/lance-jni/src/external_index.rs new file mode 100644 index 00000000000..97cc011f81f --- /dev/null +++ b/java/lance-jni/src/external_index.rs @@ -0,0 +1,437 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! JNI bindings for [`lance::index::vector::external::ExternalIvfPqIndex`]. +//! +//! Java surface (corresponds to `org.lance.index.external.ExternalIvfPqIndex`): +//! +//! - `nativeBuild(...) -> String` (returns the index UUID directory name) +//! - `nativeOpen(uri) -> long` (opaque handle) +//! - `nativeClose(handle)` +//! - `nativeSearch(handle, query, k, nprobes, refineFactor, deletedBitmap) -> SearchResult[]` +//! - `nativeFetchRows(handle, filePaths, rowIndices, projection) -> Arrow IPC bytes` +//! +//! `SearchResult[]` is returned via Java object construction. `nativeFetchRows` +//! returns Arrow IPC stream bytes so the caller can decode with their preferred +//! Arrow Java reader without us needing to bridge `RecordBatch` directly. +//! +//! The RowFilter API is exposed as an optional `byte[] deletedBitmap` (LE +//! little-endian Roaring bitmap). Rows whose `(file_id << 32) | row_index` is +//! present are dropped during refinement. This sidesteps cross-language +//! callbacks for Phase 1; the trait-based `RowFilter` stays available for Rust +//! callers. + +use std::sync::Arc; + +use arrow::ipc::writer::StreamWriter; +use jni::JNIEnv; +use jni::objects::{JByteArray, JClass, JObject, JObjectArray, JString, JValue}; +use jni::sys::{jbyteArray, jfloat, jint, jlong, jlongArray}; +use lance::index::vector::external::{ + ExternalIvfPqIndex, ExternalIvfPqIndexParams, ParquetFileSpec, ParquetRowKey, RowFilter, + SearchResult, +}; +use lance_linalg::distance::MetricType; + +use crate::error::{Error, Result}; +use crate::traits::FromJString; +use crate::RT; + +/// RowFilter implementation that holds a sorted list of deleted rids. Caller +/// passes the rids as a packed `(file_id << 32) | row_index` u64 array. We +/// resolve the file_id via the index manifest at search time. +struct DeletedRidFilter { + /// Sorted, deduped deleted rids. Binary search per refinement candidate. + deleted: Vec, + /// Path → file_id index for fast lookup. + file_id_by_path: std::collections::HashMap, +} + +impl RowFilter for DeletedRidFilter { + fn keep(&self, file_path: &str, row_index: u64) -> bool { + let Some(&file_id) = self.file_id_by_path.get(file_path) else { + // unknown file → keep (search would have errored before refinement + // on an unknown file anyway) + return true; + }; + let rid = ((file_id as u64) << 32) | row_index; + self.deleted.binary_search(&rid).is_err() + } +} + +/// Build a non-empty filter from a Java `byte[]` of u64-LE deleted rids. +/// Returns `Ok(None)` when the byte array is null or empty. +fn build_filter_from_bytes( + env: &mut JNIEnv, + deleted_bytes: &JByteArray, + file_id_by_path: std::collections::HashMap, +) -> Result> { + if deleted_bytes.is_null() { + return Ok(None); + } + let bytes = env.convert_byte_array(deleted_bytes)?; + if bytes.is_empty() { + return Ok(None); + } + if bytes.len() % 8 != 0 { + return Err(Error::input_error(format!( + "deletedBitmap byte length {} is not a multiple of 8 (u64 LE)", + bytes.len() + ))); + } + let mut deleted: Vec = Vec::with_capacity(bytes.len() / 8); + for chunk in bytes.chunks_exact(8) { + let arr: [u8; 8] = chunk.try_into().expect("chunks_exact 8"); + deleted.push(u64::from_le_bytes(arr)); + } + deleted.sort_unstable(); + deleted.dedup(); + Ok(Some(DeletedRidFilter { + deleted, + file_id_by_path, + })) +} + +fn parse_metric(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "l2" => Ok(MetricType::L2), + "cosine" => Ok(MetricType::Cosine), + "dot" => Ok(MetricType::Dot), + other => Err(Error::input_error(format!( + "unsupported metric '{other}'; expected one of L2, Cosine, Dot" + ))), + } +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeBuild<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + file_paths: JObjectArray<'local>, + vector_column: JString<'local>, + output_uri: JString<'local>, + num_partitions: jint, + num_sub_vectors: jint, + num_bits_per_sub_vector: jint, + metric: JString<'local>, + max_iters: jint, + sample_rate: jint, + seed: jlong, +) -> JObject<'local> { + ok_or_throw!( + env, + inner_build( + &mut env, + file_paths, + vector_column, + output_uri, + num_partitions, + num_sub_vectors, + num_bits_per_sub_vector, + metric, + max_iters, + sample_rate, + seed, + ) + ) +} + +#[allow(clippy::too_many_arguments)] +fn inner_build<'local>( + env: &mut JNIEnv<'local>, + file_paths: JObjectArray<'local>, + vector_column: JString<'local>, + output_uri: JString<'local>, + num_partitions: jint, + num_sub_vectors: jint, + num_bits_per_sub_vector: jint, + metric: JString<'local>, + max_iters: jint, + sample_rate: jint, + seed: jlong, +) -> Result> { + let n = env.get_array_length(&file_paths)?; + let mut files: Vec = Vec::with_capacity(n as usize); + for i in 0..n { + let elem: JString = env.get_object_array_element(&file_paths, i)?.into(); + let path: String = elem.extract(env)?; + files.push(ParquetFileSpec::of(path)); + } + + let vector_column_str: String = vector_column.extract(env)?; + let output_uri_str: String = output_uri.extract(env)?; + let metric_str: String = metric.extract(env)?; + + let params = ExternalIvfPqIndexParams::builder() + .num_partitions(num_partitions as usize) + .num_sub_vectors(num_sub_vectors as usize) + .num_bits_per_sub_vector(num_bits_per_sub_vector as usize) + .metric(parse_metric(&metric_str)?) + .max_iters(max_iters as usize) + .sample_rate(sample_rate as usize) + .seed(seed as u64) + .build(); + + let uuid = RT.block_on(async move { + ExternalIvfPqIndex::build(files, &vector_column_str, &output_uri_str, params).await + })?; + + let uuid_str = uuid.to_string(); + let j = env.new_string(&uuid_str)?; + Ok(j.into()) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeOpen<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + uri: JString<'local>, +) -> jlong { + ok_or_throw_with_return!(env, inner_open(&mut env, uri), 0i64) +} + +fn inner_open(env: &mut JNIEnv, uri: JString) -> Result { + let uri_str: String = uri.extract(env)?; + let idx = RT.block_on(async move { ExternalIvfPqIndex::open(&uri_str).await })?; + let boxed = Box::new(Arc::new(idx)); + Ok(Box::into_raw(boxed) as jlong) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeClose( + _env: JNIEnv, + _class: JClass, + handle: jlong, +) { + if handle == 0 { + return; + } + // SAFETY: handle came from a previous Box::into_raw on an Arc; + // taking it back into a Box drops the Arc. + unsafe { + let _ = Box::from_raw(handle as *mut Arc); + } +} + +fn handle_to_idx(handle: jlong) -> Result> { + if handle == 0 { + return Err(Error::input_error( + "ExternalIvfPqIndex handle is null (closed?)".to_string(), + )); + } + // SAFETY: handle is a non-null pointer from Box::into_raw; cloning the Arc + // inside doesn't take ownership. + let arc_ref = unsafe { &*(handle as *const Arc) }; + Ok(arc_ref.clone()) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeSearch<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + query: jni::objects::JFloatArray<'local>, + k: jint, + nprobes: jint, + refine_factor: jint, + deleted_bitmap: JByteArray<'local>, +) -> JObject<'local> { + ok_or_throw!( + env, + inner_search( + &mut env, + handle, + query, + k, + nprobes, + refine_factor, + deleted_bitmap, + ) + ) +} + +fn inner_search<'local>( + env: &mut JNIEnv<'local>, + handle: jlong, + query: jni::objects::JFloatArray<'local>, + k: jint, + nprobes: jint, + refine_factor: jint, + deleted_bitmap: JByteArray<'local>, +) -> Result> { + let idx = handle_to_idx(handle)?; + let q_len = env.get_array_length(&query)?; + let mut q_buf: Vec = vec![0.0; q_len as usize]; + env.get_float_array_region(&query, 0, &mut q_buf)?; + + // Build path → file_id map for the optional filter. + let mut file_id_by_path: std::collections::HashMap = + std::collections::HashMap::with_capacity(idx.num_files()); + for fid in 0..idx.num_files() as u32 { + if let Some(p) = idx.file_path(fid) { + file_id_by_path.insert(p.to_string(), fid); + } + } + let filter = build_filter_from_bytes(env, &deleted_bitmap, file_id_by_path)?; + + let results: Vec = RT.block_on(async { + idx.search( + &q_buf, + k as usize, + nprobes as usize, + refine_factor as usize, + filter.as_ref().map(|f| f as &dyn RowFilter), + ) + .await + })?; + + // Build SearchResult[] in Java. + let result_class = env.find_class("org/lance/index/external/SearchResult")?; + let array = env.new_object_array(results.len() as i32, &result_class, JObject::null())?; + for (i, r) in results.iter().enumerate() { + let path = env.new_string(&r.file_path)?; + let obj = env.new_object( + &result_class, + "(Ljava/lang/String;JF)V", + &[ + JValue::Object(&path), + JValue::Long(r.row_index as i64), + JValue::Float(r.distance), + ], + )?; + env.set_object_array_element(&array, i as i32, obj)?; + } + Ok(array.into()) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeFetchRows<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + file_paths: JObjectArray<'local>, + row_indices: jlongArray, + projection: JObjectArray<'local>, +) -> jbyteArray { + ok_or_throw_with_return!( + env, + inner_fetch_rows(&mut env, handle, file_paths, row_indices, projection), + std::ptr::null_mut::() as jbyteArray + ) +} + +fn inner_fetch_rows<'local>( + env: &mut JNIEnv<'local>, + handle: jlong, + file_paths: JObjectArray<'local>, + row_indices: jlongArray, + projection: JObjectArray<'local>, +) -> Result { + let idx = handle_to_idx(handle)?; + + // Decode file_paths + let n_paths = env.get_array_length(&file_paths)?; + let mut paths: Vec = Vec::with_capacity(n_paths as usize); + for i in 0..n_paths { + let elem: JString = env.get_object_array_element(&file_paths, i)?.into(); + let path: String = elem.extract(env)?; + paths.push(path); + } + + // Decode row_indices + let row_indices_obj = unsafe { jni::objects::JLongArray::from_raw(row_indices) }; + let n_rows = env.get_array_length(&row_indices_obj)?; + if n_rows != n_paths { + return Err(Error::input_error(format!( + "fetchRows: row_indices length {n_rows} != file_paths length {n_paths}" + ))); + } + let mut rids_i64: Vec = vec![0; n_rows as usize]; + env.get_long_array_region(&row_indices_obj, 0, &mut rids_i64)?; + + let row_keys: Vec = paths + .into_iter() + .zip(rids_i64.iter().map(|&v| v as u64)) + .map(|(p, r)| ParquetRowKey::of(p, r)) + .collect(); + + // Decode projection + let n_proj = env.get_array_length(&projection)?; + let mut proj_strings: Vec = Vec::with_capacity(n_proj as usize); + for i in 0..n_proj { + let elem: JString = env.get_object_array_element(&projection, i)?.into(); + let s: String = elem.extract(env)?; + proj_strings.push(s); + } + let proj_refs: Vec<&str> = proj_strings.iter().map(|s| s.as_str()).collect(); + + // Run fetch + let batch = RT.block_on(async { idx.fetch_rows(&row_keys, &proj_refs).await })?; + + // Serialize to Arrow IPC stream bytes. + let mut buf: Vec = Vec::with_capacity(8 * 1024); + { + let mut writer = StreamWriter::try_new(&mut buf, &batch.schema()) + .map_err(|e| Error::io_error(format!("ipc writer init: {e}")))?; + writer + .write(&batch) + .map_err(|e| Error::io_error(format!("ipc write: {e}")))?; + writer + .finish() + .map_err(|e| Error::io_error(format!("ipc finish: {e}")))?; + } + let jbyte_slice: &[jni::sys::jbyte] = unsafe { + std::slice::from_raw_parts(buf.as_ptr() as *const jni::sys::jbyte, buf.len()) + }; + let array = env.new_byte_array(buf.len() as i32)?; + env.set_byte_array_region(&array, 0, jbyte_slice)?; + Ok(array.into_raw()) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeNumPartitions( + _env: JNIEnv, + _class: JClass, + handle: jlong, +) -> jint { + let idx = match handle_to_idx(handle) { + Ok(i) => i, + Err(_) => return -1, + }; + idx.num_partitions() as jint +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeNumFiles( + _env: JNIEnv, + _class: JClass, + handle: jlong, +) -> jint { + let idx = match handle_to_idx(handle) { + Ok(i) => i, + Err(_) => return -1, + }; + idx.num_files() as jint +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_index_external_ExternalIvfPqIndex_nativeVectorColumn<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, +) -> JObject<'local> { + let idx = match handle_to_idx(handle) { + Ok(i) => i, + Err(e) => { + e.throw(&mut env); + return JObject::null(); + } + }; + match env.new_string(idx.vector_column()) { + Ok(s) => s.into(), + Err(e) => { + Error::runtime_error(format!("new_string: {e}")).throw(&mut env); + JObject::null() + } + } +} diff --git a/java/lance-jni/src/fragment.rs b/java/lance-jni/src/fragment.rs index 5e2812ddd75..a6798c2f237 100644 --- a/java/lance-jni/src/fragment.rs +++ b/java/lance-jni/src/fragment.rs @@ -101,6 +101,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiArray<'local>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -120,6 +123,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiArray<'local>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -142,6 +148,9 @@ fn inner_create_with_ffi_array<'local>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -170,6 +179,9 @@ fn inner_create_with_ffi_array<'local>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -191,6 +203,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiStream<'a>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -209,6 +224,9 @@ pub extern "system" fn Java_org_lance_Fragment_createWithFfiStream<'a>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -230,6 +248,9 @@ fn inner_create_with_ffi_stream<'local>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -248,6 +269,9 @@ fn inner_create_with_ffi_stream<'local>( enable_stable_row_ids, data_storage_version, storage_options_obj, + base_store_params_obj, + initial_bases, + target_bases, namespace_obj, table_id_obj, allow_external_blob_outside_bases, @@ -267,6 +291,9 @@ fn create_fragment<'a>( enable_stable_row_ids: JObject, // Optional data_storage_version: JObject, // Optional storage_options_obj: JObject, // Map + base_store_params_obj: JObject, // Map> + initial_bases: JObject, // Optional> + target_bases: JObject, // Optional> namespace_obj: JObject, // LanceNamespace (can be null) table_id_obj: JObject, // List (can be null) allow_external_blob_outside_bases: JObject, // Optional @@ -285,9 +312,9 @@ fn create_fragment<'a>( &data_storage_version, None, &storage_options_obj, - &JObject::null(), // base store params are not used when creating fragments - &JObject::null(), // not used when creating fragments - &JObject::null(), // not used when creating fragments + &base_store_params_obj, + &initial_bases, + &target_bases, &allow_external_blob_outside_bases, &blob_pack_file_size_threshold, )?; diff --git a/java/lance-jni/src/lib.rs b/java/lance-jni/src/lib.rs index 622ce5e8785..dc0221b30bc 100644 --- a/java/lance-jni/src/lib.rs +++ b/java/lance-jni/src/lib.rs @@ -46,6 +46,7 @@ mod blocking_scanner; mod delta; mod dispatcher; pub mod error; +mod external_index; pub mod ffi; mod file_reader; mod file_writer; diff --git a/java/lance-jni/src/mem_wal.rs b/java/lance-jni/src/mem_wal.rs index d5f9da750ab..f62b0274c8e 100644 --- a/java/lance-jni/src/mem_wal.rs +++ b/java/lance-jni/src/mem_wal.rs @@ -7,6 +7,7 @@ //! [`ShardWriter`], an LSM-aware [`LsmScanner`], an [`ExecutionPlan`] wrapper, //! and the point-lookup / vector-search planners. +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -21,7 +22,7 @@ use datafusion::common::ScalarValue; use datafusion::physical_plan::{ExecutionPlan, collect, displayable}; use datafusion::prelude::SessionContext; use jni::JNIEnv; -use jni::objects::{JClass, JObject, JString, JValueGen}; +use jni::objects::{JClass, JMap, JObject, JString, JValueGen}; use jni::sys::{jint, jlong}; use lance::dataset::Dataset as LanceDataset; use lance::dataset::mem_wal::scanner::{ @@ -30,6 +31,7 @@ use lance::dataset::mem_wal::scanner::{ use lance::dataset::mem_wal::write::{MemTableStats, WriteStatsSnapshot}; use lance::dataset::mem_wal::{ DatasetMemWalExt, LsmScanner, ShardSnapshot, ShardWriter, ShardWriterConfig, + evaluate_sharding_spec_with_source_columns, }; use lance::dataset::scanner::DatasetRecordBatchStream; use lance_index::mem_wal::{MemWalIndexDetails, ShardManifest, ShardingField, ShardingSpec}; @@ -42,6 +44,7 @@ use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; use crate::traits::{IntoJava, export_vec, import_vec, import_vec_to_rust}; +use crate::utils::to_rust_map; const NATIVE_SHARD_WRITER: &str = "nativeShardWriterHandle"; const NATIVE_LSM_SCANNER: &str = "nativeLsmScannerHandle"; @@ -599,6 +602,63 @@ fn inner_plan_open_stream(env: &mut JNIEnv, this: JObject, stream_addr: jlong) - Ok(()) } +////////////////////////// +// Sharding evaluation // +////////////////////////// + +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_memwal_ShardingEvaluator_nativeEvaluate<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + arrow_array_addr: jlong, + arrow_schema_addr: jlong, + sharding_spec: JObject<'local>, + source_id_to_column: JObject<'local>, + stream_addr: jlong, +) { + ok_or_throw_without_return!( + env, + inner_evaluate_sharding( + &mut env, + arrow_array_addr, + arrow_schema_addr, + sharding_spec, + source_id_to_column, + stream_addr, + ) + ); +} + +fn inner_evaluate_sharding( + env: &mut JNIEnv, + arrow_array_addr: jlong, + arrow_schema_addr: jlong, + sharding_spec: JObject, + source_id_to_column: JObject, + stream_addr: jlong, +) -> Result<()> { + let input = import_ffi_array(arrow_array_addr, arrow_schema_addr)?; + let struct_array = input + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::input_error( + "ShardingEvaluator expects a VectorSchemaRoot struct array".to_string(), + ) + })?; + let input_batch = RecordBatch::from(struct_array.clone()); + let spec = sharding_spec_from_java(env, &sharding_spec)?; + let source_id_to_column = i32_string_map_from_java(env, source_id_to_column)?; + let result = + evaluate_sharding_spec_with_source_columns(&input_batch, &spec, &source_id_to_column)?; + let schema = result.schema(); + let reader: Box = + Box::new(RecordBatchIterator::new([Ok(result)].into_iter(), schema)); + let ffi_stream = FFI_ArrowArrayStream::new(reader); + unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } + Ok(()) +} + #[unsafe(no_mangle)] pub extern "system" fn Java_org_lance_memwal_ExecutionPlan_releaseNativeExecutionPlan( mut env: JNIEnv, @@ -767,14 +827,15 @@ fn inner_create_vector_planner( let dist_type = parse_distance_type(distance_type.as_deref().unwrap_or("l2"))?; let vector_dim = get_vector_dim(&dataset, &vector_column)?; - let collector = LsmDataSourceCollector::new(dataset, snapshots); + let collector = LsmDataSourceCollector::new(dataset.clone(), snapshots); let planner = LsmVectorSearchPlanner::new( collector, pk_columns, base_schema.clone(), vector_column, dist_type, - ); + ) + .with_dataset(dataset); let blocking = BlockingLsmVectorSearchPlanner { planner, @@ -794,10 +855,25 @@ pub extern "system" fn Java_org_lance_memwal_LsmVectorSearchPlanner_nativePlanSe k: jint, nprobes: jint, columns: JObject<'local>, + refine_factor: jint, ) -> JObject<'local> { + let refine = if refine_factor > 0 { + Some(refine_factor as u32) + } else { + None + }; ok_or_throw!( env, - inner_plan_search(&mut env, this, array_addr, schema_addr, k, nprobes, columns) + inner_plan_search( + &mut env, + this, + array_addr, + schema_addr, + k, + nprobes, + columns, + refine + ) ) } @@ -810,6 +886,7 @@ fn inner_plan_search<'local>( k: jint, nprobes: jint, columns: JObject<'local>, + refine_factor: Option, ) -> Result> { let query = import_ffi_array(array_addr, schema_addr)?; let columns = env.get_strings_opt(&columns)?; @@ -841,6 +918,7 @@ fn inner_plan_search<'local>( k as usize, nprobes as usize, columns.as_deref(), + refine_factor, ))?; (plan, guard.dataset_schema.clone()) }; @@ -1301,6 +1379,87 @@ fn sharding_field_to_java<'a>(env: &mut JNIEnv<'a>, field: &ShardingField) -> Re )?) } +fn sharding_spec_from_java(env: &mut JNIEnv, spec: &JObject) -> Result { + if spec.is_null() { + return Err(Error::input_error( + "ShardingSpec must not be null".to_string(), + )); + } + + env.with_local_frame(32, |env| { + let spec_id = env.call_method(spec, "specId", "()I", &[])?.i()? as u32; + let fields_obj = env + .call_method(spec, "fields", "()Ljava/util/List;", &[])? + .l()?; + let fields_list = env.get_list(&fields_obj)?; + let mut iter = fields_list.iter(env)?; + let mut fields = Vec::with_capacity(fields_list.size(env)? as usize); + while let Some(field_obj) = iter.next(env)? { + fields.push(sharding_field_from_java(env, &field_obj)?); + } + + Ok::<_, Error>(ShardingSpec { spec_id, fields }) + }) +} + +fn sharding_field_from_java(env: &mut JNIEnv, field: &JObject) -> Result { + env.with_local_frame(32, |env| { + let field_id = string_from_method(env, field, "fieldId")?; + let source_ids_obj = env + .call_method(field, "sourceIds", "()Ljava/util/List;", &[])? + .l()?; + let source_ids = env.get_integers(&source_ids_obj)?; + let transform = env.get_optional_string_from_method(field, "transform")?; + let expression = env.get_optional_string_from_method(field, "expression")?; + let result_type = string_from_method(env, field, "resultType")?; + let parameters_obj = env + .call_method(field, "parameters", "()Ljava/util/Map;", &[])? + .l()?; + let parameters = if parameters_obj.is_null() { + HashMap::new() + } else { + let parameters_map = JMap::from_env(env, ¶meters_obj)?; + to_rust_map(env, ¶meters_map)? + }; + + Ok::<_, Error>(ShardingField { + field_id, + source_ids, + transform, + expression, + result_type, + parameters, + }) + }) +} + +fn i32_string_map_from_java(env: &mut JNIEnv, map_obj: JObject) -> Result> { + if map_obj.is_null() { + return Ok(HashMap::new()); + } + + env.with_local_frame(32, |env| { + let jmap = JMap::from_env(env, &map_obj)?; + let mut iter = jmap.iter(env)?; + let mut map = HashMap::new(); + while let Some((key, value)) = iter.next(env)? { + let key = env.call_method(&key, "intValue", "()I", &[])?.i()?; + let value = JString::from(value); + let value: String = env.get_string(&value)?.into(); + map.insert(key, value); + } + Ok::<_, Error>(map) + }) +} + +fn string_from_method(env: &mut JNIEnv, obj: &JObject, method: &str) -> Result { + let value = env + .call_method(obj, method, "()Ljava/lang/String;", &[])? + .l()?; + let value = JString::from(value); + Ok(env.get_string(&value)?.into()) +} + fn int_list_to_java<'a>(env: &mut JNIEnv<'a>, ints: &[i32]) -> Result> { let list = env.new_object("java/util/ArrayList", "()V", &[])?; for &value in ints { diff --git a/java/lance-jni/src/namespace.rs b/java/lance-jni/src/namespace.rs index 39acefa83c4..f0da7ff79ae 100644 --- a/java/lance-jni/src/namespace.rs +++ b/java/lance-jni/src/namespace.rs @@ -2432,6 +2432,23 @@ pub extern "system" fn Java_org_lance_namespace_DirectoryNamespace_updateTableTa .into_raw() } +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_namespace_DirectoryNamespace_createMaterializedViewNative( + mut env: JNIEnv, + _obj: JObject, + handle: jlong, + request_json: JString, +) -> jstring { + ok_or_throw_with_return!( + env, + call_namespace_method(&mut env, handle, request_json, |namespace_client, req| { + RT.block_on(namespace_client.inner.create_materialized_view(req)) + }), + std::ptr::null_mut() + ) + .into_raw() +} + #[unsafe(no_mangle)] pub extern "system" fn Java_org_lance_namespace_DirectoryNamespace_retrieveOpsMetricsNative( mut env: JNIEnv, @@ -3375,6 +3392,23 @@ pub extern "system" fn Java_org_lance_namespace_RestNamespace_updateTableTagNati .into_raw() } +#[unsafe(no_mangle)] +pub extern "system" fn Java_org_lance_namespace_RestNamespace_createMaterializedViewNative( + mut env: JNIEnv, + _obj: JObject, + handle: jlong, + request_json: JString, +) -> jstring { + ok_or_throw_with_return!( + env, + call_rest_namespace_method(&mut env, handle, request_json, |namespace_client, req| { + RT.block_on(namespace_client.inner.create_materialized_view(req)) + }), + std::ptr::null_mut() + ) + .into_raw() +} + #[unsafe(no_mangle)] pub extern "system" fn Java_org_lance_namespace_RestNamespace_retrieveOpsMetricsNative( mut env: JNIEnv, diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index c8a63c8672b..1321e8e71e3 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -477,6 +477,7 @@ pub fn get_vector_index_params( stages, version: IndexFileVersion::V3, skip_transpose: false, + runtime_hints: Default::default(), }) }, )?; diff --git a/java/pom.xml b/java/pom.xml index 9b44c467605..4fcb5e2a9de 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -7,7 +7,7 @@ org.lance lance-core Lance Core - 7.0.0-beta.15 + 7.1.0-beta.1 jar Lance Format Java API @@ -109,12 +109,12 @@ org.lance lance-namespace-core - 0.7.5 + 0.7.7 org.lance lance-namespace-apache-client - 0.7.5 + 0.7.7 com.fasterxml.jackson.core diff --git a/java/src/main/java/org/lance/Dataset.java b/java/src/main/java/org/lance/Dataset.java index 1647204582b..02d61fb708b 100644 --- a/java/src/main/java/org/lance/Dataset.java +++ b/java/src/main/java/org/lance/Dataset.java @@ -2030,8 +2030,9 @@ private native MergeInsertResult nativeMergeInsert( /** * Initialize MemWAL on this dataset. * - *

Must be called once before any call to {@link #memWalWriter}. The dataset schema must have - * at least one field carrying the {@code lance-schema:unenforced-primary-key} metadata. + *

Must be called once before any call to {@link #memWalWriter}. Append-only tables may omit + * primary-key metadata; primary keys are only required for primary-key lookup and last-write-wins + * deduplication workflows. * * @param params MemWAL initialization parameters */ diff --git a/java/src/main/java/org/lance/Fragment.java b/java/src/main/java/org/lance/Fragment.java index 43091269382..b27b189bf48 100644 --- a/java/src/main/java/org/lance/Fragment.java +++ b/java/src/main/java/org/lance/Fragment.java @@ -278,6 +278,9 @@ static List create( params.getEnableStableRowIds(), params.getDataStorageVersion(), params.getStorageOptions(), + params.getBaseStoreParams(), + params.getInitialBases(), + params.getTargetBases(), namespaceClient, tableId, params.getAllowExternalBlobOutsideBases(), @@ -305,6 +308,9 @@ static List create( params.getEnableStableRowIds(), params.getDataStorageVersion(), params.getStorageOptions(), + params.getBaseStoreParams(), + params.getInitialBases(), + params.getTargetBases(), namespaceClient, tableId, params.getAllowExternalBlobOutsideBases(), @@ -323,6 +329,9 @@ private static native List createWithFfiArray( Optional enableStableRowIds, Optional dataStorageVersion, Map storageOptions, + Map> baseStoreParams, + Optional> initialBases, + Optional> targetBases, LanceNamespace namespaceClient, List tableId, Optional allowExternalBlobOutsideBases, @@ -339,6 +348,9 @@ private static native List createWithFfiStream( Optional enableStableRowIds, Optional dataStorageVersion, Map storageOptions, + Map> baseStoreParams, + Optional> initialBases, + Optional> targetBases, LanceNamespace namespaceClient, List tableId, Optional allowExternalBlobOutsideBases, diff --git a/java/src/main/java/org/lance/WriteFragmentBuilder.java b/java/src/main/java/org/lance/WriteFragmentBuilder.java index 693b5b6bc87..5d7dc1a42b2 100644 --- a/java/src/main/java/org/lance/WriteFragmentBuilder.java +++ b/java/src/main/java/org/lance/WriteFragmentBuilder.java @@ -123,6 +123,22 @@ public WriteFragmentBuilder storageOptions(Map storageOptions) { return this; } + /** + * Set runtime-only object store parameters for registered base paths. + * + *

Entries are keyed by the exact {@link BasePath#getPath()} value persisted in the manifest. + * Each value is the storage options map used for that base. Bases without an explicit entry use + * {@link #storageOptions(Map)} as the fallback. + * + * @param baseStoreParams object store parameters keyed by base path URI + * @return this builder + */ + public WriteFragmentBuilder baseStoreParams(Map> baseStoreParams) { + ensureWriteParamsBuilder(); + this.writeParamsBuilder.withBaseStoreParams(baseStoreParams); + return this; + } + /** * Set the namespace client for automatic credential refresh. * @@ -223,6 +239,30 @@ public WriteFragmentBuilder dataStorageVersion(String version) { return this; } + /** + * Register base paths when creating a new dataset from fragments. + * + * @param bases base paths to register + * @return this builder + */ + public WriteFragmentBuilder initialBases(List bases) { + ensureWriteParamsBuilder(); + this.writeParamsBuilder.withInitialBases(bases); + return this; + } + + /** + * Set base names or paths where new fragment files should be written. + * + * @param targetBases base names or exact paths + * @return this builder + */ + public WriteFragmentBuilder targetBases(List targetBases) { + ensureWriteParamsBuilder(); + this.writeParamsBuilder.withTargetBases(targetBases); + return this; + } + /** * Execute the fragment write operation. * diff --git a/java/src/main/java/org/lance/index/external/ExternalIvfPqIndex.java b/java/src/main/java/org/lance/index/external/ExternalIvfPqIndex.java new file mode 100644 index 00000000000..ca4f728f32b --- /dev/null +++ b/java/src/main/java/org/lance/index/external/ExternalIvfPqIndex.java @@ -0,0 +1,209 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import org.lance.JniLoader; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; + +/** + * Java handle for an IVF-PQ index built over caller-registered parquet files. + * + *

Lance owns the parquet reader and refinement; callers see {@link SearchResult} as {@code + * (filePath, rowIndex, distance)} and never have to decode an internal rid. + * + *

Lifecycle: {@link #build} writes an index file under an output URI, returning a UUID that + * names the index directory; {@link #open} returns an open handle. {@link #close} releases the + * handle. The handle is {@link AutoCloseable}; use try-with-resources where possible. + * + *

Example: + * + *

{@code
+ * String uuid = ExternalIvfPqIndex.build(
+ *     List.of("/data/embeddings-0.parquet", "/data/embeddings-1.parquet"),
+ *     "vec",
+ *     "/index/v1",
+ *     ExternalIvfPqIndexParams.builder().numPartitions(256).numSubVectors(16).build());
+ *
+ * try (ExternalIvfPqIndex idx = ExternalIvfPqIndex.open("/index/v1/" + uuid)) {
+ *   List hits = idx.search(query, 10, 16, 8, null);
+ *   byte[] arrowIpc = idx.fetchRows(
+ *       hits.stream().map(h -> ParquetRowKey.of(h.getFilePath(), h.getRowIndex())).toList(),
+ *       List.of("doc_id", "title"));
+ * }
+ * }
+ */ +public final class ExternalIvfPqIndex implements AutoCloseable { + + static { + JniLoader.ensureLoaded(); + } + + private long handle; + + private ExternalIvfPqIndex(long handle) { + this.handle = handle; + } + + // ---- build / open --------------------------------------------------------- + + /** + * Build an external IVF-PQ index over the registered parquet files. + * + *

Writes a single index directory under {@code outputUri} named by the returned UUID. The + * directory contains {@code index.idx} (IVF model + PQ codebooks) and {@code manifest.json} + * (parquet file list + build params). + * + *

The {@code file_id} encoded into rids is implicit in {@code filePaths}'s position; + * reordering invalidates the index. + * + * @return UUID directory name (not the full URI); join with {@code outputUri} to get the open + * URI. + */ + public static String build( + List filePaths, + String vectorColumn, + String outputUri, + ExternalIvfPqIndexParams params) { + String[] paths = filePaths.toArray(new String[0]); + return nativeBuild( + paths, + vectorColumn, + outputUri, + params.getNumPartitions(), + params.getNumSubVectors(), + params.getNumBitsPerSubVector(), + params.getMetric().toRustString(), + params.getMaxIters(), + params.getSampleRate(), + params.getSeed()); + } + + /** + * Open an external IVF-PQ index by URI. Cheap: reads {@code manifest.json} and the index file + * footer. + */ + public static ExternalIvfPqIndex open(String uri) { + long handle = nativeOpen(uri); + return new ExternalIvfPqIndex(handle); + } + + /** + * Run a vector query and return up to {@code k} refined results. + * + * @param query Query vector. Length must match the index's dimension. + * @param k Number of results to return. + * @param nprobes Number of IVF partitions to probe. + * @param refineFactor Re-rank multiplier; {@code k * refineFactor} approximate candidates are + * fetched, refined exactly, then trimmed to {@code k}. + * @param deletedRids Optional little-endian {@code u64} packed array of deleted rids encoded as + * {@code (file_id << 32) | row_index}. Survivors will exclude these. {@code null} means no + * filter. + */ + public List search( + float[] query, int k, int nprobes, int refineFactor, byte[] deletedRids) { + SearchResult[] arr = nativeSearch(handle, query, k, nprobes, refineFactor, deletedRids); + return java.util.Arrays.asList(arr); + } + + /** + * Fetch arbitrary projection columns for {@code rowKeys} from the registered parquet files. + * + *

Returns Arrow IPC stream bytes; decode with {@code ArrowStreamReader} on the Java side. The + * batch has one row per input key, in caller-input order. + * + *

Use this for post-topK materialization: pass the {@code (filePath, rowIndex)} pairs from + * {@link #search} and project only the columns you need. + */ + public byte[] fetchRows(List rowKeys, List projection) { + String[] paths = new String[rowKeys.size()]; + long[] rows = new long[rowKeys.size()]; + for (int i = 0; i < rowKeys.size(); i++) { + paths[i] = rowKeys.get(i).getFilePath(); + rows[i] = rowKeys.get(i).getRowIndex(); + } + String[] proj = projection.toArray(new String[0]); + return nativeFetchRows(handle, paths, rows, proj); + } + + /** Pack a list of {@code (file_id, row_index)} deletes into the byte format {@link #search}. */ + public static byte[] packDeletedRids(List deletedFileRowPairs) { + ByteBuffer buf = + ByteBuffer.allocate(deletedFileRowPairs.size() * Long.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (long[] pair : deletedFileRowPairs) { + if (pair.length != 2) { + throw new IllegalArgumentException( + "expected (file_id, row_index) pair; got length " + pair.length); + } + long rid = (pair[0] << 32) | (pair[1] & 0xFFFFFFFFL); + buf.putLong(rid); + } + return buf.array(); + } + + /** Number of IVF partitions. */ + public int getNumPartitions() { + return nativeNumPartitions(handle); + } + + /** Number of registered parquet files. */ + public int getNumFiles() { + return nativeNumFiles(handle); + } + + /** Vector column the index was built over. */ + public String getVectorColumn() { + return nativeVectorColumn(handle); + } + + @Override + public void close() { + if (handle != 0L) { + nativeClose(handle); + handle = 0L; + } + } + + // ---- native methods ------------------------------------------------------- + + private static native String nativeBuild( + String[] filePaths, + String vectorColumn, + String outputUri, + int numPartitions, + int numSubVectors, + int numBitsPerSubVector, + String metric, + int maxIters, + int sampleRate, + long seed); + + private static native long nativeOpen(String uri); + + private static native void nativeClose(long handle); + + private static native SearchResult[] nativeSearch( + long handle, float[] query, int k, int nprobes, int refineFactor, byte[] deletedRids); + + private static native byte[] nativeFetchRows( + long handle, String[] filePaths, long[] rowIndices, String[] projection); + + private static native int nativeNumPartitions(long handle); + + private static native int nativeNumFiles(long handle); + + private static native String nativeVectorColumn(long handle); +} diff --git a/java/src/main/java/org/lance/index/external/ExternalIvfPqIndexParams.java b/java/src/main/java/org/lance/index/external/ExternalIvfPqIndexParams.java new file mode 100644 index 00000000000..cce3f7852a8 --- /dev/null +++ b/java/src/main/java/org/lance/index/external/ExternalIvfPqIndexParams.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +/** + * Build configuration for {@link ExternalIvfPqIndex#build}. Defaults match Lance's IVF-PQ defaults: + * {@code numPartitions=256}, {@code numSubVectors=16}, {@code numBitsPerSubVector=8}, metric = + * {@link Metric#L2}. + */ +public final class ExternalIvfPqIndexParams { + + /** Distance metric used by the index. Mirrors lance::Distance. */ + public enum Metric { + L2, + Cosine, + Dot; + + String toRustString() { + switch (this) { + case L2: + return "L2"; + case Cosine: + return "Cosine"; + case Dot: + return "Dot"; + default: + throw new IllegalStateException("unknown metric: " + this); + } + } + } + + private final int numPartitions; + private final int numSubVectors; + private final int numBitsPerSubVector; + private final Metric metric; + private final int maxIters; + private final int sampleRate; + private final long seed; + + private ExternalIvfPqIndexParams(Builder b) { + this.numPartitions = b.numPartitions; + this.numSubVectors = b.numSubVectors; + this.numBitsPerSubVector = b.numBitsPerSubVector; + this.metric = b.metric; + this.maxIters = b.maxIters; + this.sampleRate = b.sampleRate; + this.seed = b.seed; + } + + public int getNumPartitions() { + return numPartitions; + } + + public int getNumSubVectors() { + return numSubVectors; + } + + public int getNumBitsPerSubVector() { + return numBitsPerSubVector; + } + + public Metric getMetric() { + return metric; + } + + public int getMaxIters() { + return maxIters; + } + + public int getSampleRate() { + return sampleRate; + } + + public long getSeed() { + return seed; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private int numPartitions = 256; + private int numSubVectors = 16; + private int numBitsPerSubVector = 8; + private Metric metric = Metric.L2; + private int maxIters = 50; + private int sampleRate = 256; + private long seed = 0xCAFEBABEDEADBEEFL; + + public Builder numPartitions(int n) { + this.numPartitions = n; + return this; + } + + public Builder numSubVectors(int n) { + this.numSubVectors = n; + return this; + } + + public Builder numBitsPerSubVector(int n) { + this.numBitsPerSubVector = n; + return this; + } + + public Builder metric(Metric m) { + this.metric = m; + return this; + } + + public Builder maxIters(int n) { + this.maxIters = n; + return this; + } + + public Builder sampleRate(int n) { + this.sampleRate = n; + return this; + } + + public Builder seed(long seed) { + this.seed = seed; + return this; + } + + public ExternalIvfPqIndexParams build() { + return new ExternalIvfPqIndexParams(this); + } + } +} diff --git a/java/src/main/java/org/lance/index/external/ParquetRowKey.java b/java/src/main/java/org/lance/index/external/ParquetRowKey.java new file mode 100644 index 00000000000..d9f7dc79cef --- /dev/null +++ b/java/src/main/java/org/lance/index/external/ParquetRowKey.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import java.util.Objects; + +/** {@code (file_path, row_index)} input to {@link ExternalIvfPqIndex#fetchRows}. */ +public final class ParquetRowKey { + private final String filePath; + private final long rowIndex; + + public ParquetRowKey(String filePath, long rowIndex) { + this.filePath = filePath; + this.rowIndex = rowIndex; + } + + public static ParquetRowKey of(String filePath, long rowIndex) { + return new ParquetRowKey(filePath, rowIndex); + } + + public String getFilePath() { + return filePath; + } + + public long getRowIndex() { + return rowIndex; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ParquetRowKey)) { + return false; + } + ParquetRowKey that = (ParquetRowKey) o; + return rowIndex == that.rowIndex && Objects.equals(filePath, that.filePath); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, rowIndex); + } +} diff --git a/java/src/main/java/org/lance/index/external/SearchResult.java b/java/src/main/java/org/lance/index/external/SearchResult.java new file mode 100644 index 00000000000..1bd82c62681 --- /dev/null +++ b/java/src/main/java/org/lance/index/external/SearchResult.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import java.util.Objects; + +/** + * One result from {@link ExternalIvfPqIndex#search}. The result is already refined; the {@code + * distance} is the exact distance under the index's metric. + */ +public final class SearchResult { + private final String filePath; + private final long rowIndex; + private final float distance; + + public SearchResult(String filePath, long rowIndex, float distance) { + this.filePath = filePath; + this.rowIndex = rowIndex; + this.distance = distance; + } + + /** Path of the parquet file the row lives in (one of the registered file specs). */ + public String getFilePath() { + return filePath; + } + + /** Zero-based row index within {@link #getFilePath()}. */ + public long getRowIndex() { + return rowIndex; + } + + /** Exact distance under the index's distance metric. */ + public float getDistance() { + return distance; + } + + @Override + public String toString() { + return "SearchResult{" + filePath + "@" + rowIndex + " d=" + distance + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SearchResult)) { + return false; + } + SearchResult that = (SearchResult) o; + return rowIndex == that.rowIndex + && Float.compare(distance, that.distance) == 0 + && Objects.equals(filePath, that.filePath); + } + + @Override + public int hashCode() { + return Objects.hash(filePath, rowIndex, distance); + } +} diff --git a/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java b/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java index b4567db3acd..1fe0e8bd020 100644 --- a/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java +++ b/java/src/main/java/org/lance/memwal/LsmVectorSearchPlanner.java @@ -93,9 +93,12 @@ private native void nativeCreate( * @param nprobes number of IVF partitions to probe * @param columns columns to project; pass {@code null} to return all columns plus {@code * _distance} + * @param refineFactor when positive, the base-table arm over-fetches {@code k * refineFactor} + * candidates and re-ranks them with exact distances. Pass {@code 0} or negative to disable. * @return an executable plan */ - public ExecutionPlan planSearch(Float4Vector query, int k, int nprobes, List columns) { + public ExecutionPlan planSearch( + Float4Vector query, int k, int nprobes, List columns, int refineFactor) { Preconditions.checkNotNull(query, "query must not be null"); Preconditions.checkArgument(k > 0, "k must be positive, got %s", k); Preconditions.checkArgument(nprobes > 0, "nprobes must be positive, got %s", nprobes); @@ -111,20 +114,40 @@ public ExecutionPlan planSearch(Float4Vector query, int k, int nprobes, List columns) { + return planSearch(query, k, nprobes, columns, 0); + } + /** Plan a KNN vector search with default {@code nprobes} of 20. */ public ExecutionPlan planSearch(Float4Vector query, int k) { - return planSearch(query, k, 20, null); + return planSearch(query, k, 20, null, 0); } private native ExecutionPlan nativePlanSearch( - long arrayAddress, long schemaAddress, int k, int nprobes, Optional> columns); + long arrayAddress, + long schemaAddress, + int k, + int nprobes, + Optional> columns, + int refineFactor); /** * Close the planner and release native resources. If the planner is already closed, invoking this diff --git a/java/src/main/java/org/lance/memwal/ShardingEvaluator.java b/java/src/main/java/org/lance/memwal/ShardingEvaluator.java new file mode 100644 index 00000000000..50e6cdb1ea3 --- /dev/null +++ b/java/src/main/java/org/lance/memwal/ShardingEvaluator.java @@ -0,0 +1,124 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.memwal; + +import org.lance.JniLoader; +import org.lance.schema.LanceField; +import org.lance.schema.LanceSchema; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** Evaluates MemWAL sharding specs against Arrow record batches. */ +public final class ShardingEvaluator { + static { + JniLoader.ensureLoaded(); + } + + private ShardingEvaluator() {} + + /** + * Evaluate {@code spec} against {@code root}. + * + * @param allocator allocator used for Arrow C data interface structs + * @param root input record batch + * @param spec MemWAL sharding spec to evaluate + * @param schema Lance table schema used to resolve spec source field IDs to input column names + * @return an Arrow reader containing one result batch with the derived sharding fields + */ + public static ArrowReader evaluate( + BufferAllocator allocator, VectorSchemaRoot root, ShardingSpec spec, LanceSchema schema) { + Preconditions.checkNotNull(allocator, "allocator must not be null"); + Preconditions.checkNotNull(root, "root must not be null"); + Preconditions.checkNotNull(spec, "spec must not be null"); + Preconditions.checkNotNull(schema, "schema must not be null"); + + try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray arrowArray = ArrowArray.allocateNew(allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); + nativeEvaluate( + arrowArray.memoryAddress(), + arrowSchema.memoryAddress(), + spec, + sourceIdToColumnMap(schema), + stream.memoryAddress()); + return Data.importArrayStream(allocator, stream); + } + } + + /** Evaluate {@code spec} against {@code root} when the spec embeds enough column information. */ + public static ArrowReader evaluate( + BufferAllocator allocator, VectorSchemaRoot root, ShardingSpec spec) { + return evaluate(allocator, root, spec, Collections.emptyMap()); + } + + private static ArrowReader evaluate( + BufferAllocator allocator, + VectorSchemaRoot root, + ShardingSpec spec, + Map sourceIdToColumn) { + Preconditions.checkNotNull(allocator, "allocator must not be null"); + Preconditions.checkNotNull(root, "root must not be null"); + Preconditions.checkNotNull(spec, "spec must not be null"); + Preconditions.checkNotNull(sourceIdToColumn, "sourceIdToColumn must not be null"); + + try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray arrowArray = ArrowArray.allocateNew(allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); + nativeEvaluate( + arrowArray.memoryAddress(), + arrowSchema.memoryAddress(), + spec, + sourceIdToColumn, + stream.memoryAddress()); + return Data.importArrayStream(allocator, stream); + } + } + + private static Map sourceIdToColumnMap(LanceSchema schema) { + Map result = new HashMap<>(); + for (LanceField field : schema.fields()) { + collectFieldIds(field, "", result); + } + return result; + } + + private static void collectFieldIds( + LanceField field, String prefix, Map result) { + String fullName = prefix.isEmpty() ? field.getName() : prefix + "." + field.getName(); + result.put(field.getId(), fullName); + for (LanceField child : field.getChildren()) { + collectFieldIds(child, fullName, result); + } + } + + private static native void nativeEvaluate( + long arrowArrayAddress, + long arrowSchemaAddress, + ShardingSpec spec, + Map sourceIdToColumn, + long streamAddress); +} diff --git a/java/src/main/java/org/lance/namespace/DirectoryNamespace.java b/java/src/main/java/org/lance/namespace/DirectoryNamespace.java index 7bd5f9c7f9f..d26bbff9135 100644 --- a/java/src/main/java/org/lance/namespace/DirectoryNamespace.java +++ b/java/src/main/java/org/lance/namespace/DirectoryNamespace.java @@ -26,6 +26,8 @@ import org.lance.namespace.model.BatchDeleteTableVersionsRequest; import org.lance.namespace.model.BatchDeleteTableVersionsResponse; import org.lance.namespace.model.CountTableRowsRequest; +import org.lance.namespace.model.CreateMaterializedViewRequest; +import org.lance.namespace.model.CreateMaterializedViewResponse; import org.lance.namespace.model.CreateNamespaceRequest; import org.lance.namespace.model.CreateNamespaceResponse; import org.lance.namespace.model.CreateTableIndexRequest; @@ -662,6 +664,15 @@ public UpdateTableTagResponse updateTableTag(UpdateTableTagRequest request) { return fromJson(responseJson, UpdateTableTagResponse.class); } + @Override + public CreateMaterializedViewResponse createMaterializedView( + CreateMaterializedViewRequest request) { + ensureInitialized(); + String requestJson = toJson(request); + String responseJson = createMaterializedViewNative(nativeDirectoryNamespaceHandle, requestJson); + return fromJson(responseJson, CreateMaterializedViewResponse.class); + } + @Override public void close() { if (nativeDirectoryNamespaceHandle != 0) { @@ -830,6 +841,8 @@ private native String mergeInsertIntoTableNative( private native String updateTableTagNative(long handle, String requestJson); + private native String createMaterializedViewNative(long handle, String requestJson); + private native Map retrieveOpsMetricsNative(long handle); private native void resetOpsMetricsNative(long handle); diff --git a/java/src/main/java/org/lance/namespace/RestNamespace.java b/java/src/main/java/org/lance/namespace/RestNamespace.java index fbf58bd55c6..9cbbc588660 100644 --- a/java/src/main/java/org/lance/namespace/RestNamespace.java +++ b/java/src/main/java/org/lance/namespace/RestNamespace.java @@ -26,6 +26,8 @@ import org.lance.namespace.model.BatchDeleteTableVersionsRequest; import org.lance.namespace.model.BatchDeleteTableVersionsResponse; import org.lance.namespace.model.CountTableRowsRequest; +import org.lance.namespace.model.CreateMaterializedViewRequest; +import org.lance.namespace.model.CreateMaterializedViewResponse; import org.lance.namespace.model.CreateNamespaceRequest; import org.lance.namespace.model.CreateNamespaceResponse; import org.lance.namespace.model.CreateTableIndexRequest; @@ -568,6 +570,15 @@ public UpdateTableTagResponse updateTableTag(UpdateTableTagRequest request) { return fromJson(responseJson, UpdateTableTagResponse.class); } + @Override + public CreateMaterializedViewResponse createMaterializedView( + CreateMaterializedViewRequest request) { + ensureInitialized(); + String requestJson = toJson(request); + String responseJson = createMaterializedViewNative(nativeRestNamespaceHandle, requestJson); + return fromJson(responseJson, CreateMaterializedViewResponse.class); + } + @Override public void close() { if (nativeRestNamespaceHandle != 0) { @@ -735,6 +746,8 @@ private native String mergeInsertIntoTableNative( private native String updateTableTagNative(long handle, String requestJson); + private native String createMaterializedViewNative(long handle, String requestJson); + private native Map retrieveOpsMetricsNative(long handle); private native void resetOpsMetricsNative(long handle); diff --git a/java/src/test/java/org/lance/DatasetTest.java b/java/src/test/java/org/lance/DatasetTest.java index c01154b0b71..315b010da1e 100644 --- a/java/src/test/java/org/lance/DatasetTest.java +++ b/java/src/test/java/org/lance/DatasetTest.java @@ -528,7 +528,12 @@ void testOpenSerializedManifest(@TempDir Path tempDir) throws IOException { assertEquals(1, dataset1.version()); Path manifestPath = datasetPath.resolve("_versions"); try (Stream fileStream = Files.list(manifestPath)) { - assertEquals(1, fileStream.count()); + // Ignore the version hint file, which is not a manifest. + assertEquals( + 1, + fileStream + .filter(p -> !p.getFileName().toString().startsWith("latest_version_hint")) + .count()); ByteBuffer manifestBuffer = readManifest(manifestPath.resolve("1.manifest")); try (Dataset dataset2 = testDataset.write(1, 5)) { assertEquals(2, dataset2.version()); diff --git a/java/src/test/java/org/lance/MultiBaseTest.java b/java/src/test/java/org/lance/MultiBaseTest.java index 802b4a2ca31..e4f1d982e14 100644 --- a/java/src/test/java/org/lance/MultiBaseTest.java +++ b/java/src/test/java/org/lance/MultiBaseTest.java @@ -36,13 +36,18 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; public class MultiBaseTest { private BufferAllocator allocator; @@ -103,6 +108,51 @@ private ArrowStreamReader makeReader(int startId, int count) throws Exception { } } + private VectorSchemaRoot makeRoot(int startId, int count) { + List fields = + Arrays.asList( + new Field("id", FieldType.notNullable(new ArrowType.Int(32, true)), null), + new Field("value", FieldType.nullable(new ArrowType.Utf8()), null)); + Schema schema = new Schema(fields); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + IntVector idVec = (IntVector) root.getVector("id"); + idVec.allocateNew(count); + VarCharVector valVec = (VarCharVector) root.getVector("value"); + valVec.allocateNew(); + for (int i = 0; i < count; i++) { + int id = startId + i; + idVec.setSafe(i, id); + byte[] b = ("val_" + id).getBytes(); + valVec.setSafe(i, b, 0, b.length); + } + root.setRowCount(count); + return root; + } + + private boolean hasLanceFile(String basePath) throws Exception { + try (Stream paths = Files.walk(Path.of(basePath))) { + return paths.anyMatch(path -> path.toString().endsWith(".lance")); + } + } + + private Set baseIds(Dataset dataset) { + return dataset.getFragments().stream() + .flatMap(f -> f.metadata().getFiles().stream()) + .map(DataFile::getBaseId) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); + } + + private Set baseIds(List fragments) { + return fragments.stream() + .flatMap(fragment -> fragment.getFiles().stream()) + .map(DataFile::getBaseId) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(Collectors.toSet()); + } + @Test public void testCreateMode() throws Exception { ArrowStreamReader reader = makeReader(0, 500); @@ -211,12 +261,7 @@ public void testTargetByPathUri() throws Exception { .maxRowsPerFile(50) .execute(); - Set baseIds = - ds.getFragments().stream() - .flatMap(f -> f.metadata().getFiles().stream().map(DataFile::getBaseId)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(Collectors.toSet()); + Set baseIds = baseIds(ds); assertEquals(1, baseIds.size()); ArrowStreamReader append = makeReader(100, 50); @@ -231,12 +276,103 @@ public void testTargetByPathUri() throws Exception { .execute(); assertEquals(150, updated.countRows()); - baseIds = - updated.getFragments().stream() - .flatMap(f -> f.metadata().getFiles().stream().map(DataFile::getBaseId)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(Collectors.toSet()); + baseIds = baseIds(updated); assertEquals(2, baseIds.size()); } + + @Test + public void testFragmentCreateWithMultiBaseParams() throws Exception { + ArrowStreamReader reader = makeReader(0, 100); + List bases = + Arrays.asList( + new BasePath(0, Optional.of("base1"), base1, false), + new BasePath(0, Optional.of("base2"), base2, false)); + + Dataset ds = + Dataset.write() + .allocator(allocator) + .reader(reader) + .uri(primary) + .mode(WriteParams.WriteMode.CREATE) + .initialBases(bases) + .targetBases(Arrays.asList("base1")) + .maxRowsPerFile(50) + .execute(); + Set initialBaseIds = baseIds(ds); + assertEquals(1, initialBaseIds.size()); + assertTrue(hasLanceFile(base1)); + + Map> baseStoreParams = new HashMap<>(); + baseStoreParams.put(base2, new HashMap<>()); + WriteParams params = + new WriteParams.Builder() + .withTargetBases(Arrays.asList("base2")) + .withBaseStoreParams(baseStoreParams) + .withMaxRowsPerFile(25) + .build(); + + List fragments; + try (VectorSchemaRoot root = makeRoot(100, 50)) { + fragments = Fragment.create(primary, allocator, root, params); + } + + assertEquals(2, fragments.size()); + Set fragmentBaseIds = baseIds(fragments); + assertEquals(1, fragmentBaseIds.size()); + assertTrue(Collections.disjoint(initialBaseIds, fragmentBaseIds)); + assertTrue(hasLanceFile(base2)); + + FragmentOperation.Append append = new FragmentOperation.Append(fragments); + Dataset updated = Dataset.commit(allocator, primary, append, Optional.of(ds.version())); + assertEquals(150, updated.countRows()); + assertEquals(2, baseIds(updated).size()); + } + + @Test + public void testFragmentWriteWithMultiBaseParams() throws Exception { + ArrowStreamReader reader = makeReader(0, 50); + List bases = + Arrays.asList( + new BasePath(0, Optional.of("base1"), base1, false), + new BasePath(0, Optional.of("base2"), base2, false)); + + Dataset ds = + Dataset.write() + .allocator(allocator) + .reader(reader) + .uri(primary) + .mode(WriteParams.WriteMode.CREATE) + .initialBases(bases) + .targetBases(Arrays.asList("base1")) + .maxRowsPerFile(50) + .execute(); + Set initialBaseIds = baseIds(ds); + assertEquals(1, initialBaseIds.size()); + assertTrue(hasLanceFile(base1)); + + Map> baseStoreParams = new HashMap<>(); + baseStoreParams.put(base2, new HashMap<>()); + + List fragments; + try (VectorSchemaRoot root = makeRoot(50, 25)) { + fragments = + Fragment.write() + .datasetUri(primary) + .allocator(allocator) + .data(root) + .targetBases(Arrays.asList("base2")) + .baseStoreParams(baseStoreParams) + .maxRowsPerFile(25) + .execute(); + } + + Set fragmentBaseIds = baseIds(fragments); + assertEquals(1, fragmentBaseIds.size()); + assertTrue(Collections.disjoint(initialBaseIds, fragmentBaseIds)); + FragmentOperation.Append append = new FragmentOperation.Append(fragments); + Dataset updated = Dataset.commit(allocator, primary, append, Optional.of(ds.version())); + assertEquals(75, updated.countRows()); + assertEquals(2, baseIds(updated).size()); + assertTrue(hasLanceFile(base2)); + } } diff --git a/java/src/test/java/org/lance/index/external/ExternalIvfPqIndexJniTest.java b/java/src/test/java/org/lance/index/external/ExternalIvfPqIndexJniTest.java new file mode 100644 index 00000000000..d96e79f7d6f --- /dev/null +++ b/java/src/test/java/org/lance/index/external/ExternalIvfPqIndexJniTest.java @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.index.external; + +import org.junit.jupiter.api.Test; +import org.lance.JniLoader; + +import java.util.ArrayList; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Smoke tests for the JNI surface. Full end-to-end coverage (build + open + search + + * fetchRows) lives in {@code rust/lance/tests/external_index_phase1.rs}, which exercises + * the same code path the JNI calls into. Java-side seeding of parquet test data + * requires a parquet writer dependency and is deferred to phase 1.7-followup. + */ +public class ExternalIvfPqIndexJniTest { + static { + JniLoader.ensureLoaded(); + } + + @Test + public void packDeletedRidsLittleEndian() { + java.util.List deletes = new ArrayList<>(); + deletes.add(new long[] {0L, 0L}); + deletes.add(new long[] {1L, 42L}); + byte[] packed = ExternalIvfPqIndex.packDeletedRids(deletes); + assertEquals(16, packed.length, "two u64s = 16 bytes"); + + // First rid: (0 << 32) | 0 = 0 + long first = java.nio.ByteBuffer.wrap(packed, 0, 8) + .order(java.nio.ByteOrder.LITTLE_ENDIAN) + .getLong(); + assertEquals(0L, first); + + // Second rid: (1 << 32) | 42 = 4294967338 + long second = java.nio.ByteBuffer.wrap(packed, 8, 8) + .order(java.nio.ByteOrder.LITTLE_ENDIAN) + .getLong(); + assertEquals((1L << 32) | 42L, second); + } + + @Test + public void packDeletedRidsValidatesPairLength() { + java.util.List bad = Collections.singletonList(new long[] {1L, 2L, 3L}); + assertThrows(IllegalArgumentException.class, () -> ExternalIvfPqIndex.packDeletedRids(bad)); + } + + @Test + public void openMissingDirThrows() { + // Validates the JNI exception bridge. LanceError::IO maps to java.io.IOException. + assertThrows( + java.io.IOException.class, + () -> ExternalIvfPqIndex.open("/tmp/this-path-does-not-exist-9f3a8e2c")); + } + + @Test + public void paramsBuilderDefaults() { + ExternalIvfPqIndexParams p = ExternalIvfPqIndexParams.builder().build(); + assertEquals(256, p.getNumPartitions()); + assertEquals(16, p.getNumSubVectors()); + assertEquals(8, p.getNumBitsPerSubVector()); + assertEquals(ExternalIvfPqIndexParams.Metric.L2, p.getMetric()); + assertNotNull(p.getMetric().toRustString()); + } + + @Test + public void paramsBuilderOverrides() { + ExternalIvfPqIndexParams p = + ExternalIvfPqIndexParams.builder() + .numPartitions(32) + .numSubVectors(4) + .metric(ExternalIvfPqIndexParams.Metric.Cosine) + .maxIters(10) + .seed(42L) + .build(); + assertEquals(32, p.getNumPartitions()); + assertEquals(4, p.getNumSubVectors()); + assertEquals(ExternalIvfPqIndexParams.Metric.Cosine, p.getMetric()); + assertEquals("Cosine", p.getMetric().toRustString()); + assertEquals(10, p.getMaxIters()); + assertEquals(42L, p.getSeed()); + } +} diff --git a/java/src/test/java/org/lance/memwal/MemWalTest.java b/java/src/test/java/org/lance/memwal/MemWalTest.java index 6e1bd6bd4f7..ee26932dd59 100644 --- a/java/src/test/java/org/lance/memwal/MemWalTest.java +++ b/java/src/test/java/org/lance/memwal/MemWalTest.java @@ -72,6 +72,14 @@ public class MemWalTest { new Field( "id", new FieldType(false, new ArrowType.Int(64, true), null, PK_META), null), Field.nullable("name", new ArrowType.Utf8()))); + private static final Schema APPEND_ONLY_SCHEMA = + new Schema( + Arrays.asList( + new Field( + "id", + new FieldType(false, new ArrowType.Int(64, true), null, Collections.emptyMap()), + null), + Field.nullable("name", new ArrowType.Utf8()))); /** Build a single-batch root where {@code name = "{prefix}_{id}"}. */ private static VectorSchemaRoot lookupRoot(BufferAllocator allocator, long[] ids, String prefix) { @@ -88,6 +96,22 @@ private static VectorSchemaRoot lookupRoot(BufferAllocator allocator, long[] ids return root; } + /** Build a single-batch append-only root without primary-key metadata. */ + private static VectorSchemaRoot appendOnlyRoot( + BufferAllocator allocator, long[] ids, String prefix) { + VectorSchemaRoot root = VectorSchemaRoot.create(APPEND_ONLY_SCHEMA, allocator); + BigIntVector idVector = (BigIntVector) root.getVector("id"); + VarCharVector nameVector = (VarCharVector) root.getVector("name"); + idVector.allocateNew(ids.length); + nameVector.allocateNew(); + for (int i = 0; i < ids.length; i++) { + idVector.set(i, ids[i]); + nameVector.setSafe(i, (prefix + "_" + ids[i]).getBytes(StandardCharsets.UTF_8)); + } + root.setRowCount(ids.length); + return root; + } + /** Wrap an in-memory root into an {@link ArrowReader} via the Arrow IPC stream format. */ private static ArrowReader toReader(BufferAllocator allocator, VectorSchemaRoot root) throws Exception { @@ -109,6 +133,15 @@ private static Dataset writeLookupDataset( } } + /** Write an append-only base dataset of `(id, name)` rows at {@code path}. */ + private static Dataset writeAppendOnlyDataset( + BufferAllocator allocator, String path, long[] ids, String prefix) throws Exception { + try (VectorSchemaRoot root = appendOnlyRoot(allocator, ids, prefix); + ArrowReader reader = toReader(allocator, root)) { + return Dataset.write().allocator(allocator).reader(reader).uri(path).execute(); + } + } + /** Read an LSM scanner fully into an {@code id -> name} map. */ private static Map readByName(ArrowReader reader) throws Exception { Map byId = new HashMap<>(); @@ -145,6 +178,127 @@ void testInitializeMemWalUnsharded(@TempDir Path tempDir) throws Exception { } } + @Test + void testInitializeMemWalBucketShardingWithoutPrimaryKey(@TempDir Path tempDir) throws Exception { + String path = tempDir.resolve("append_only").toString(); + try (BufferAllocator allocator = new RootAllocator(); + Dataset dataset = writeAppendOnlyDataset(allocator, path, new long[] {1, 2, 3}, "base")) { + dataset.initializeMemWal(new InitializeMemWalParams().withBucketSharding("id", 4)); + + Optional details = dataset.memWalIndexDetails(); + assertTrue(details.isPresent()); + assertEquals(4L, details.get().numShards()); + ShardingField field = details.get().shardingSpecs().get(0).fields().get(0); + assertEquals("bucket", field.transform().get()); + assertEquals("4", field.parameters().get("num_buckets")); + } + } + + @Test + void testInitializeMemWalBucketShardingUsesConfiguredColumn(@TempDir Path tempDir) + throws Exception { + String path = tempDir.resolve("base").toString(); + try (BufferAllocator allocator = new RootAllocator(); + Dataset dataset = writeLookupDataset(allocator, path, new long[] {1, 2, 3}, "base")) { + dataset.initializeMemWal(new InitializeMemWalParams().withBucketSharding("name", 4)); + + MemWalIndexDetails details = dataset.memWalIndexDetails().get(); + ShardingField field = details.shardingSpecs().get(0).fields().get(0); + int nameFieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals("name")) + .findFirst() + .get() + .getId(); + assertEquals("bucket", field.transform().get()); + assertEquals(nameFieldId, field.sourceIds().get(0)); + } + } + + @Test + void testShardingEvaluatorBucketAndIdentity(@TempDir Path tempDir) throws Exception { + String path = tempDir.resolve("append_only").toString(); + try (BufferAllocator allocator = new RootAllocator(); + Dataset dataset = writeAppendOnlyDataset(allocator, path, new long[] {1}, "base")) { + dataset.initializeMemWal(new InitializeMemWalParams().withBucketSharding("id", 4)); + ShardingSpec bucketSpec = dataset.memWalIndexDetails().get().shardingSpecs().get(0); + ShardingField bucketField = bucketSpec.fields().get(0); + + try (VectorSchemaRoot root = appendOnlyRoot(allocator, new long[] {1, 2, 3}, "eval"); + ArrowReader reader = + ShardingEvaluator.evaluate(allocator, root, bucketSpec, dataset.getLanceSchema())) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot result = reader.getVectorSchemaRoot(); + IntVector buckets = (IntVector) result.getVector(bucketField.fieldId()); + assertEquals(3, result.getRowCount()); + assertEquals(0, buckets.get(0)); + assertEquals(0, buckets.get(1)); + assertEquals(3, buckets.get(2)); + assertFalse(reader.loadNextBatch()); + } + + int nameFieldId = + dataset.getLanceSchema().fields().stream() + .filter(f -> f.getName().equals("name")) + .findFirst() + .get() + .getId(); + ShardingSpec identitySpec = + new ShardingSpec( + 7, + Collections.singletonList( + new ShardingField( + "name_identity", + Collections.singletonList(nameFieldId), + "identity", + null, + "utf8", + Collections.emptyMap()))); + try (VectorSchemaRoot root = appendOnlyRoot(allocator, new long[] {1}, "eval"); + ArrowReader reader = + ShardingEvaluator.evaluate(allocator, root, identitySpec, dataset.getLanceSchema())) { + assertTrue(reader.loadNextBatch()); + VarCharVector names = + (VarCharVector) reader.getVectorSchemaRoot().getVector("name_identity"); + assertEquals("eval_1", new String(names.get(0), StandardCharsets.UTF_8)); + assertFalse(reader.loadNextBatch()); + } + + Map stringBucketParameters = new HashMap<>(); + stringBucketParameters.put("column", "key"); + stringBucketParameters.put("num_buckets", "8"); + ShardingSpec stringBucketSpec = + new ShardingSpec( + 8, + Collections.singletonList( + new ShardingField( + "key_bucket", + Collections.emptyList(), + "bucket", + null, + "int32", + stringBucketParameters))); + Schema stringSchema = + new Schema(Collections.singletonList(Field.nullable("key", new ArrowType.Utf8()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(stringSchema, allocator)) { + VarCharVector keyVector = (VarCharVector) root.getVector("key"); + keyVector.allocateNew(); + keyVector.setSafe(0, "a".getBytes(StandardCharsets.UTF_8)); + keyVector.setSafe(1, "b".getBytes(StandardCharsets.UTF_8)); + keyVector.setNull(2); + root.setRowCount(3); + try (ArrowReader reader = ShardingEvaluator.evaluate(allocator, root, stringBucketSpec)) { + assertTrue(reader.loadNextBatch()); + IntVector buckets = (IntVector) reader.getVectorSchemaRoot().getVector("key_bucket"); + assertEquals(1, buckets.get(0)); + assertEquals(5, buckets.get(1)); + assertEquals(0, buckets.get(2)); + assertFalse(reader.loadNextBatch()); + } + } + } + } + @Test void testInitializeMemWalRejectsConflictingSharding(@TempDir Path tempDir) throws Exception { String path = tempDir.resolve("base").toString(); diff --git a/protos/index.proto b/protos/index.proto index 1fb51f3291c..ea21c70387d 100644 --- a/protos/index.proto +++ b/protos/index.proto @@ -184,6 +184,62 @@ message VectorIndex { VectorMetricType metric_type = 4; } +// Details for vector indexes, stored in the manifest's index_details field. +message VectorIndexDetails { + VectorMetricType metric_type = 1; + + // The target number of vectors per partition. + // 0 means unset. + uint64 target_partition_size = 2; + + // Optional HNSW index configuration. If set, the index has an HNSW layer. + optional HnswParameters hnsw_index_config = 3; + + message ProductQuantization { + uint32 num_bits = 1; + uint32 num_sub_vectors = 2; + } + message ScalarQuantization { + uint32 num_bits = 1; + } + message RabitQuantization { + enum RotationType { + FAST = 0; + MATRIX = 1; + } + uint32 num_bits = 1; + RotationType rotation_type = 2; + } + + // No quantization; vectors are stored as-is. + message FlatCompression {} + + oneof compression { + ProductQuantization pq = 4; + ScalarQuantization sq = 5; + RabitQuantization rq = 6; + FlatCompression flat = 8; + } + + // Runtime hints: optional build preferences that don't affect index structure. + // Keys use reverse-DNS namespacing (e.g., "lance.ivf.max_iters", "lancedb.accelerator"). + // Unrecognized keys must be silently ignored by all runtimes. + map runtime_hints = 9; +} + +// Hierarchical Navigable Small World (HNSW) parameters, used as an optional configuration for IVF indexes. +message HnswParameters { + // The maximum number of outgoing edges per node in the HNSW graph. Higher values + // means more connections, better recall, but more memory and slower builds. + // Referred to as "M" in the HNSW literature. + uint32 max_connections = 1; + // "construction exploration factor": The size of the dynamic list used during + // index construction. + uint32 construction_ef = 2; + // The maximum number of levels in the HNSW graph. + uint32 max_level = 3; +} + message JsonIndexDetails { string path = 1; google.protobuf.Any target_details = 2; diff --git a/protos/table.proto b/protos/table.proto index 9fe687bb03e..d298809d5d8 100644 --- a/protos/table.proto +++ b/protos/table.proto @@ -474,8 +474,7 @@ message ExternalFile { uint64 size = 3; } -// Empty details messages for older indexes that don't take advantage of the details field. -message VectorIndexDetails {} +// VectorIndexDetails and HnswParameters (formerly HnswIndexDetails) moved to index.proto message FragmentReuseIndexDetails { diff --git a/python/Cargo.lock b/python/Cargo.lock index fa439bf442a..ae5c73e68c5 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -2679,9 +2679,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "encoding_rs" @@ -2853,7 +2853,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" [[package]] name = "fsst" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "rand 0.9.4", @@ -3975,7 +3975,7 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a" [[package]] name = "lance" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arc-swap", "arrow", @@ -4023,6 +4023,7 @@ dependencies = [ "lance-table", "lance-tokenizer", "log", + "moka", "object_store", "permutation", "pin-project", @@ -4046,7 +4047,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4066,7 +4067,7 @@ dependencies = [ [[package]] name = "lance-bitpacking" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrayref", "paste", @@ -4075,7 +4076,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4110,7 +4111,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4142,7 +4143,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4160,7 +4161,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-arith", "arrow-array", @@ -4195,7 +4196,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-arith", "arrow-array", @@ -4226,7 +4227,7 @@ dependencies = [ [[package]] name = "lance-geo" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "datafusion", "geo-traits", @@ -4240,7 +4241,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arc-swap", "arrow", @@ -4308,7 +4309,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-arith", @@ -4350,7 +4351,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow-array", "arrow-buffer", @@ -4366,7 +4367,7 @@ dependencies = [ [[package]] name = "lance-namespace" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "async-trait", @@ -4378,7 +4379,7 @@ dependencies = [ [[package]] name = "lance-namespace-impls" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-ipc", @@ -4408,9 +4409,9 @@ dependencies = [ [[package]] name = "lance-namespace-reqwest-client" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65e31bdaa13e01dab6e7cf566da31df243c34a542f0d915d3601ec0e01e61d2" +checksum = "6369eee4682fb11edf538388b43c61ce288b8302fe89bb40944d7daa7faaae99" dependencies = [ "reqwest 0.12.28", "serde", @@ -4422,7 +4423,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", @@ -4461,7 +4462,7 @@ dependencies = [ [[package]] name = "lance-tokenizer" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "jieba-rs", "lindera", @@ -5881,7 +5882,7 @@ dependencies = [ [[package]] name = "pylance" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" dependencies = [ "arrow", "arrow-array", diff --git a/python/Cargo.toml b/python/Cargo.toml index c9b7d918a7c..c1bc0630d47 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "7.0.0-beta.15" +version = "7.1.0-beta.1" edition = "2024" authors = ["Lance Devs "] license = "Apache-2.0" diff --git a/python/pyproject.toml b/python/pyproject.toml index 6574bd24559..97e023bb7c7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "pylance" dynamic = ["version"] -dependencies = ["pyarrow>=14", "numpy>=1.22", "lance-namespace>=0.7.5,<0.8"] +dependencies = ["pyarrow>=14", "numpy>=1.22", "lance-namespace>=0.7.7,<0.8"] description = "python wrapper for Lance columnar format" authors = [{ name = "Lance Devs", email = "dev@lance.org" }] license = { file = "LICENSE" } diff --git a/python/python/lance/__init__.py b/python/python/lance/__init__.py index 78f97418133..f58b169a47a 100644 --- a/python/python/lance/__init__.py +++ b/python/python/lance/__init__.py @@ -39,10 +39,11 @@ LsmScanner, LsmVectorSearchPlanner, MergedGeneration, - RegionField, - RegionSnapshot, - RegionSpec, - RegionWriter, + ShardingField, + ShardingSpec, + ShardSnapshot, + ShardWriter, + evaluate_sharding_spec, ) from .namespace import ( DescribeTableRequest, @@ -99,10 +100,11 @@ "LsmScanner", "LsmVectorSearchPlanner", "MergedGeneration", - "RegionField", - "RegionSpec", - "RegionSnapshot", - "RegionWriter", + "ShardSnapshot", + "ShardWriter", + "ShardingField", + "ShardingSpec", + "evaluate_sharding_spec", ] diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index a57781480c1..5737ec013b5 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -4740,19 +4740,19 @@ def initialize_mem_wal( ) -> None: """Initialize MemWAL on this dataset. - Must be called once before any calls to `mem_wal_writer`. The dataset - schema must have at least one field annotated with the - ``lance-schema:unenforced-primary-key`` Arrow field metadata. + Must be called once before any calls to `mem_wal_writer`. Append-only + tables may omit primary-key metadata; primary keys are only required + for primary-key lookup and last-write-wins deduplication workflows. At most one sharding mode may be selected: bucket sharding (``bucket_column`` + ``num_buckets``), identity sharding (``identity_column``), or ``unsharded``. With none selected, shards are - managed manually by passing region IDs to `mem_wal_writer`. + managed manually by passing shard IDs to `mem_wal_writer`. Any writer-configuration keyword arguments (``durable_write``, ``max_memtable_size``, ``max_wal_flush_interval_ms``, etc. — the same knobs accepted by `mem_wal_writer`) are recorded as the default - `~lance.mem_wal.RegionWriter` configuration in the MemWAL index, so + `~lance.mem_wal.ShardWriter` configuration in the MemWAL index, so every writer starts from the same defaults. Parameters @@ -4761,8 +4761,7 @@ def initialize_mem_wal( Names of existing indexes to keep updated as data is written through the MemWAL. Must reference indexes that already exist. bucket_column : str, optional - With ``num_buckets``, hash-bucket writes by this column, which must - be the single-column unenforced primary key. + With ``num_buckets``, hash-bucket writes by this scalar column. num_buckets : int, optional Number of hash buckets (shards). Required with ``bucket_column``. identity_column : str, optional @@ -4773,7 +4772,6 @@ def initialize_mem_wal( Raises ------ IOError - - Dataset has no ``lance-schema:unenforced-primary-key`` field. - An entry in *maintained_indexes* does not exist on the dataset. - MemWAL has already been initialized on this dataset. ValueError @@ -4815,7 +4813,7 @@ def mem_wal_index_details(self) -> Optional[dict]: def mem_wal_writer( self, - region_id: str, + shard_id: str, *, durable_write: Optional[bool] = None, sync_indexed_write: Optional[bool] = None, @@ -4830,17 +4828,17 @@ def mem_wal_writer( async_index_interval_ms: Optional[int] = None, backpressure_log_interval_ms: Optional[int] = None, stats_log_interval_ms: Optional[int] = None, - ) -> "mem_wal.RegionWriter": - """Get a RegionWriter for the specified region. + ) -> "mem_wal.ShardWriter": + """Get a ShardWriter for the specified shard. `initialize_mem_wal` must be called before using this method. - Each *region* is an independent write shard; use different region IDs + Each shard is an independent write path; use different shard IDs to achieve parallel ingestion without writer contention. Parameters ---------- - region_id : str - UUID string identifying the write region (e.g. + shard_id : str + UUID string identifying the write shard (e.g. ``str(uuid.uuid4())``). durable_write : bool, optional Whether to fsync WAL writes (default: ``True``). @@ -4873,8 +4871,8 @@ def mem_wal_writer( Returns ------- - RegionWriter - A context-manager-compatible writer for the specified region. + ShardWriter + A context-manager-compatible writer for the specified shard. Examples -------- @@ -4893,9 +4891,9 @@ def mem_wal_writer( ... tmpdir, ... ) ... ds.initialize_mem_wal() - ... region_id = str(uuid.uuid4()) + ... shard_id = str(uuid.uuid4()) ... new_data = pa.table({"id": [2], "val": [0.2]}, schema=schema) - ... with ds.mem_wal_writer(region_id) as writer: + ... with ds.mem_wal_writer(shard_id) as writer: ... writer.put(new_data) """ import lance.mem_wal as _mw @@ -4919,8 +4917,8 @@ def mem_wal_writer( ] if val is not None } - raw = self._ds.mem_wal_writer(region_id, **kwargs) - return _mw.RegionWriter(raw) + raw = self._ds.mem_wal_writer(shard_id, **kwargs) + return _mw.ShardWriter(raw) class SqlQuery: diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index e76baf7e5dd..75722ac82e1 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -1011,6 +1011,7 @@ def write_fragments( enable_stable_row_ids: bool = False, target_bases: Optional[List[str]] = None, initial_bases: Optional[List["DatasetBasePath"]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, namespace_client: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, ) -> Transaction: ... @@ -1033,6 +1034,7 @@ def write_fragments( enable_stable_row_ids: bool = False, target_bases: Optional[List[str]] = None, initial_bases: Optional[List["DatasetBasePath"]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, namespace_client: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, ) -> List[FragmentMetadata]: ... @@ -1055,6 +1057,7 @@ def write_fragments( enable_stable_row_ids: bool = False, target_bases: Optional[List[str]] = None, initial_bases: Optional[List["DatasetBasePath"]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, namespace_client: Optional[LanceNamespace] = None, table_id: Optional[List[str]] = None, ) -> List[FragmentMetadata] | Transaction: @@ -1132,6 +1135,13 @@ def write_fragments( **Only valid in CREATE mode**. Will raise an error if used with APPEND/OVERWRITE modes. + base_store_params : dict of str to dict, optional + Runtime-only object store parameters keyed by exact base path URI. + Each value is a dict of storage options for that base. These settings + are not persisted to the manifest. When a base has no explicit entry, + top-level ``storage_options`` is used as a fallback. If ``dataset_uri`` + is a LanceDataset and this is omitted, the dataset's base store params + are inherited. namespace_client : optional, LanceNamespace A namespace client for automatic credential refresh. When provided with `table_id`, a storage options provider will be created automatically to @@ -1173,6 +1183,10 @@ def write_fragments( if isinstance(dataset_uri, Path): dataset_uri = str(dataset_uri) elif isinstance(dataset_uri, LanceDataset): + if base_store_params is None: + base_store_params = dataset_uri._base_store_params + if storage_options is None: + storage_options = dataset_uri._storage_options dataset_uri = dataset_uri._ds elif not isinstance(dataset_uri, str): raise TypeError(f"Unknown dataset_uri type {type(dataset_uri)}") @@ -1204,6 +1218,7 @@ def write_fragments( enable_stable_row_ids=enable_stable_row_ids, target_bases=target_bases, initial_bases=initial_bases, + base_store_params=base_store_params, ) diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 6edcef5e080..dd91fb19614 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -559,6 +559,9 @@ def _write_fragments( namespace_client: Optional[LanceNamespace], table_id: Optional[List[str]], enable_stable_row_ids: bool, + target_bases: Optional[List[str]] = None, + initial_bases: Optional[List[Any]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, ): ... def _write_fragments_transaction( dataset_uri: str | Path | _Dataset, @@ -573,11 +576,90 @@ def _write_fragments_transaction( namespace_client: Optional[LanceNamespace], table_id: Optional[List[str]], enable_stable_row_ids: bool, + target_bases: Optional[List[str]] = None, + initial_bases: Optional[List[Any]] = None, + base_store_params: Optional[Dict[str, Dict[str, str]]] = None, ) -> Transaction: ... def _json_to_schema(schema_json: str) -> pa.Schema: ... def _schema_to_json(schema: pa.Schema) -> str: ... def _parse_field_path(path: str) -> list[str]: ... def _format_field_path(segments: list[str]) -> str: ... +def _evaluate_sharding_spec( + batch: pa.RecordBatch, + spec: Dict[str, Any], + schema: LanceSchema, +) -> pa.RecordBatch: ... + +class _MergedGeneration: + shard_id: str + generation: int + def __init__(self, shard_id: str, generation: int) -> None: ... + +class _ShardSnapshot: + shard_id: str + def __init__(self, shard_id: str) -> None: ... + def with_spec_id(self, spec_id: int) -> Self: ... + def with_current_generation(self, generation: int) -> Self: ... + def with_flushed_generation(self, generation: int, path: str) -> Self: ... + +class _ShardWriter: + shard_id: str + def put(self, data: Any) -> None: ... + def close(self) -> None: ... + def stats(self) -> Dict[str, Any]: ... + def memtable_stats(self) -> Dict[str, Any]: ... + def lsm_scanner( + self, shard_snapshots: Optional[List[_ShardSnapshot]] = None + ) -> _LsmScanner: ... + +class _LsmScanner: + @staticmethod + def from_snapshots( + dataset: _Dataset, shard_snapshots: List[_ShardSnapshot] + ) -> _LsmScanner: ... + def project(self, columns: List[str]) -> Self: ... + def filter(self, expr: str) -> Self: ... + def limit(self, n: int, offset: Optional[int] = None) -> Self: ... + def with_row_address(self) -> Self: ... + def with_memtable_gen(self) -> Self: ... + def to_batch(self) -> pa.RecordBatch: ... + def to_batches(self) -> List[pa.RecordBatch]: ... + def count_rows(self) -> int: ... + +class _ExecutionPlan: + schema: pa.Schema + dataset_schema: pa.Schema + def explain(self) -> str: ... + def to_reader(self) -> pa.RecordBatchReader: ... + def to_batches(self) -> List[pa.RecordBatch]: ... + +class _LsmPointLookupPlanner: + def __init__( + self, + dataset: _Dataset, + shard_snapshots: List[_ShardSnapshot], + pk_columns: Optional[List[str]] = None, + ) -> None: ... + def plan_lookup( + self, pk_value: pa.Array, columns: Optional[List[str]] = None + ) -> _ExecutionPlan: ... + +class _LsmVectorSearchPlanner: + def __init__( + self, + dataset: _Dataset, + shard_snapshots: List[_ShardSnapshot], + vector_column: str, + pk_columns: Optional[List[str]] = None, + distance_type: Optional[str] = None, + ) -> None: ... + def plan_search( + self, + query: pa.Array, + k: int = 10, + nprobes: int = 20, + columns: Optional[List[str]] = None, + ) -> _ExecutionPlan: ... class _Hnsw: @staticmethod diff --git a/python/python/lance/mem_wal.py b/python/python/lance/mem_wal.py index 87609435e53..d2ccd463775 100644 --- a/python/python/lance/mem_wal.py +++ b/python/python/lance/mem_wal.py @@ -15,19 +15,20 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Union import pyarrow as pa from .lance import ( + _evaluate_sharding_spec, _ExecutionPlan, _LsmPointLookupPlanner, _LsmScanner, _LsmVectorSearchPlanner, _MergedGeneration, - _RegionSnapshot, - _RegionWriter, + _ShardSnapshot, + _ShardWriter, ) from .types import _coerce_reader @@ -35,11 +36,12 @@ import lance __all__ = [ - "RegionField", - "RegionSpec", + "ShardingField", + "ShardingSpec", + "evaluate_sharding_spec", "MergedGeneration", - "RegionSnapshot", - "RegionWriter", + "ShardSnapshot", + "ShardWriter", "LsmScanner", "ExecutionPlan", "LsmPointLookupPlanner", @@ -48,13 +50,13 @@ # --------------------------------------------------------------------------- -# RegionSpec +# ShardingSpec # --------------------------------------------------------------------------- @dataclass -class RegionField: - """Defines one derived field used in region partitioning. +class ShardingField: + """Defines one MemWAL sharding field. Parameters ---------- @@ -81,11 +83,43 @@ class RegionField: @dataclass -class RegionSpec: - """Partitioning specification for deriving MemWAL region IDs.""" +class ShardingSpec: + """Specification for deriving MemWAL shard routing values.""" spec_id: int - fields: List[RegionField] + fields: List[ShardingField] + + +def evaluate_sharding_spec( + batch: pa.RecordBatch, + spec: Union[ShardingSpec, Mapping[str, object]], + schema: "lance.schema.LanceSchema", +) -> pa.RecordBatch: + """Evaluate a MemWAL sharding spec against one PyArrow RecordBatch. + + Parameters + ---------- + batch : pyarrow.RecordBatch + Input batch containing the sharding source columns. + spec : ShardingSpec or dict + MemWAL sharding spec to evaluate. + schema : LanceSchema + Lance table schema used to resolve source field IDs in the spec to + input batch column names. + """ + if not isinstance(batch, pa.RecordBatch): + raise TypeError(f"Expected pyarrow.RecordBatch, got {type(batch)!r}") + return _evaluate_sharding_spec( + batch, + _sharding_spec_to_dict(spec), + schema, + ) + + +def _sharding_spec_to_dict(spec: Union[ShardingSpec, Mapping[str, object]]) -> dict: + if isinstance(spec, Mapping): + return dict(spec) + return asdict(spec) @dataclass @@ -97,45 +131,45 @@ class MergedGeneration: Parameters ---------- - region_id : str - UUID string for the write region. + shard_id : str + UUID string for the write shard. generation : int Generation number (from - :attr:`RegionSnapshot.flushed_generations`). + :attr:`ShardSnapshot.flushed_generations`). """ - region_id: str + shard_id: str generation: int -class RegionSnapshot: - """Snapshot of a MemWAL region's state, used when constructing scanners. +class ShardSnapshot: + """Snapshot of a MemWAL shard's state, used when constructing scanners. Parameters ---------- - region_id : str - UUID string for the write region. + shard_id : str + UUID string for the write shard. """ - def __init__(self, region_id: str) -> None: - self._raw = _RegionSnapshot(region_id) + def __init__(self, shard_id: str) -> None: + self._raw = _ShardSnapshot(shard_id) @property - def region_id(self) -> str: - """UUID string for this region.""" - return self._raw.region_id + def shard_id(self) -> str: + """UUID string for this shard.""" + return self._raw.shard_id - def with_spec_id(self, spec_id: int) -> "RegionSnapshot": - """Set the RegionSpec ID.""" + def with_spec_id(self, spec_id: int) -> "ShardSnapshot": + """Set the sharding spec ID.""" self._raw = self._raw.with_spec_id(spec_id) return self - def with_current_generation(self, generation: int) -> "RegionSnapshot": + def with_current_generation(self, generation: int) -> "ShardSnapshot": """Set the current (active) generation number.""" self._raw = self._raw.with_current_generation(generation) return self - def with_flushed_generation(self, generation: int, path: str) -> "RegionSnapshot": + def with_flushed_generation(self, generation: int, path: str) -> "ShardSnapshot": """Add a flushed generation with its storage path.""" self._raw = self._raw.with_flushed_generation(generation, path) return self @@ -144,28 +178,28 @@ def __repr__(self) -> str: return repr(self._raw) -class RegionWriter: - """Stateful writer for one MemWAL region. +class ShardWriter: + """Stateful writer for one MemWAL shard. Obtain an instance via mem_wal_writer. Use as a context manager so the writer is closed automatically:: - with dataset.mem_wal_writer(region_id) as writer: + with dataset.mem_wal_writer(shard_id) as writer: writer.put(batch) Parameters ---------- - _raw : _RegionWriter + _raw : _ShardWriter Internal PyO3 object — do not construct directly. """ - def __init__(self, _raw: _RegionWriter) -> None: + def __init__(self, _raw: _ShardWriter) -> None: self._raw = _raw @property - def region_id(self) -> str: - """UUID string for this writer's region.""" - return self._raw.region_id + def shard_id(self) -> str: + """UUID string for this writer's shard.""" + return self._raw.shard_id def put(self, data, *, schema: Optional[pa.Schema] = None) -> None: """Write data to the MemWAL. @@ -222,7 +256,7 @@ def memtable_stats(self) -> dict: return self._raw.memtable_stats() def lsm_scanner( - self, region_snapshots: Optional[List[RegionSnapshot]] = None + self, shard_snapshots: Optional[List[ShardSnapshot]] = None ) -> "LsmScanner": """Create an LSM scanner that includes the active MemTable. @@ -232,18 +266,18 @@ def lsm_scanner( Parameters ---------- - region_snapshots : list of RegionSnapshot, optional - Snapshots of other regions to include. This writer's own region + shard_snapshots : list of ShardSnapshot, optional + Snapshots of other shards to include. This writer's own shard is automatically included. Returns ------- LsmScanner """ - raw_snaps = [s._raw for s in (region_snapshots or [])] + raw_snaps = [s._raw for s in (shard_snapshots or [])] return LsmScanner(self._raw.lsm_scanner(raw_snaps)) - def __enter__(self) -> "RegionWriter": + def __enter__(self) -> "ShardWriter": return self def __exit__(self, exc_type, exc_val, exc_tb) -> bool: @@ -257,7 +291,7 @@ class LsmScanner: Deduplicates by primary key, always returning the newest version of each row across base table, flushed MemTables, and the active MemTable. - Obtain an instance from `RegionWriter.lsm_scanner` (includes + Obtain an instance from `ShardWriter.lsm_scanner` (includes active MemTable) or `LsmScanner.from_snapshots` (flushed only). The builder methods (`project`, `filter`, `limit`) @@ -276,23 +310,21 @@ def __init__(self, _raw: _LsmScanner) -> None: @staticmethod def from_snapshots( dataset: "lance.LanceDataset", - region_snapshots: List[RegionSnapshot], + shard_snapshots: List[ShardSnapshot], ) -> "LsmScanner": - """Create a scanner from dataset and region snapshots. + """Create a scanner from dataset and shard snapshots. Does **not** include the active MemTable; use - `RegionWriter.lsm_scanner` for that. + `ShardWriter.lsm_scanner` for that. Parameters ---------- dataset : LanceDataset The base dataset to scan. - region_snapshots : list of RegionSnapshot - Region snapshots specifying flushed generations to include. + shard_snapshots : list of ShardSnapshot + Shard snapshots specifying flushed generations to include. """ - raw = _LsmScanner.from_snapshots( - dataset._ds, [s._raw for s in region_snapshots] - ) + raw = _LsmScanner.from_snapshots(dataset._ds, [s._raw for s in shard_snapshots]) return LsmScanner(raw) def project(self, columns: List[str]) -> "LsmScanner": @@ -389,8 +421,8 @@ class LsmPointLookupPlanner: ---------- dataset : LanceDataset The base dataset. - region_snapshots : list of RegionSnapshot - Region snapshots specifying flushed generations to include. + shard_snapshots : list of ShardSnapshot + Shard snapshots specifying flushed generations to include. pk_columns : list of str, optional Primary key column names. Inferred from schema metadata if omitted. @@ -404,12 +436,12 @@ class LsmPointLookupPlanner: def __init__( self, dataset: "lance.LanceDataset", - region_snapshots: List[RegionSnapshot], + shard_snapshots: List[ShardSnapshot], pk_columns: Optional[List[str]] = None, ) -> None: self._raw = _LsmPointLookupPlanner( dataset._ds, - [s._raw for s in region_snapshots], + [s._raw for s in shard_snapshots], pk_columns, ) @@ -448,8 +480,8 @@ class LsmVectorSearchPlanner: ---------- dataset : LanceDataset The base dataset. - region_snapshots : list of RegionSnapshot - Region snapshots specifying flushed generations to include. + shard_snapshots : list of ShardSnapshot + Shard snapshots specifying flushed generations to include. vector_column : str Name of the ``FixedSizeList`` vector column. pk_columns : list of str, optional @@ -470,7 +502,7 @@ class LsmVectorSearchPlanner: def __init__( self, dataset: "lance.LanceDataset", - region_snapshots: List[RegionSnapshot], + shard_snapshots: List[ShardSnapshot], vector_column: str, pk_columns: Optional[List[str]] = None, distance_type: Optional[str] = None, @@ -482,7 +514,7 @@ def __init__( kwargs["distance_type"] = distance_type self._raw = _LsmVectorSearchPlanner( dataset._ds, - [s._raw for s in region_snapshots], + [s._raw for s in shard_snapshots], vector_column, **kwargs, ) @@ -493,6 +525,7 @@ def plan_search( k: int = 10, nprobes: int = 20, columns: Optional[List[str]] = None, + refine_factor: Optional[int] = None, ) -> ExecutionPlan: """Plan a KNN vector search. @@ -507,6 +540,11 @@ def plan_search( columns : list of str, optional Columns to project. Returns all columns + ``_distance`` if omitted. + refine_factor : int, optional + When set, the base-table arm fetches ``k * refine_factor`` + candidates and re-ranks with exact distances. Useful when + the base index is approximate (e.g. IVF-PQ). Memtable arms + use exact HNSW and are unaffected. Returns ------- @@ -514,19 +552,21 @@ def plan_search( Physical plan for the vector search. Execute it via `to_table`, `to_reader`, or `to_batches`. """ - return ExecutionPlan(self._raw.plan_search(query, k, nprobes, columns)) + return ExecutionPlan( + self._raw.plan_search(query, k, nprobes, columns, refine_factor) + ) -def _unwrap_region_id(region_id: str) -> str: - """Validate region_id is a UUID string.""" +def _unwrap_shard_id(shard_id: str) -> str: + """Validate shard_id is a UUID string.""" import uuid as _uuid - _uuid.UUID(region_id) # raises ValueError if invalid - return region_id + _uuid.UUID(shard_id) # raises ValueError if invalid + return shard_id def _to_raw_merged_generations( generations: Iterable[MergedGeneration], ) -> list: """Convert Python MergedGeneration list to PyO3 _MergedGeneration list.""" - return [_MergedGeneration(g.region_id, g.generation) for g in generations] + return [_MergedGeneration(g.shard_id, g.generation) for g in generations] diff --git a/python/python/lance/namespace.py b/python/python/lance/namespace.py index 76a67695f30..f448e5c3368 100644 --- a/python/python/lance/namespace.py +++ b/python/python/lance/namespace.py @@ -28,6 +28,8 @@ AlterTransactionResponse, AnalyzeTableQueryPlanRequest, CountTableRowsRequest, + CreateMaterializedViewRequest, + CreateMaterializedViewResponse, CreateNamespaceRequest, CreateNamespaceResponse, CreateTableIndexRequest, @@ -806,6 +808,13 @@ def refresh_materialized_view( response_dict = self._inner.refresh_materialized_view(request.model_dump()) return RefreshMaterializedViewResponse.from_dict(response_dict) + def create_materialized_view( + self, request: CreateMaterializedViewRequest + ) -> CreateMaterializedViewResponse: + """Create a materialized view backed by an optional UDTF/chunker.""" + response_dict = self._inner.create_materialized_view(request.model_dump()) + return CreateMaterializedViewResponse.from_dict(response_dict) + # Table tag operations def list_table_tags(self, request: ListTableTagsRequest) -> ListTableTagsResponse: @@ -1369,6 +1378,13 @@ def refresh_materialized_view( response_dict = self._inner.refresh_materialized_view(request.model_dump()) return RefreshMaterializedViewResponse.from_dict(response_dict) + def create_materialized_view( + self, request: CreateMaterializedViewRequest + ) -> CreateMaterializedViewResponse: + """Create a materialized view backed by an optional UDTF/chunker.""" + response_dict = self._inner.create_materialized_view(request.model_dump()) + return CreateMaterializedViewResponse.from_dict(response_dict) + # Table tag operations def list_table_tags(self, request: ListTableTagsRequest) -> ListTableTagsResponse: diff --git a/python/python/tests/compat/test_vector_indices.py b/python/python/tests/compat/test_vector_indices.py index 194435c095a..b98ffdf63e3 100644 --- a/python/python/tests/compat/test_vector_indices.py +++ b/python/python/tests/compat/test_vector_indices.py @@ -71,6 +71,17 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + elif self.compat_version >= "0.39.0": + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0]["name"] + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and rebuild index.""" ds = lance.dataset(self.path) @@ -140,6 +151,18 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + else: + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0]["name"] + + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and rebuild index.""" ds = lance.dataset(self.path) @@ -209,6 +232,18 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + else: + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0]["name"] + + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and rebuild index.""" ds = lance.dataset(self.path) @@ -226,9 +261,9 @@ def check_write(self): ds.optimize.compact_files() -@compat_test(min_version="0.39.0") +@compat_test(min_version="4.0.0-beta.8") class IvfRqVectorIndex(UpgradeDowngradeTest): - """Test IVF_RQ vector index compatibility.""" + """Test IVF_RQ vector index compatibility. V2 was introduced in v4.0.0-beta.8""" def __init__(self, path: Path): self.path = path @@ -273,6 +308,18 @@ def check_read(self): ) assert result.num_rows == 4 + if hasattr(ds, "describe_indices"): + indices = ds.describe_indices() + assert len(indices) >= 1 + name = indices[0].name + else: + indices = ds.list_indices() + assert len(indices) >= 1 + name = indices[0].name + + stats = ds.stats.index_stats(name) + assert stats["num_indexed_rows"] > 0 + def check_write(self): """Verify can insert vectors and run optimize workflows.""" ds = lance.dataset(self.path) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 428b85095ce..dc2030c4b2b 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -436,18 +436,27 @@ def test_has_stable_row_ids_property(tmp_path: Path): assert lance.dataset(non_stable_path).has_stable_row_ids is False +def _list_manifests(versions_dir): + # Ignore the version hint file, which is not a manifest. + return [ + name + for name in os.listdir(versions_dir) + if not name.startswith("latest_version_hint") + ] + + def test_v2_manifest_paths(tmp_path: Path): lance.write_dataset( pa.table({"a": range(100)}), tmp_path, enable_v2_manifest_paths=True ) - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert len(manifest_path) == 1 assert re.match(r"\d{20}\.manifest", manifest_path[0]) def test_default_v2_manifest_paths(tmp_path: Path): lance.write_dataset(pa.table({"a": range(100)}), tmp_path) - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert len(manifest_path) == 1 assert re.match(r"\d{20}\.manifest", manifest_path[0]) @@ -457,12 +466,12 @@ def test_v2_manifest_paths_migration(tmp_path: Path): lance.write_dataset( pa.table({"a": range(100)}), tmp_path, enable_v2_manifest_paths=False ) - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert manifest_path == ["1.manifest"] # Migrate to v2 manifest paths lance.dataset(tmp_path).migrate_manifest_paths_v2() - manifest_path = os.listdir(tmp_path / "_versions") + manifest_path = _list_manifests(tmp_path / "_versions") assert len(manifest_path) == 1 assert re.match(r"\d{20}\.manifest", manifest_path[0]) diff --git a/python/python/tests/test_log.py b/python/python/tests/test_log.py index b00fe5813a0..1eb02957044 100644 --- a/python/python/tests/test_log.py +++ b/python/python/tests/test_log.py @@ -148,6 +148,36 @@ def test_lance_log_file_with_directory_creation(tmp_path): assert len(log_content.strip()) > 0, "Log file is empty" +@pytest.mark.skipif( + sys.platform == "win32", + reason="subprocess does not work correctly in CI on Windows", +) +def test_lance_log_filters_trace_event_targets(tmp_path): + log_file = tmp_path / "lance_rust.log" + + result = subprocess.run( + [ + sys.executable, + "-c", + "import lance; import pyarrow as pa; " + "lance.write_dataset(pa.table({'x': range(10)}), 'memory://test')", + ], + capture_output=True, + env={ + "LANCE_LOG": "warn,lance::events::dataset_events=info", + "LANCE_LOG_FILE": str(log_file), + }, + ) + + assert result.returncode == 0, f"Command failed: {result.stderr.decode()}" + + log_content = log_file.read_text() + assert "lance::events::dataset_events" in log_content + assert 'target="lance::dataset_events"' in log_content + assert "lance::events::file_audit" not in log_content + assert "lance::file_audit" not in log_content + + @pytest.mark.skipif( sys.platform == "win32", reason="subprocess does not work correctly in CI on Windows", diff --git a/python/python/tests/test_mem_wal.py b/python/python/tests/test_mem_wal.py index e63aacff57b..b8c859cb637 100644 --- a/python/python/tests/test_mem_wal.py +++ b/python/python/tests/test_mem_wal.py @@ -11,8 +11,12 @@ from lance.mem_wal import ( LsmPointLookupPlanner, LsmScanner, - RegionSnapshot, + ShardingField, + ShardingSpec, + ShardSnapshot, + evaluate_sharding_spec, ) +from lance.schema import LanceSchema _PK_META = {"lance-schema:unenforced-primary-key": "true"} _LOOKUP_SCHEMA = pa.schema( @@ -21,6 +25,12 @@ pa.field("name", pa.utf8()), ] ) +_APPEND_ONLY_SCHEMA = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.utf8()), + ] +) def _lookup_table(ids, prefix: str) -> pa.Table: @@ -34,13 +44,24 @@ def _lookup_table(ids, prefix: str) -> pa.Table: ) -def _write_flushed_gen(base_path: str, region_id: str, gen_folder: str, data: pa.Table): +def _append_only_table(ids, prefix: str) -> pa.Table: + """Build a table without primary-key metadata.""" + return pa.table( + { + "id": pa.array(ids, pa.int64()), + "name": pa.array([f"{prefix}_{i}" for i in ids], pa.utf8()), + }, + schema=_APPEND_ONLY_SCHEMA, + ) + + +def _write_flushed_gen(base_path: str, shard_id: str, gen_folder: str, data: pa.Table): """Write a flushed-generation Lance dataset at the expected sub-path. The collector resolves flushed generation paths as: - {base_dataset_path}/_mem_wal/{region_id}/{gen_folder} + {base_dataset_path}/_mem_wal/{shard_id}/{gen_folder} """ - gen_path = os.path.join(base_path, "_mem_wal", region_id, gen_folder) + gen_path = os.path.join(base_path, "_mem_wal", shard_id, gen_folder) lance.write_dataset(data, gen_path, schema=_LOOKUP_SCHEMA) @@ -54,10 +75,10 @@ def test_point_lookup_with_memtables(tmp_path): base : ids [1, 2, 3] names ["base_1", "base_2", "base_3"] gen_1 : ids [2] names ["gen1_2"] ← update to id=2 - RegionSnapshot: flushed_generation(gen=1, path="gen_1"), current_generation=2 + ShardSnapshot: flushed_generation(gen=1, path="gen_1"), current_generation=2 """ ds_path = str(tmp_path / "base") - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) # --- Base dataset --- base_ds = lance.write_dataset( @@ -66,11 +87,11 @@ def test_point_lookup_with_memtables(tmp_path): base_ds.initialize_mem_wal() # --- Flushed generation: overwrites id=2 --- - _write_flushed_gen(ds_path, region_id, "gen_1", _lookup_table([2], "gen1")) + _write_flushed_gen(ds_path, shard_id, "gen_1", _lookup_table([2], "gen1")) - # --- RegionSnapshot describing the flushed state --- + # --- ShardSnapshot describing the flushed state --- snap = ( - RegionSnapshot(region_id) + ShardSnapshot(shard_id) .with_flushed_generation(1, "gen_1") .with_current_generation(2) ) @@ -110,17 +131,17 @@ def test_lsm_scanner_with_memtables(tmp_path): Expected result: 3 unique rows — id=2 from gen_1, id=1 and id=3 from base. """ ds_path = str(tmp_path / "base") - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) base_ds = lance.write_dataset( _lookup_table([1, 2, 3], "base"), ds_path, schema=_LOOKUP_SCHEMA ) base_ds.initialize_mem_wal() - _write_flushed_gen(ds_path, region_id, "gen_1", _lookup_table([2], "gen1")) + _write_flushed_gen(ds_path, shard_id, "gen_1", _lookup_table([2], "gen1")) snap = ( - RegionSnapshot(region_id) + ShardSnapshot(shard_id) .with_flushed_generation(1, "gen_1") .with_current_generation(2) ) @@ -136,14 +157,14 @@ def test_lsm_scanner_with_memtables(tmp_path): assert name_by_id[3] == "base_3" -def test_region_writer_lsm_scanner_includes_own_flushed_generations(tmp_path): +def test_shard_writer_lsm_scanner_includes_own_flushed_generations(tmp_path): ds_path = str(tmp_path / "base") - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) ds = lance.write_dataset(_lookup_table([0], "base"), ds_path, schema=_LOOKUP_SCHEMA) ds.initialize_mem_wal() with ds.mem_wal_writer( - region_id, + shard_id, durable_write=True, max_wal_buffer_size=1, max_wal_flush_interval_ms=10, @@ -232,15 +253,15 @@ def _e2e_batch(schema, start_id: int, num_rows: int) -> pa.RecordBatch: ) -def test_region_writer_e2e_correctness(tmp_path): +def test_shard_writer_e2e_correctness(tmp_path): """ - End-to-end correctness test for RegionWriter covering: + End-to-end correctness test for ShardWriter covering: - Multi-round writes that trigger WAL and MemTable flushes - - File-system layout verification (_mem_wal//wal/ and manifest/) + - File-system layout verification (_mem_wal//wal/ and manifest/) - Flushed generation data readable via LsmScanner - New writer created after close can write and scan correctly - Mirrors Rust test: region_writer_tests::test_region_writer_e2e_correctness + Mirrors Rust test: shard_writer_tests::test_shard_writer_e2e_correctness """ schema = _e2e_schema() ds_path = str(tmp_path / "ds") @@ -255,9 +276,9 @@ def test_region_writer_e2e_correctness(tmp_path): ds.initialize_mem_wal(maintained_indexes=["id_btree"]) # Small buffers to trigger WAL and MemTable flushes during the test - region_id = str(uuid.uuid4()) + shard_id = str(uuid.uuid4()) writer = ds.mem_wal_writer( - region_id, + shard_id, durable_write=True, sync_indexed_write=True, max_wal_buffer_size=10 * 1024, # 10 KB @@ -287,7 +308,7 @@ def test_region_writer_e2e_correctness(tmp_path): assert closed_memtable_stats["generation"] >= 1 # === File-system layout === - mem_wal_dir = os.path.join(ds_path, "_mem_wal", region_id) + mem_wal_dir = os.path.join(ds_path, "_mem_wal", shard_id) assert os.path.isdir(mem_wal_dir), f"MemWAL directory missing: {mem_wal_dir}" wal_dir = os.path.join(mem_wal_dir, "wal") @@ -308,9 +329,9 @@ def test_region_writer_e2e_correctness(tmp_path): # === New writer: write and read back via active MemTable scanner === ds2 = lance.dataset(ds_path) - region_id2 = str(uuid.uuid4()) + shard_id2 = str(uuid.uuid4()) with ds2.mem_wal_writer( - region_id2, durable_write=False, sync_indexed_write=True + shard_id2, durable_write=False, sync_indexed_write=True ) as writer2: verify_batch = _e2e_batch(schema, start_id=10000, num_rows=10) writer2.put(pa.Table.from_batches([verify_batch])) @@ -358,6 +379,31 @@ def test_initialize_mem_wal_unsharded(tmp_path): def test_initialize_mem_wal_bucket_sharding(tmp_path): ds = _mem_wal_dataset(tmp_path) + ds.initialize_mem_wal(bucket_column="name", num_buckets=8) + + details = ds.mem_wal_index_details() + assert details["num_shards"] == 8 + field = details["sharding_specs"][0]["fields"][0] + assert field["transform"] == "bucket" + assert "expression" in field + assert field["parameters"]["num_buckets"] == "8" + assert len(field["source_ids"]) == 1 + + batch = _lookup_table([1, 2, 3], "base").to_batches()[0] + result = evaluate_sharding_spec( + batch, details["sharding_specs"][0], LanceSchema.from_pyarrow(batch.schema) + ) + assert result.column_names == [field["field_id"]] + assert result.num_rows == batch.num_rows + + +def test_initialize_mem_wal_bucket_sharding_without_primary_key(tmp_path): + ds_path = str(tmp_path / "append_only") + ds = lance.write_dataset( + _append_only_table([1, 2, 3], "base"), + ds_path, + schema=_APPEND_ONLY_SCHEMA, + ) ds.initialize_mem_wal(bucket_column="id", num_buckets=8) details = ds.mem_wal_index_details() @@ -423,6 +469,55 @@ def test_initialize_mem_wal_rejects_partial_bucket(tmp_path): ds.initialize_mem_wal(bucket_column="id") +def test_evaluate_sharding_spec_python_binding(): + batch = pa.record_batch( + [pa.array([1, 2, None, 3], type=pa.int32())], + names=["id"], + ) + spec = ShardingSpec( + 1, + [ + ShardingField( + field_id="bucket", + source_ids=[0], + transform="bucket", + result_type="int32", + parameters={"num_buckets": "8"}, + ) + ], + ) + + result = evaluate_sharding_spec(batch, spec, LanceSchema.from_pyarrow(batch.schema)) + + assert result.column_names == ["bucket"] + assert result.column(0).to_pylist() == [2, 7, 0, 1] + + +def test_evaluate_sharding_spec_python_binding_column_parameter(): + batch = pa.record_batch( + [pa.array(["a", "b", None], type=pa.utf8())], + names=["key"], + ) + spec = { + "spec_id": 1, + "fields": [ + { + "field_id": "key_bucket", + "source_ids": [0], + "transform": "bucket", + "expression": None, + "result_type": "int32", + "parameters": {"num_buckets": "8"}, + } + ], + } + + result = evaluate_sharding_spec(batch, spec, LanceSchema.from_pyarrow(batch.schema)) + + assert result.column_names == ["key_bucket"] + assert result.column(0).to_pylist() == [1, 5, 0] + + def test_mem_wal_index_details_none_before_init(tmp_path): ds = _mem_wal_dataset(tmp_path) assert ds.mem_wal_index_details() is None diff --git a/python/python/tests/test_multi_base.py b/python/python/tests/test_multi_base.py index baa437b4fc6..113f3224d14 100644 --- a/python/python/tests/test_multi_base.py +++ b/python/python/tests/test_multi_base.py @@ -1031,6 +1031,7 @@ def test_write_fragments_with_target_bases(self): dataset, mode="append", target_bases=["base2"], + base_store_params={self.base2_uri: {}}, max_rows_per_file=25, ) @@ -1054,6 +1055,42 @@ def test_write_fragments_with_target_bases(self): data_files = list(base2_path.glob("**/*.lance")) assert len(data_files) > 0, "Expected data files in base2" + def test_write_fragments_with_base_store_params(self): + """Test write_fragments inherits base_store_params from LanceDataset.""" + initial_data = pd.DataFrame({"id": range(10), "value": range(10)}) + + dataset = lance.write_dataset( + initial_data, + self.primary_uri, + mode="create", + initial_bases=[ + DatasetBasePath(self.base1_uri, name="base1"), + DatasetBasePath(self.base2_uri, name="base2"), + ], + target_bases=["base1"], + base_store_params={self.base2_uri: {}}, + max_rows_per_file=10, + ) + + fragment_data = pd.DataFrame({"id": range(10, 20), "value": range(10, 20)}) + fragments = write_fragments( + pa.Table.from_pandas(fragment_data), + dataset, + mode="append", + target_bases=["base2"], + max_rows_per_file=10, + ) + + operation = lance.LanceOperation.Append(fragments) + dataset = lance.LanceDataset.commit( + dataset, operation, read_version=dataset.version + ) + + result = dataset.to_table().to_pandas() + assert len(result) == 20 + assert set(result["id"]) == set(range(20)) + assert list(Path(self.base2_uri).glob("**/*.lance")) + def test_write_fragments_transaction_with_target_bases(self): """Test write_fragments with return_transaction and target_bases.""" # Create initial dataset diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index b04693ee85c..356f72a5e66 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1678,7 +1678,7 @@ def test_describe_vector_index(indexed_dataset: LanceDataset): info = indexed_dataset.describe_indices()[0] assert info.name == "vector_idx" - assert info.type_url == "/lance.table.VectorIndexDetails" + assert info.type_url == "/lance.index.pb.VectorIndexDetails" assert info.index_type == "IVF_PQ" assert info.num_rows_indexed == 1000 assert info.fields == [0] @@ -1689,6 +1689,31 @@ def test_describe_vector_index(indexed_dataset: LanceDataset): assert info.segments[0].index_version == 1 assert info.segments[0].created_at is not None + details = info.details + assert details["metric_type"] == "L2" + assert details["compression"]["type"] == "pq" + assert details["compression"]["num_bits"] == 8 + assert details["compression"]["num_sub_vectors"] == 16 + + +def test_describe_index_runtime_hints_stored(tmp_path): + tbl = create_table(nvec=300, ndim=16) + dataset = lance.write_dataset(tbl, tmp_path) + dataset = dataset.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=4, + max_iters=100, + sample_rate=512, + ) + details = dataset.describe_indices()[0].details + hints = details.get("runtime_hints", {}) + assert hints.get("lance.ivf.max_iters") == "100" + assert hints.get("lance.ivf.sample_rate") == "512" + assert hints.get("lance.pq.max_iters") == "100" + assert hints.get("lance.pq.sample_rate") == "512" + def test_optimize_indices(indexed_dataset): data = create_table() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 128f3f38b53..c868504e87c 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -3228,9 +3228,9 @@ impl Dataset { /// Initialize MemWAL on this dataset. /// - /// Must be called once before any `mem_wal_writer()` calls. Requires the - /// dataset schema to have at least one field with the - /// `lance-schema:unenforced-primary-key` metadata. + /// Must be called once before any `mem_wal_writer()` calls. Append-only + /// tables may omit primary-key metadata; primary keys are only required + /// for primary-key lookup and last-write-wins deduplication workflows. /// /// At most one sharding mode may be selected: bucket sharding /// (`bucket_column` + `num_buckets`), identity sharding (`identity_column`), @@ -3375,6 +3375,7 @@ impl Dataset { field_dict.set_item("field_id", &field.field_id)?; field_dict.set_item("source_ids", field.source_ids.clone())?; field_dict.set_item("transform", field.transform.clone())?; + field_dict.set_item("expression", field.expression.clone())?; field_dict.set_item("result_type", &field.result_type)?; field_dict.set_item("parameters", field.parameters.clone())?; fields.append(field_dict)?; @@ -3386,12 +3387,12 @@ impl Dataset { Ok(Some(dict)) } - /// Get a RegionWriter for the specified region. + /// Get a ShardWriter for the specified shard. /// /// `initialize_mem_wal()` must be called before using this method. #[allow(clippy::too_many_arguments)] #[pyo3(signature=( - region_id, + shard_id, *, durable_write=None, sync_indexed_write=None, @@ -3410,7 +3411,7 @@ impl Dataset { fn mem_wal_writer( &self, py: Python<'_>, - region_id: String, + shard_id: String, durable_write: Option, sync_indexed_write: Option, max_wal_buffer_size: Option, @@ -3424,11 +3425,11 @@ impl Dataset { async_index_interval_ms: Option, backpressure_log_interval_ms: Option, stats_log_interval_ms: Option, - ) -> PyResult { + ) -> PyResult { use lance::dataset::mem_wal::DatasetMemWalExt; - let uuid = uuid::Uuid::parse_str(®ion_id) - .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + let uuid = uuid::Uuid::parse_str(&shard_id) + .map_err(|e| PyValueError::new_err(format!("Invalid shard_id UUID: {}", e)))?; let config = writer_config_from_kwargs( durable_write, @@ -3455,7 +3456,7 @@ impl Dataset { )? .map_err(|e| PyIOError::new_err(e.to_string()))?; - Ok(crate::mem_wal::PyRegionWriter::new( + Ok(crate::mem_wal::PyShardWriter::new( writer, uuid, self.ds.clone(), @@ -4286,6 +4287,12 @@ fn prepare_vector_index_params( sq_params.sample_rate = sample_rate; } + if let Some(max_iters) = kwargs.get_item("max_iters")? { + let max_iters: usize = max_iters.extract()?; + ivf_params.max_iters = max_iters; + pq_params.max_iters = max_iters; + } + // Parse IVF params if let Some(n) = kwargs.get_item("num_partitions")? { ivf_params.num_partitions = Some(n.extract()?) @@ -4443,6 +4450,13 @@ fn prepare_vector_index_params( }?; params.version(index_file_version); params.skip_transpose(skip_transpose); + if let Some(kwargs) = kwargs + && let Some(acc) = kwargs.get_item("accelerator")? + { + params + .runtime_hints + .insert("lancedb.accelerator".to_string(), acc.to_string()); + } Ok(params) } diff --git a/python/src/indices.rs b/python/src/indices.rs index 5f00e9f5726..cf93579b867 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -537,8 +537,7 @@ async fn do_load_shuffled_vectors( dataset_version: ds.manifest.version, fragment_bitmap: Some(ds.fragments().iter().map(|f| f.id as u32).collect()), index_details: Some(Arc::new( - prost_types::Any::from_msg(&lance_table::format::pb::VectorIndexDetails::default()) - .unwrap(), + prost_types::Any::from_msg(&lance_index::pb::VectorIndexDetails::default()).unwrap(), )), index_version: IndexType::IvfPq.version(), created_at: Some(Utc::now()), diff --git a/python/src/lib.rs b/python/src/lib.rs index f1d384d275e..cf29b26c46a 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -286,12 +286,13 @@ fn lance(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; // MemWAL classes m.add_class::()?; - m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_wrapped(wrap_pyfunction!(mem_wal::py_evaluate_sharding_spec))?; m.add_wrapped(wrap_pyfunction!(bfloat16_array))?; m.add_wrapped(wrap_pyfunction!(write_dataset))?; m.add_wrapped(wrap_pyfunction!(write_fragments))?; diff --git a/python/src/mem_wal.rs b/python/src/mem_wal.rs index 68bd57a58bc..5a8516045ab 100644 --- a/python/src/mem_wal.rs +++ b/python/src/mem_wal.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::sync::Arc; use arrow::ffi_stream::ArrowArrayStreamReader; @@ -20,10 +21,10 @@ use lance::dataset::mem_wal::scanner::{ FlushedGeneration, LsmDataSourceCollector, LsmPointLookupPlanner, LsmVectorSearchPlanner, }; use lance::dataset::mem_wal::write::{MemTableStats, WriteStatsSnapshot}; -use lance::dataset::mem_wal::{ - LsmScanner, ShardSnapshot as RegionSnapshot, ShardWriter as RegionWriter, +use lance::dataset::mem_wal::{LsmScanner, ShardSnapshot, ShardWriter, evaluate_sharding_spec}; +use lance_index::mem_wal::{ + MergedGeneration as LanceMergedGeneration, ShardingField, ShardingSpec, }; -use lance_index::mem_wal::MergedGeneration as LanceMergedGeneration; use lance_linalg::distance::DistanceType; use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; @@ -33,28 +34,82 @@ use uuid::Uuid; use crate::dataset::Dataset as PyDataset; use crate::rt; +use crate::schema::LanceSchema as PyLanceSchema; + +/// Evaluate a MemWAL sharding spec against one PyArrow RecordBatch. +#[pyfunction(name = "_evaluate_sharding_spec", signature = (batch, spec, schema))] +pub fn py_evaluate_sharding_spec<'py>( + py: Python<'py>, + batch: PyArrowType, + spec: &Bound<'_, PyAny>, + schema: &PyLanceSchema, +) -> PyResult> { + let PyArrowType(batch) = batch; + let spec = sharding_spec_from_py(spec)?; + let result = evaluate_sharding_spec(&batch, &spec, &schema.0) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + result.to_pyarrow(py) +} + +fn sharding_spec_from_py(spec: &Bound<'_, PyAny>) -> PyResult { + let spec_id = get_py_value(spec, "spec_id")?.extract::()?; + let fields_obj = get_py_value(spec, "fields")?; + let mut fields = Vec::new(); + for field_obj in fields_obj.try_iter()? { + fields.push(sharding_field_from_py(&field_obj?)?); + } + Ok(ShardingSpec { spec_id, fields }) +} -/// Represents a single generation of a MemWAL region that has been merged +fn sharding_field_from_py(field: &Bound<'_, PyAny>) -> PyResult { + Ok(ShardingField { + field_id: get_py_value(field, "field_id")?.extract::()?, + source_ids: get_py_value(field, "source_ids")?.extract::>()?, + transform: optional_string(get_py_value(field, "transform")?)?, + expression: optional_string(get_py_value(field, "expression")?)?, + result_type: get_py_value(field, "result_type")?.extract::()?, + parameters: get_py_value(field, "parameters")?.extract::>()?, + }) +} + +fn get_py_value<'py>(obj: &Bound<'py, PyAny>, name: &str) -> PyResult> { + if let Ok(dict) = obj.cast::() { + dict.get_item(name)? + .ok_or_else(|| PyValueError::new_err(format!("Missing sharding spec field '{name}'"))) + } else { + obj.getattr(name) + } +} + +fn optional_string(value: Bound<'_, PyAny>) -> PyResult> { + if value.is_none() { + Ok(None) + } else { + value.extract::().map(Some) + } +} + +/// Represents a single generation of a MemWAL shard that has been merged /// into the base table. Used with `MergeInsertBuilder.mark_generations_as_merged()`. #[pyclass(name = "_MergedGeneration", module = "_lib")] pub struct PyMergedGeneration { - pub region_id: String, + pub shard_id: String, pub generation: u64, } #[pymethods] impl PyMergedGeneration { #[new] - pub fn new(region_id: String, generation: u64) -> Self { + pub fn new(shard_id: String, generation: u64) -> Self { Self { - region_id, + shard_id, generation, } } #[getter] - pub fn region_id(&self) -> &str { - &self.region_id + pub fn shard_id(&self) -> &str { + &self.shard_id } #[getter] @@ -64,42 +119,42 @@ impl PyMergedGeneration { pub fn __repr__(&self) -> String { format!( - "_MergedGeneration(region_id='{}', generation={})", - self.region_id, self.generation + "_MergedGeneration(shard_id='{}', generation={})", + self.shard_id, self.generation ) } } impl PyMergedGeneration { pub fn to_lance(&self) -> PyResult { - let uuid = Uuid::parse_str(&self.region_id) - .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + let uuid = Uuid::parse_str(&self.shard_id) + .map_err(|e| PyValueError::new_err(format!("Invalid shard_id UUID: {}", e)))?; Ok(LanceMergedGeneration::new(uuid, self.generation)) } } -/// Snapshot of a MemWAL region's state at a point in time. +/// Snapshot of a MemWAL shard's state at a point in time. /// /// Used to specify which flushed generations to include when creating an /// `_LsmScanner`. Supports a builder pattern for adding generations. -#[pyclass(name = "_RegionSnapshot", module = "_lib", skip_from_py_object)] +#[pyclass(name = "_ShardSnapshot", module = "_lib", skip_from_py_object)] #[derive(Clone)] -pub struct PyRegionSnapshot { - pub inner: RegionSnapshot, +pub struct PyShardSnapshot { + pub inner: ShardSnapshot, } #[pymethods] -impl PyRegionSnapshot { +impl PyShardSnapshot { #[new] - pub fn new(region_id: String) -> PyResult { - let uuid = Uuid::parse_str(®ion_id) - .map_err(|e| PyValueError::new_err(format!("Invalid region_id UUID: {}", e)))?; + pub fn new(shard_id: String) -> PyResult { + let uuid = Uuid::parse_str(&shard_id) + .map_err(|e| PyValueError::new_err(format!("Invalid shard_id UUID: {}", e)))?; Ok(Self { - inner: RegionSnapshot::new(uuid), + inner: ShardSnapshot::new(uuid), }) } - /// Set the RegionSpec ID for this snapshot. + /// Set the sharding spec ID for this snapshot. pub fn with_spec_id(mut slf: PyRefMut<'_, Self>, spec_id: u32) -> PyRefMut<'_, Self> { slf.inner = slf.inner.clone().with_spec_id(spec_id); slf @@ -125,13 +180,13 @@ impl PyRegionSnapshot { } #[getter] - pub fn region_id(&self) -> String { + pub fn shard_id(&self) -> String { self.inner.shard_id.to_string() } pub fn __repr__(&self) -> String { format!( - "_RegionSnapshot(region_id='{}', current_gen={}, flushed_gens={})", + "_ShardSnapshot(shard_id='{}', current_gen={}, flushed_gens={})", self.inner.shard_id, self.inner.current_generation, self.inner.flushed_generations.len() @@ -139,26 +194,26 @@ impl PyRegionSnapshot { } } -/// Long-lived stateful writer for a MemWAL region. +/// Long-lived stateful writer for a MemWAL shard. /// /// Supports writing batches, querying statistics, creating LSM scanners, /// and graceful shutdown. Supports the Python context manager protocol. -#[pyclass(name = "_RegionWriter", module = "_lib")] -pub struct PyRegionWriter { - inner: Arc>>, - closed_state: Arc>>, - region_id: Uuid, +#[pyclass(name = "_ShardWriter", module = "_lib")] +pub struct PyShardWriter { + inner: Arc>>, + closed_state: Arc>>, + shard_id: Uuid, dataset: Arc, } #[derive(Clone)] -struct ClosedRegionWriterState { +struct ClosedShardWriterState { stats: WriteStatsSnapshot, memtable_stats: MemTableStats, } #[pymethods] -impl PyRegionWriter { +impl PyShardWriter { /// Write data batches to the MemWAL. /// /// Accepts any PyArrow-compatible data source (RecordBatch, Table, @@ -180,7 +235,7 @@ impl PyRegionWriter { match guard.as_ref() { Some(writer) => writer.put(batches).await.map(|_| ()), None => Err(lance_core::Error::invalid_input( - "RegionWriter is already closed", + "ShardWriter is already closed", )), } })? @@ -205,7 +260,7 @@ impl PyRegionWriter { writer.close().await?; let closed_memtable_stats = closed_memtable_stats(memtable_stats_before_close); let mut closed_guard = closed_state.lock().await; - *closed_guard = Some(ClosedRegionWriterState { + *closed_guard = Some(ClosedShardWriterState { stats: stats_snapshot, memtable_stats: closed_memtable_stats, }); @@ -235,7 +290,7 @@ impl PyRegionWriter { .as_ref() .map(|state| state.stats.clone()) .ok_or_else(|| { - lance_core::Error::invalid_input("RegionWriter is already closed") + lance_core::Error::invalid_input("ShardWriter is already closed") }) } })? @@ -262,7 +317,7 @@ impl PyRegionWriter { .as_ref() .map(|state| state.memtable_stats.clone()) .ok_or_else(|| { - lance_core::Error::invalid_input("RegionWriter is already closed") + lance_core::Error::invalid_input("ShardWriter is already closed") }) } } @@ -275,13 +330,13 @@ impl PyRegionWriter { /// Create an LSM scanner that includes the active MemTable for strong consistency. /// /// The scanner covers: base table + given flushed generations + current active MemTable. - #[pyo3(signature = (region_snapshots=vec![]))] + #[pyo3(signature = (shard_snapshots=vec![]))] pub fn lsm_scanner( &self, py: Python<'_>, - region_snapshots: Vec>, + shard_snapshots: Vec>, ) -> PyResult { - let mut snapshots: Vec = region_snapshots + let mut snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -289,7 +344,7 @@ impl PyRegionWriter { let pk_columns = get_pk_columns(&self.dataset)?; let inner = self.inner.clone(); let dataset = self.dataset.clone(); - let region_id = self.region_id; + let shard_id = self.shard_id; let (active_ref, writer_snapshot) = rt() .block_on(Some(py), async move { @@ -300,31 +355,31 @@ impl PyRegionWriter { let writer_snapshot = w .manifest() .await? - .map(region_snapshot_from_manifest) - .unwrap_or_else(|| RegionSnapshot::new(region_id)); + .map(shard_snapshot_from_manifest) + .unwrap_or_else(|| ShardSnapshot::new(shard_id)); Ok((active_ref, writer_snapshot)) } None => Err(lance_core::Error::invalid_input( - "RegionWriter is already closed", + "ShardWriter is already closed", )), } })? .map_err(|e| PyIOError::new_err(e.to_string()))?; - snapshots.retain(|snapshot| snapshot.shard_id != region_id); + snapshots.retain(|snapshot| snapshot.shard_id != shard_id); snapshots.push(writer_snapshot); let scanner = LsmScanner::new(dataset, snapshots, pk_columns) - .with_active_memtable(region_id, active_ref); + .with_active_memtable(shard_id, active_ref); Ok(PyLsmScanner { inner: Some(scanner), }) } - /// Return the region ID as a UUID string. + /// Return the shard ID as a UUID string. #[getter] - pub fn region_id(&self) -> String { - self.region_id.to_string() + pub fn shard_id(&self) -> String { + self.shard_id.to_string() } pub fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { @@ -343,13 +398,13 @@ impl PyRegionWriter { } } -impl PyRegionWriter { - /// Create from a Rust RegionWriter and dataset reference. - pub fn new(writer: RegionWriter, region_id: Uuid, dataset: Arc) -> Self { +impl PyShardWriter { + /// Create from a Rust ShardWriter and dataset reference. + pub fn new(writer: ShardWriter, shard_id: Uuid, dataset: Arc) -> Self { Self { inner: Arc::new(TokioMutex::new(Some(writer))), closed_state: Arc::new(TokioMutex::new(None)), - region_id, + shard_id, dataset, } } @@ -438,14 +493,14 @@ pub struct PyLsmScanner { #[pymethods] impl PyLsmScanner { - /// Create a scanner from dataset and region snapshots (without active MemTable). + /// Create a scanner from dataset and shard snapshots (without active MemTable). #[staticmethod] pub fn from_snapshots( dataset: &Bound<'_, PyDataset>, - region_snapshots: Vec>, + shard_snapshots: Vec>, ) -> PyResult { let ds = dataset.borrow().ds.clone(); - let snapshots: Vec = region_snapshots + let snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -584,14 +639,14 @@ pub struct PyLsmPointLookupPlanner { #[pymethods] impl PyLsmPointLookupPlanner { #[new] - #[pyo3(signature = (dataset, region_snapshots, pk_columns=None))] + #[pyo3(signature = (dataset, shard_snapshots, pk_columns=None))] pub fn new( dataset: &Bound<'_, PyDataset>, - region_snapshots: Vec>, + shard_snapshots: Vec>, pk_columns: Option>, ) -> PyResult { let ds = dataset.borrow().ds.clone(); - let snapshots: Vec = region_snapshots + let snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -649,16 +704,16 @@ pub struct PyLsmVectorSearchPlanner { #[pymethods] impl PyLsmVectorSearchPlanner { #[new] - #[pyo3(signature = (dataset, region_snapshots, vector_column, pk_columns=None, distance_type=None))] + #[pyo3(signature = (dataset, shard_snapshots, vector_column, pk_columns=None, distance_type=None))] pub fn new( dataset: &Bound<'_, PyDataset>, - region_snapshots: Vec>, + shard_snapshots: Vec>, vector_column: String, pk_columns: Option>, distance_type: Option, ) -> PyResult { let ds = dataset.borrow().ds.clone(); - let snapshots: Vec = region_snapshots + let snapshots: Vec = shard_snapshots .iter() .map(|s| s.borrow().inner.clone()) .collect(); @@ -672,14 +727,15 @@ impl PyLsmVectorSearchPlanner { let vector_dim = get_vector_dim(&ds, &vector_column)?; - let collector = LsmDataSourceCollector::new(ds, snapshots); + let collector = LsmDataSourceCollector::new(ds.clone(), snapshots); let planner = LsmVectorSearchPlanner::new( collector, pk_cols, base_schema.clone(), vector_column, dist_type, - ); + ) + .with_dataset(ds); Ok(Self { planner, @@ -691,7 +747,7 @@ impl PyLsmVectorSearchPlanner { /// Plan a KNN vector search. /// /// `query` should be a flat PyArrow Float32Array with `vector_dim` elements. - #[pyo3(signature = (query, k=10, nprobes=20, columns=None))] + #[pyo3(signature = (query, k=10, nprobes=20, columns=None, refine_factor=None))] pub fn plan_search( &self, py: Python<'_>, @@ -699,6 +755,7 @@ impl PyLsmVectorSearchPlanner { k: usize, nprobes: usize, columns: Option>, + refine_factor: Option, ) -> PyResult { let query_array = make_array(query.0); let float32_array = query_array @@ -732,7 +789,7 @@ impl PyLsmVectorSearchPlanner { let plan = rt() .block_on(Some(py), async { planner_ref - .plan_search(&fsl, k, nprobes, columns.as_deref()) + .plan_search(&fsl, k, nprobes, columns.as_deref(), refine_factor) .await })? .map_err(|e| PyIOError::new_err(e.to_string()))?; @@ -878,8 +935,8 @@ fn scalar_values_from_pk_value( Ok(pk_values) } -fn region_snapshot_from_manifest(manifest: lance_index::mem_wal::ShardManifest) -> RegionSnapshot { - RegionSnapshot { +fn shard_snapshot_from_manifest(manifest: lance_index::mem_wal::ShardManifest) -> ShardSnapshot { + ShardSnapshot { shard_id: manifest.shard_id, spec_id: manifest.shard_spec_id, current_generation: manifest.current_generation, diff --git a/python/src/namespace.rs b/python/src/namespace.rs index 15af9f07280..cf5f7c41b0f 100644 --- a/python/src/namespace.rs +++ b/python/src/namespace.rs @@ -12,16 +12,17 @@ use lance_namespace::LanceNamespace as LanceNamespaceTrait; use lance_namespace::models::{ AlterTableAddColumnsRequest, AlterTableAlterColumnsRequest, AlterTableBackfillColumnsRequest, AlterTableDropColumnsRequest, AlterTransactionRequest, AnalyzeTableQueryPlanRequest, - CountTableRowsRequest, CreateTableIndexRequest, CreateTableTagRequest, - CreateTableVersionRequest, CreateTableVersionResponse, DeleteFromTableRequest, - DeleteTableTagRequest, DescribeTableIndexStatsRequest, DescribeTableRequest, - DescribeTableResponse, DescribeTableVersionRequest, DescribeTableVersionResponse, - DescribeTransactionRequest, DropTableIndexRequest, ExplainTableQueryPlanRequest, - GetTableStatsRequest, GetTableTagVersionRequest, InsertIntoTableRequest, - ListTableIndicesRequest, ListTableTagsRequest, ListTableVersionsRequest, - ListTableVersionsResponse, ListTablesRequest, MergeInsertIntoTableRequest, QueryTableRequest, - RefreshMaterializedViewRequest, RestoreTableRequest, UpdateTableRequest, - UpdateTableSchemaMetadataRequest, UpdateTableTagRequest, + CountTableRowsRequest, CreateMaterializedViewRequest, CreateTableIndexRequest, + CreateTableTagRequest, CreateTableVersionRequest, CreateTableVersionResponse, + DeleteFromTableRequest, DeleteTableTagRequest, DescribeTableIndexStatsRequest, + DescribeTableRequest, DescribeTableResponse, DescribeTableVersionRequest, + DescribeTableVersionResponse, DescribeTransactionRequest, DropTableIndexRequest, + ExplainTableQueryPlanRequest, GetTableStatsRequest, GetTableTagVersionRequest, + InsertIntoTableRequest, ListTableIndicesRequest, ListTableTagsRequest, + ListTableVersionsRequest, ListTableVersionsResponse, ListTablesRequest, + MergeInsertIntoTableRequest, QueryTableRequest, RefreshMaterializedViewRequest, + RestoreTableRequest, UpdateTableRequest, UpdateTableSchemaMetadataRequest, + UpdateTableTagRequest, }; use lance_namespace_impls::RestNamespaceBuilder; use lance_namespace_impls::{ConnectBuilder, RestAdapter, RestAdapterConfig, RestAdapterHandle}; @@ -685,6 +686,18 @@ impl PyDirectoryNamespace { pythonize(py, &response).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } + fn create_materialized_view<'py>( + &self, + py: Python<'py>, + request: &Bound<'_, PyAny>, + ) -> PyResult> { + let request: CreateMaterializedViewRequest = depythonize(request)?; + let response = crate::rt() + .block_on(Some(py), self.inner.create_materialized_view(request))? + .infer_error()?; + pythonize(py, &response).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) + } + // Table tag operations fn list_table_tags<'py>( @@ -1335,6 +1348,18 @@ impl PyRestNamespace { pythonize(py, &response).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) } + fn create_materialized_view<'py>( + &self, + py: Python<'py>, + request: &Bound<'_, PyAny>, + ) -> PyResult> { + let request: CreateMaterializedViewRequest = depythonize(request)?; + let response = crate::rt() + .block_on(Some(py), self.inner.create_materialized_view(request))? + .infer_error()?; + pythonize(py, &response).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) + } + // Table tag operations fn list_table_tags<'py>( diff --git a/python/src/tracing.rs b/python/src/tracing.rs index fdeb3c49220..3aadae60060 100644 --- a/python/src/tracing.rs +++ b/python/src/tracing.rs @@ -219,6 +219,11 @@ impl Visit for EventToMap { } } +fn trace_event_log_target(target: &str) -> String { + let target = target.strip_prefix("lance::").unwrap_or(target); + format!("lance::events::{target}") +} + #[derive(Clone)] pub struct LoggingPassthroughRef(Arc>>); @@ -248,7 +253,14 @@ impl tracing_subscriber::Layer for LoggingPassthroughRef { let mut fields = EventToStr::default(); event.record(&mut fields); - log::log!(target: "lance::events", state.level, "target=\"{}\" {}", event.metadata().target(), fields.str); + let log_target = trace_event_log_target(event.metadata().target()); + log::log!( + target: &log_target, + state.level, + "target=\"{}\" {}", + event.metadata().target(), + fields.str + ); if let Some(callback_sender) = state.callback_sender.as_ref() { let mut args = EventToMap::default(); diff --git a/rust/lance-core/Cargo.toml b/rust/lance-core/Cargo.toml index 9dff4b001a4..ccab121b43f 100644 --- a/rust/lance-core/Cargo.toml +++ b/rust/lance-core/Cargo.toml @@ -48,11 +48,16 @@ log.workspace = true libc = { version = "0.2" } [dev-dependencies] +criterion.workspace = true proptest.workspace = true rstest.workspace = true [features] datafusion = ["dep:datafusion-common", "dep:datafusion-sql"] +[[bench]] +name = "row_addr_mask" +harness = false + [lints] workspace = true diff --git a/rust/lance-core/benches/row_addr_mask.rs b/rust/lance-core/benches/row_addr_mask.rs new file mode 100644 index 00000000000..c1f09484b69 --- /dev/null +++ b/rust/lance-core/benches/row_addr_mask.rs @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmarks for `RowAddrMask` / `RowAddrTreeMap`. +//! +//! These benchmarks are deliberately structured to expose the row-cardinality +//! scaling weakness of the current per-row bitmap representation. Producers +//! (e.g. scalar-index `search` implementations) and consumers (e.g. +//! `mask_to_offset_ranges`) are frequently range-shaped, but every operation +//! must round-trip through `Partial(RoaringBitmap)` and therefore costs O(N) +//! in the number of rows, not O(R) in the number of distinct ranges. +//! +//! Each benchmark varies the number of rows while keeping the number of +//! ranges fixed at 1. A range-aware representation should make these +//! near-constant time; today they are linear in N. +//! +//! Run with `cargo bench -p lance-core --bench row_addr_mask`. + +use std::ops::Range; + +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; + +/// Row counts we sweep across. Chosen to cover the realistic range of +/// matches a zonemap produces for an `IS NULL`-like predicate on a single +/// fragment: a few thousand rows up through tens of millions. +const ROW_COUNTS: &[u64] = &[10_000, 100_000, 1_000_000, 10_000_000]; + +fn make_range_mask(num_rows: u64) -> RowAddrTreeMap { + // Build a mask covering a single contiguous run in fragment 0. + // This is the exact shape a scalar-index search produces when it + // determines a contiguous chunk of zones matches. + let mut map = RowAddrTreeMap::new(); + map.insert_range(0..num_rows); + map +} + +/// Producer cost: building a mask from one contiguous Range. +/// +/// Today this is O(N) — every bit gets inserted into a roaring bitmap. +/// With a range-aware representation it would be O(1) (push a single run). +fn bench_insert_range(c: &mut Criterion) { + let mut group = c.benchmark_group("insert_range_single_run"); + for &n in ROW_COUNTS { + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &n| { + b.iter(|| { + let mut map = RowAddrTreeMap::new(); + map.insert_range(0..n); + std::hint::black_box(map); + }); + }); + } + group.finish(); +} + +/// Consumer cost: iterating every row address in a dense mask. +/// +/// `into_addr_iter` walks set bits one at a time. For a contiguous run +/// of N rows this is O(N) — even though the rows are trivially +/// representable as a single Range. This is what `mask_to_offset_ranges` +/// does after intersecting with a source segment: it pays per-row +/// iteration cost only to immediately collapse the addresses back into +/// ranges via `GroupingIterator`. +fn bench_iter_addrs(c: &mut Criterion) { + let mut group = c.benchmark_group("into_addr_iter_single_run"); + for &n in ROW_COUNTS { + let map = make_range_mask(n); + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, _| { + b.iter(|| { + // SAFETY: the map only contains Partial selections; no Full entries. + let count: u64 = unsafe { map.clone().into_addr_iter() }.count() as u64; + std::hint::black_box(count); + }); + }); + } + group.finish(); +} + +/// Best-achievable iteration over the same data. +/// +/// `Iter::next_range` walks the bitmap's run containers in O(num_runs). +/// For a single contiguous run this should be ~constant time — the +/// public `RowAddrMask` API gives no way to surface that today, so the +/// performance is currently inaccessible to callers. Comparing this to +/// `into_addr_iter_single_run` quantifies the speedup a range-aware +/// representation could deliver to consumers. +fn bench_iter_runs(c: &mut Criterion) { + let mut group = c.benchmark_group("next_range_iter_single_run"); + for &n in ROW_COUNTS { + // Use the same underlying roaring bitmap shape that `make_range_mask` + // produces internally (one fragment, one contiguous run). + let mut bitmap = roaring::RoaringBitmap::new(); + bitmap.insert_range(0..(n as u32)); + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, _| { + b.iter(|| { + let mut iter = bitmap.iter(); + let mut runs: u64 = 0; + while iter.next_range().is_some() { + runs += 1; + } + std::hint::black_box(runs); + }); + }); + } + group.finish(); +} + +/// Range-aware consumer cost via the public `RowAddrTreeMap::iter_runs` +/// API. The map is built the ordinary way (`insert_range` → Partial +/// bitmap); `iter_runs` walks the bitmap's run containers via +/// `Iter::next_range`. Compare against `into_addr_iter_single_run` to see +/// the consumer-side speedup callers get without changing the underlying +/// representation. +fn bench_iter_runs_partial(c: &mut Criterion) { + let mut group = c.benchmark_group("iter_runs_partial_single_run"); + for &n in ROW_COUNTS { + let map = make_range_mask(n); + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, _| { + b.iter(|| { + // SAFETY: map only contains Partial selections. + let mut runs: u64 = 0; + for _ in unsafe { map.iter_runs() } { + runs += 1; + } + std::hint::black_box(runs); + }); + }); + } + group.finish(); +} + +/// Set intersection of two range-shaped masks. +/// +/// Both inputs are single contiguous runs that overlap in their middle +/// half (so the output is itself a single contiguous run). With per-row +/// bitmaps this is O(N) — the entire bitmap participates in the AND. +/// With ranges it would be O(1). +fn bench_intersect_ranges(c: &mut Criterion) { + let mut group = c.benchmark_group("intersect_two_runs"); + for &n in ROW_COUNTS { + let lhs = make_range_mask(n); + let rhs_range = (n / 4)..(3 * n / 4); + let mut rhs = RowAddrTreeMap::new(); + rhs.insert_range(rhs_range); + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, _| { + b.iter(|| { + let mut tmp = lhs.clone(); + tmp &= &rhs; + std::hint::black_box(tmp); + }); + }); + } + group.finish(); +} + +/// Full round trip: build a source range bitmap, AND with a mask, iterate +/// each surviving bit. This is the exact slow path of +/// `mask_to_offset_ranges` in `lance-table/src/rowids.rs:387`. Profiling +/// a 10M-row zonemap `IS NULL` query showed this consuming ~55% of the +/// hot-loop time (~495 ms of 889 ms). The benchmark separates the +/// per-row producer/consumer cost from the rest of the scan pipeline so +/// it can be tracked in isolation. +fn bench_range_to_ranges_round_trip(c: &mut Criterion) { + let mut group = c.benchmark_group("mask_to_offset_ranges_inner_loop"); + for &n in ROW_COUNTS { + // The mask selects the back half of a 2N-row fragment. + let mask_range = n..(2 * n); + let mask = RowAddrMask::AllowList(RowAddrTreeMap::from(mask_range)); + // The source segment covers the whole fragment. + let src: Range = 0..(2 * n); + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, _| { + b.iter(|| { + // Mimic the slow path: materialize source range, AND with mask, + // iterate to count survivors (a stand-in for whatever the + // consumer actually does — e.g. GroupingIterator). + let mut ids = RowAddrTreeMap::from(src.clone()); + ids.mask(&mask); + let count = unsafe { ids.into_addr_iter() }.count(); + std::hint::black_box(count); + }); + }); + } + group.finish(); +} + +/// Same end-to-end shape as `mask_to_offset_ranges_inner_loop`, but the +/// final per-bit walk is replaced by `iter_runs`. Quantifies the speedup +/// the consumer side gets purely from switching iteration APIs — no +/// representation change. +fn bench_range_to_ranges_round_trip_runs(c: &mut Criterion) { + let mut group = c.benchmark_group("mask_to_offset_ranges_inner_loop_runs"); + for &n in ROW_COUNTS { + let mask_range = n..(2 * n); + let mask = RowAddrMask::AllowList(RowAddrTreeMap::from(mask_range)); + let src: Range = 0..(2 * n); + group.throughput(Throughput::Elements(n)); + group.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, _| { + b.iter(|| { + let mut ids = RowAddrTreeMap::from(src.clone()); + ids.mask(&mask); + // SAFETY: only Partial selections in play. + let count: u64 = unsafe { ids.iter_runs() } + .map(|(_, r)| (*r.end() as u64) - (*r.start() as u64) + 1) + .sum(); + std::hint::black_box(count); + }); + }); + } + group.finish(); +} + +/// Many small runs vs one big run with the same total cardinality. +/// +/// A range-aware representation should be O(num_runs), so the +/// `single_run` case should be ~K times faster than the `K_runs` case. +/// Today they are essentially equal: the cost is dictated by the number +/// of rows, not the number of runs. +fn bench_runs_vs_rows(c: &mut Criterion) { + let total_rows: u64 = 1_000_000; + let mut group = c.benchmark_group("insert_runs_constant_cardinality"); + + group.throughput(Throughput::Elements(total_rows)); + group.bench_function("single_run_1M", |b| { + b.iter(|| { + let mut map = RowAddrTreeMap::new(); + map.insert_range(0..total_rows); + std::hint::black_box(map); + }); + }); + + for k in [10u64, 100, 1_000, 10_000] { + let run_size = total_rows / k; + // Stride between runs is 2 * run_size so the bitmap is half full. + let stride = run_size * 2; + group.bench_function(format!("{k}_runs_1M_total"), |b| { + b.iter(|| { + let mut map = RowAddrTreeMap::new(); + for i in 0..k { + let start = i * stride; + map.insert_range(start..(start + run_size)); + } + std::hint::black_box(map); + }); + }); + } + group.finish(); +} + +criterion_group!( + benches, + bench_insert_range, + bench_iter_addrs, + bench_iter_runs, + bench_iter_runs_partial, + bench_intersect_ranges, + bench_range_to_ranges_round_trip, + bench_range_to_ranges_round_trip_runs, + bench_runs_vs_rows, +); +criterion_main!(benches); diff --git a/rust/lance-core/src/utils/mask.rs b/rust/lance-core/src/utils/mask.rs index 0ee1b5d17fa..b904a0e3748 100644 --- a/rust/lance-core/src/utils/mask.rs +++ b/rust/lance-core/src/utils/mask.rs @@ -13,6 +13,7 @@ use deepsize::DeepSizeOf; use itertools::Itertools; use roaring::{MultiOps, RoaringBitmap, RoaringTreemap}; +use crate::cache::CacheCodecImpl; use crate::{Error, Result}; use super::address::RowAddress; @@ -659,6 +660,39 @@ impl RowAddrTreeMap { }), }) } + + /// Iterate the selected row addresses as `(fragment_id, run)` pairs. + /// + /// Range-shaped counterpart to [`Self::into_addr_iter`]. A contiguous + /// run of N selected rows yields one item, not N. Uses roaring's + /// `Iter::next_range`, which walks each underlying container's runs + /// rather than its individual bits, so dense ranges cost + /// O(num_containers) (roughly num_rows / 65536) instead of O(num_rows). + /// + /// # Safety + /// Same contract as [`Self::into_addr_iter`]: panics if any entry is + /// `Full`, since the fragment size is unknown at this layer. + pub unsafe fn iter_runs(&self) -> impl Iterator)> + '_ { + self.inner + .iter() + .flat_map(|(&fragment, selection)| match selection { + RowAddrSelection::Full => panic!("Size of full fragment is unknown"), + RowAddrSelection::Partial(bitmap) => { + let mut iter = bitmap.iter(); + std::iter::from_fn(move || iter.next_range().map(|r| (fragment, r))) + } + }) + } +} + +impl CacheCodecImpl for RowAddrTreeMap { + fn serialize(&self, writer: &mut dyn Write) -> Result<()> { + self.serialize_into(writer) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + Self::deserialize_from(data.as_ref()) + } } impl std::ops::BitOr for RowAddrTreeMap { @@ -1555,6 +1589,38 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_map_iter_runs() { + // Three contiguous regions across two fragments. + let mut mask = RowAddrTreeMap::default(); + mask.insert_range(0..3); + mask.insert_range(10..15); + mask.insert_range((1u64 << 32) + 100..(1u64 << 32) + 103); + + // SAFETY: only Partial entries. + let runs: Vec<(u32, RangeInclusive)> = unsafe { mask.iter_runs().collect() }; + assert_eq!(runs, vec![(0, 0..=2), (0, 10..=14), (1, 100..=102)]); + } + + #[test] + fn test_map_iter_runs_matches_into_addr_iter() { + // Confirm iter_runs and into_addr_iter agree on a non-trivial shape. + let mut mask = RowAddrTreeMap::default(); + mask.insert_range(5..7); + mask.insert_range(11..12); + mask.insert_range(20..25); + mask.insert_range((1u64 << 32)..(1u64 << 32) + 3); + + let from_runs: Vec = unsafe { mask.iter_runs() } + .flat_map(|(frag, run)| { + let frag = u64::from(frag); + (*run.start()..=*run.end()).map(move |v| (frag << 32) | u64::from(v)) + }) + .collect(); + let from_bits: Vec = unsafe { mask.clone().into_addr_iter() }.collect(); + assert_eq!(from_runs, from_bits); + } + #[test] fn test_map_from() { let map = RowAddrTreeMap::from(10..12); diff --git a/rust/lance-core/src/utils/tracing.rs b/rust/lance-core/src/utils/tracing.rs index e126fade06d..603a666e313 100644 --- a/rust/lance-core/src/utils/tracing.rs +++ b/rust/lance-core/src/utils/tracing.rs @@ -84,3 +84,4 @@ pub const DATASET_DELETING_EVENT: &str = "deleting"; pub const DATASET_COMPACTING_EVENT: &str = "compacting"; pub const DATASET_CLEANING_EVENT: &str = "cleaning"; pub const DATASET_LOADING_EVENT: &str = "loading"; +pub const TRACE_OBJECT_STORE_THROTTLE: &str = "lance::object_store::throttle"; diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 45027cc7b63..76d387f92b7 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -19,10 +19,14 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{StreamExt, TryStreamExt, stream}; +use lance_arrow::ipc::{ + read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream, + write_len_prefixed_bytes, +}; use lance_core::utils::mask::RowSetOps; use lance_core::{ Error, ROW_ID, Result, - cache::{CacheKey, LanceCache, WeakLanceCache}, + cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, utils::{ mask::{NullableRowAddrSet, RowAddrTreeMap}, @@ -145,6 +149,151 @@ impl CacheKey for BitmapKey { fn type_name() -> &'static str { "Bitmap" } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } +} + +/// The serializable state of a [`BitmapIndex`]. +/// +/// `BitmapIndex` holds non-serializable infrastructure (an `IndexStore`, a +/// cache handle, a lazy reader, a fragment-reuse index). `BitmapIndexState` +/// captures just the data needed to rebuild it: the value→file-offset map, +/// the null bitmap, and the value type. +#[derive(Debug, Clone)] +pub struct BitmapIndexState { + /// Value-to-row-offset lookup, encoded as an Arrow `RecordBatch` so we can + /// reuse the existing IPC utilities for zero-copy round trips. + /// + /// Schema: `keys: `, `offsets: UInt64`. Iteration order of + /// `index_map` is preserved on serialize and the `BTreeMap` resorts the + /// entries on deserialize, so the wire form does not need to be sorted. + lookup_batch: RecordBatch, + /// Already-remapped null bitmap (remapping is applied during load, so the + /// cached state matches the in-memory representation). + null_map: Arc, + /// Cached separately from the schema for the empty-index case where the + /// `lookup_batch` is empty but we still need to remember the column type. + value_type: DataType, +} + +impl DeepSizeOf for BitmapIndexState { + fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { + self.lookup_batch.get_array_memory_size() + self.null_map.deep_size_of_children(context) + } +} + +impl BitmapIndexState { + pub(crate) fn from_index(index: &BitmapIndex) -> Result { + Ok(Self { + lookup_batch: build_lookup_batch(&index.index_map, &index.value_type)?, + null_map: index.null_map.clone(), + value_type: index.value_type.clone(), + }) + } + + pub(crate) fn into_bitmap_index( + self, + store: Arc, + index_cache: &LanceCache, + frag_reuse_index: Option>, + ) -> Result> { + let index_map = parse_lookup_batch(&self.lookup_batch)?; + Ok(Arc::new(BitmapIndex::new( + index_map, + self.null_map, + self.value_type, + store, + WeakLanceCache::from(index_cache), + frag_reuse_index, + ))) + } +} + +fn build_lookup_batch( + index_map: &BTreeMap, + value_type: &DataType, +) -> Result { + let keys = if index_map.is_empty() { + arrow_array::new_empty_array(value_type) + } else { + ScalarValue::iter_to_array(index_map.keys().map(|k| k.0.clone()))? + }; + let offsets = Arc::new(UInt64Array::from_iter_values( + index_map.values().map(|v| *v as u64), + )); + let schema = Arc::new(Schema::new(vec![ + Field::new("keys", value_type.clone(), true), + Field::new("offsets", DataType::UInt64, false), + ])); + Ok(RecordBatch::try_new(schema, vec![keys, offsets])?) +} + +fn parse_lookup_batch(batch: &RecordBatch) -> Result> { + let keys = batch.column(0); + let offsets = batch + .column(1) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::internal("BitmapIndexState: expected UInt64 offsets column".to_string()) + })?; + let mut index_map = BTreeMap::new(); + for idx in 0..batch.num_rows() { + let value = OrderableScalarValue(ScalarValue::try_from_array(keys, idx)?); + index_map.insert(value, offsets.value(idx) as usize); + } + Ok(index_map) +} + +impl CacheCodecImpl for BitmapIndexState { + /// Wire format: + /// ```text + /// [u64 null_map_len][null_map bytes] + /// [arrow IPC stream: (keys: , offsets: UInt64)] + /// ``` + /// The value type is recovered from the IPC stream schema. + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + let mut null_bytes = Vec::with_capacity(self.null_map.serialized_size()); + self.null_map.serialize_into(&mut null_bytes)?; + write_len_prefixed_bytes(writer, &null_bytes)?; + write_ipc_stream(&self.lookup_batch, writer)?; + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + let mut offset = 0; + let null_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let null_map = Arc::new(RowAddrTreeMap::deserialize_from(null_bytes.as_ref())?); + let lookup_batch = read_ipc_stream_single_at(data, &mut offset)?; + let value_type = lookup_batch.schema().field(0).data_type().clone(); + Ok(Self { + lookup_batch, + null_map, + value_type, + }) + } +} + +/// Cache key for a [`BitmapIndexState`]. The cache is already namespaced +/// per-index by the caller, so a constant key suffices. +struct BitmapIndexStateKey; + +impl CacheKey for BitmapIndexStateKey { + type ValueType = BitmapIndexState; + + fn key(&self) -> std::borrow::Cow<'_, str> { + "state".into() + } + + fn type_name() -> &'static str { + "BitmapIndexState" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } } impl BitmapIndex { @@ -1542,6 +1691,34 @@ impl ScalarIndexPlugin for BitmapIndexPlugin { Ok(BitmapIndex::load(index_store, frag_reuse_index, cache).await? as Arc) } + async fn get_from_cache( + &self, + index_store: Arc, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result>> { + let Some(state) = cache.get_with_key(&BitmapIndexStateKey).await else { + return Ok(None); + }; + let state = (*state).clone(); + let index = state.into_bitmap_index(index_store, cache, frag_reuse_index)?; + Ok(Some(index as Arc)) + } + + async fn put_in_cache(&self, cache: &LanceCache, index: Arc) -> Result<()> { + let bitmap = index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::internal("BitmapIndexPlugin::put_in_cache called with a non-bitmap index") + })?; + let state = BitmapIndexState::from_index(bitmap)?; + cache + .insert_with_key(&BitmapIndexStateKey, Arc::new(state)) + .await; + Ok(()) + } + async fn load_statistics( &self, index_store: Arc, @@ -1589,6 +1766,41 @@ mod tests { use lance_io::object_store::ObjectStore; use std::collections::HashMap; + fn assert_state_roundtrips(state: &BitmapIndexState) { + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + let restored = BitmapIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + assert_eq!(restored.lookup_batch, state.lookup_batch); + assert_eq!(&*restored.null_map, &*state.null_map); + assert_eq!(restored.value_type, state.value_type); + } + + #[test] + fn test_bitmap_index_state_codec_roundtrip() { + // Non-empty state with a few keys and a populated null map. + let mut index_map = BTreeMap::new(); + index_map.insert(OrderableScalarValue(ScalarValue::Int32(Some(1))), 0); + index_map.insert(OrderableScalarValue(ScalarValue::Int32(Some(7))), 1); + index_map.insert(OrderableScalarValue(ScalarValue::Int32(Some(42))), 2); + let mut null_map = RowAddrTreeMap::new(); + null_map.insert(RowAddress::new_from_parts(0, 3).into()); + null_map.insert(RowAddress::new_from_parts(0, 5).into()); + let state = BitmapIndexState { + lookup_batch: build_lookup_batch(&index_map, &DataType::Int32).unwrap(), + null_map: Arc::new(null_map), + value_type: DataType::Int32, + }; + assert_state_roundtrips(&state); + + // Empty state: no keys, empty null map. Schema still carries the type. + let empty_state = BitmapIndexState { + lookup_batch: build_lookup_batch(&BTreeMap::new(), &DataType::Utf8).unwrap(), + null_map: Arc::new(RowAddrTreeMap::new()), + value_type: DataType::Utf8, + }; + assert_state_roundtrips(&empty_state); + } + #[tokio::test] async fn test_bitmap_lazy_loading_and_cache() { // Create a temporary directory for the index diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index f2be3241b85..eba6f1c6205 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -44,9 +44,10 @@ use futures::{ future::BoxFuture, stream::{self}, }; +use lance_arrow::ipc::{read_ipc_stream_single_at, write_ipc_stream}; use lance_core::{ Error, ROW_ID, Result, - cache::{CacheKey, LanceCache, WeakLanceCache}, + cache::{CacheCodec, CacheCodecImpl, CacheKey, LanceCache, WeakLanceCache}, error::LanceOptionExt, utils::{ mask::NullableRowAddrSet, @@ -998,6 +999,173 @@ impl CacheKey for BTreePageKey { fn type_name() -> &'static str { "BTreePage" } + + fn codec() -> Option { + // Pages are cached as `FlatIndex` values (see `ValueType` above). + Some(CacheCodec::from_impl::()) + } +} + +/// The serializable state of a [`BTreeIndex`]. +/// +/// A `BTreeIndex` holds non-serializable infrastructure (an `IndexStore`, a +/// cache handle, a fragment-reuse index). `BTreeIndexState` captures just the +/// data needed to rebuild it: the `page_lookup.lance` batch (from which +/// `BTreeIndex::try_from_serialized` reconstructs the in-memory lookup with +/// no IO) plus the page batch size and range-partition map. +#[derive(Debug, Clone)] +pub struct BTreeIndexState { + lookup_batch: RecordBatch, + batch_size: u64, + ranges_to_files: Option>>, +} + +impl DeepSizeOf for BTreeIndexState { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + // `ranges_to_files` is tiny and `RangeInclusiveMap` is not `DeepSizeOf`; + // the lookup batch dominates, matching how `BTreeIndex` accounts for itself. + self.lookup_batch.get_array_memory_size() + } +} + +impl BTreeIndexState { + fn reconstruct( + &self, + store: Arc, + index_cache: &LanceCache, + frag_reuse_index: Option>, + ) -> Result> { + let index = BTreeIndex::try_from_serialized( + self.lookup_batch.clone(), + store, + index_cache, + self.batch_size, + self.ranges_to_files.clone(), + frag_reuse_index, + )?; + Ok(Arc::new(index) as Arc) + } +} + +impl CacheCodecImpl for BTreeIndexState { + /// Wire format (no stability guarantees yet — the cache is rebuilt from + /// source on any version mismatch): + /// ```text + /// u64 batch_size (LE) + /// u8 has_ranges (0 = None, 1 = Some) + /// if has_ranges: + /// u32 entry_count (LE) + /// per entry: u32 start | u32 end | u32 offset | u32 path_len | path bytes + /// lookup batch (Arrow IPC stream) + /// ``` + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + writer.write_all(&self.batch_size.to_le_bytes())?; + match &self.ranges_to_files { + None => writer.write_all(&[0u8])?, + Some(ranges) => { + writer.write_all(&[1u8])?; + let count = u32::try_from(ranges.len()).map_err(|_| { + Error::io("BTreeIndexState: ranges_to_files exceeds u32::MAX entries") + })?; + writer.write_all(&count.to_le_bytes())?; + for (range, (path, page_offset)) in ranges.iter() { + writer.write_all(&range.start().to_le_bytes())?; + writer.write_all(&range.end().to_le_bytes())?; + writer.write_all(&page_offset.to_le_bytes())?; + let path_len = u32::try_from(path.len()).map_err(|_| { + Error::io("BTreeIndexState: ranges_to_files path exceeds u32::MAX bytes") + })?; + writer.write_all(&path_len.to_le_bytes())?; + writer.write_all(path.as_bytes())?; + } + } + } + write_ipc_stream(&self.lookup_batch, writer)?; + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result { + let mut offset = 0; + let batch_size = read_u64_le(data, &mut offset)?; + let has_ranges = read_u8(data, &mut offset)?; + let ranges_to_files = match has_ranges { + 0 => None, + 1 => { + let count = read_u32_le(data, &mut offset)? as usize; + let mut entries = Vec::with_capacity(count); + for _ in 0..count { + let start = read_u32_le(data, &mut offset)?; + let end = read_u32_le(data, &mut offset)?; + let page_offset = read_u32_le(data, &mut offset)?; + let path_len = read_u32_le(data, &mut offset)? as usize; + let path = read_bytes(data, &mut offset, path_len)?; + let path = std::str::from_utf8(&path) + .map_err(|e| Error::io(format!("BTreeIndexState path: {e}")))? + .to_string(); + entries.push((start..=end, (path, page_offset))); + } + Some(Arc::new(entries.into_iter().collect())) + } + other => { + return Err(Error::io(format!( + "BTreeIndexState: invalid has_ranges tag {other}" + ))); + } + }; + let lookup_batch = read_ipc_stream_single_at(data, &mut offset)?; + Ok(Self { + lookup_batch, + batch_size, + ranges_to_files, + }) + } +} + +fn read_bytes(data: &bytes::Bytes, offset: &mut usize, len: usize) -> Result { + if data.len() < *offset + len { + return Err(Error::io(format!( + "BTreeIndexState: short read of {len} bytes at offset {offset} (have {})", + data.len() + ))); + } + let slice = data.slice(*offset..*offset + len); + *offset += len; + Ok(slice) +} + +fn read_u8(data: &bytes::Bytes, offset: &mut usize) -> Result { + let bytes = read_bytes(data, offset, 1)?; + Ok(bytes[0]) +} + +fn read_u32_le(data: &bytes::Bytes, offset: &mut usize) -> Result { + let bytes = read_bytes(data, offset, 4)?; + Ok(u32::from_le_bytes(bytes.as_ref().try_into().unwrap())) +} + +fn read_u64_le(data: &bytes::Bytes, offset: &mut usize) -> Result { + let bytes = read_bytes(data, offset, 8)?; + Ok(u64::from_le_bytes(bytes.as_ref().try_into().unwrap())) +} + +/// Cache key for a [`BTreeIndexState`]. The cache it is used with is already +/// namespaced per-index, so the key string is a constant. +struct BTreeIndexStateKey; + +impl CacheKey for BTreeIndexStateKey { + type ValueType = BTreeIndexState; + + fn key(&self) -> std::borrow::Cow<'_, str> { + "state".into() + } + + fn type_name() -> &'static str { + "BTreeIndexState" + } + + fn codec() -> Option { + Some(CacheCodec::from_impl::()) + } } /// Note: this is very similar to the IVF index except we store the IVF part in a btree @@ -1040,13 +1208,26 @@ pub struct BTreeIndex { /// - The system now knows to read page `42` from the file `part_2_page_file.lance`. ranges_to_files: Option>>, frag_reuse_index: Option>, + + /// The raw lookup batch this index was built from (the contents of + /// `page_lookup.lance`). Retained so the index can be serialized into a + /// cache as a [`BTreeIndexState`] without re-reading it from storage. + /// + /// TODO: this duplicates the min/max values already held in `page_lookup`. + /// A follow-up could rewrite `BTreeLookup` to query this batch directly + /// (binary search on the sorted `min` column + linear scan, type-dispatched + /// per column type), eliminating the duplication and making this batch the + /// single source of truth. + lookup_batch: RecordBatch, } impl DeepSizeOf for BTreeIndex { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { // We don't include the index cache, or anything stored in it. For example: // sub_index and fri. - self.page_lookup.deep_size_of_children(context) + self.store.deep_size_of_children(context) + self.page_lookup.deep_size_of_children(context) + + self.store.deep_size_of_children(context) + + self.lookup_batch.get_array_memory_size() } } @@ -1060,6 +1241,7 @@ impl BTreeIndex { batch_size: u64, ranges_to_files: Option>>, frag_reuse_index: Option>, + lookup_batch: RecordBatch, ) -> Self { Self { page_lookup, @@ -1069,6 +1251,7 @@ impl BTreeIndex { batch_size, ranges_to_files, frag_reuse_index, + lookup_batch, } } @@ -1158,6 +1341,7 @@ impl BTreeIndex { batch_size, ranges_to_files, frag_reuse_index, + data, )); } @@ -1199,18 +1383,19 @@ impl BTreeIndex { let last_max = ScalarValue::try_from_array(&maxs, data.num_rows() - 1)?; map.entry(OrderableScalarValue(last_max)).or_default(); - let data_type = mins.data_type(); + let data_type = mins.data_type().clone(); let page_lookup = Arc::new(BTreeLookup::new(map, null_pages, all_null_pages)); Ok(Self::new( page_lookup, store, - data_type.clone(), + data_type, WeakLanceCache::from(index_cache), batch_size, ranges_to_files, frag_reuse_index, + data, )) } @@ -2753,6 +2938,37 @@ impl ScalarIndexPlugin for BTreeIndexPlugin { ) -> Result> { Ok(BTreeIndex::load(index_store, frag_reuse_index, cache).await? as Arc) } + + async fn get_from_cache( + &self, + index_store: Arc, + frag_reuse_index: Option>, + cache: &LanceCache, + ) -> Result>> { + let Some(state) = cache.get_with_key(&BTreeIndexStateKey).await else { + return Ok(None); + }; + Ok(Some(state.reconstruct( + index_store, + cache, + frag_reuse_index, + )?)) + } + + async fn put_in_cache(&self, cache: &LanceCache, index: Arc) -> Result<()> { + let btree = index.as_any().downcast_ref::().ok_or_else(|| { + Error::internal("BTreeIndexPlugin::put_in_cache called with a non-BTree index") + })?; + let state = BTreeIndexState { + lookup_batch: btree.lookup_batch.clone(), + batch_size: btree.batch_size, + ranges_to_files: btree.ranges_to_files.clone(), + }; + cache + .insert_with_key(&BTreeIndexStateKey, Arc::new(state)) + .await; + Ok(()) + } } #[cfg(test)] @@ -2791,9 +3007,13 @@ mod tests { }; use super::{ - DEFAULT_BTREE_BATCH_SIZE, OrderableScalarValue, part_lookup_file_path, - part_page_data_file_path, train_btree_index, + BTreeIndexPlugin, BTreeIndexState, BTreePageKey, DEFAULT_BTREE_BATCH_SIZE, + OrderableScalarValue, part_lookup_file_path, part_page_data_file_path, train_btree_index, }; + use crate::scalar::registry::ScalarIndexPlugin; + use arrow_array::RecordBatch; + use lance_core::cache::{CacheCodecImpl, CacheKey}; + use rangemap::RangeInclusiveMap; lance_testing::define_stage_event_progress!( RecordingProgress, @@ -4769,4 +4989,288 @@ mod tests { _ => panic!("BTree search should return Exact"), } } + + fn sample_lookup_batch() -> RecordBatch { + record_batch!( + ("min", Int32, [Some(0), Some(10), Some(20)]), + ("max", Int32, [Some(9), Some(19), Some(29)]), + ("null_count", UInt32, [0, 2, 0]), + ("page_idx", UInt32, [0, 1, 2]) + ) + .unwrap() + } + + fn assert_state_roundtrips(state: &BTreeIndexState) { + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + let restored = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + assert_eq!(restored.lookup_batch, state.lookup_batch); + assert_eq!(restored.batch_size, state.batch_size); + assert_eq!(restored.ranges_to_files, state.ranges_to_files); + } + + #[test] + fn test_btree_page_key_codec() { + // FlatIndex pages can be serialized by a persistent cache backend. + assert!(BTreePageKey::codec().is_some()); + } + + #[test] + fn test_btree_index_state_roundtrip() { + // Not range-partitioned. + assert_state_roundtrips(&BTreeIndexState { + lookup_batch: sample_lookup_batch(), + batch_size: DEFAULT_BTREE_BATCH_SIZE, + ranges_to_files: None, + }); + + // Range-partitioned across multiple files. + let ranges: RangeInclusiveMap = [ + (0..=99, ("part_0_page_file.lance".to_string(), 0)), + (100..=199, ("part_1_page_file.lance".to_string(), 100)), + ] + .into_iter() + .collect(); + assert_state_roundtrips(&BTreeIndexState { + lookup_batch: sample_lookup_batch(), + batch_size: 8192, + ranges_to_files: Some(Arc::new(ranges)), + }); + + // Empty index. + assert_state_roundtrips(&BTreeIndexState { + lookup_batch: RecordBatch::new_empty(sample_lookup_batch().schema()), + batch_size: DEFAULT_BTREE_BATCH_SIZE, + ranges_to_files: None, + }); + } + + #[tokio::test] + async fn test_btree_index_state_reconstruct_and_plugin_cache() { + let tmpdir = TempObjDir::default(); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let stream = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(1000), BatchCount::from(5)); + train_btree_index(stream, test_store.as_ref(), 1000, None, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + + // Round-trip the state through the codec and reconstruct an index from it. + let state = BTreeIndexState { + lookup_batch: index.lookup_batch.clone(), + batch_size: index.batch_size, + ranges_to_files: index.ranges_to_files.clone(), + }; + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + let restored = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap(); + let reconstructed = restored + .reconstruct(test_store.clone(), &LanceCache::no_cache(), None) + .unwrap(); + assert_eq!( + reconstructed + .as_any() + .downcast_ref::() + .unwrap() + .page_lookup, + index.page_lookup + ); + + // The plugin's put/get hooks round-trip through a real cache + the codec. + let cache = LanceCache::with_capacity(64 * 1024 * 1024); + let plugin = BTreeIndexPlugin; + plugin.put_in_cache(&cache, index.clone()).await.unwrap(); + let from_cache = plugin + .get_from_cache(test_store.clone(), None, &cache) + .await + .unwrap() + .expect("index should be served from the cache"); + + // Searches against the cached index match the original. + let query = SargableQuery::Range( + std::ops::Bound::Included(ScalarValue::Int32(Some(100))), + std::ops::Bound::Excluded(ScalarValue::Int32(Some(200))), + ); + let expected = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + let actual = from_cache + .search(&query, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(expected, actual); + } + + #[test] + fn test_btree_index_state_rejects_invalid_has_ranges_tag() { + // u64 batch_size (any) then a bad has_ranges tag. + let mut buf = Vec::new(); + buf.extend_from_slice(&1000u64.to_le_bytes()); + buf.push(7u8); + let err = BTreeIndexState::deserialize(&bytes::Bytes::from(buf)).unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("has_ranges") && msg.contains("7"), + "expected error to mention the bad has_ranges tag, got: {msg}" + ); + } + + #[tokio::test] + async fn test_btree_index_state_reconstruct_applies_frag_reuse_index() { + use crate::frag_reuse::{FragReuseIndex, FragReuseIndexDetails}; + use std::collections::HashMap; + use uuid::Uuid; + + let tmpdir = TempObjDir::default(); + let test_store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // value == _rowid for all rows in [0, 1000). + let stream = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(1000), BatchCount::from(1)); + train_btree_index(stream, test_store.as_ref(), 1000, None, None) + .await + .unwrap(); + + let index = BTreeIndex::load(test_store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + let state = BTreeIndexState { + lookup_batch: index.lookup_batch.clone(), + batch_size: index.batch_size, + ranges_to_files: index.ranges_to_files.clone(), + }; + + // Remap row 0 -> row 5000 (outside the original [0, 1000) range so no collision). + // Querying for value == 0 should now return row 5000, confirming reconstruct threaded + // the FragReuseIndex through to the rebuilt BTreeIndex. + let frag_reuse_index = Arc::new(FragReuseIndex::new( + Uuid::new_v4(), + vec![HashMap::from([(0u64, Some(5000u64))])], + FragReuseIndexDetails { versions: vec![] }, + )); + let reconstructed = state + .reconstruct( + test_store.clone(), + &LanceCache::no_cache(), + Some(frag_reuse_index), + ) + .unwrap(); + + let result = reconstructed + .search( + &SargableQuery::Equals(ScalarValue::Int32(Some(0))), + &NoOpMetricsCollector, + ) + .await + .unwrap(); + let row_ids: Vec = match &result { + SearchResult::Exact(set) => set + .true_rows() + .row_addrs() + .unwrap() + .map(u64::from) + .collect(), + other => panic!("expected Exact, got {other:?}"), + }; + assert_eq!( + row_ids, + vec![5000], + "frag_reuse_index remap was not applied" + ); + } + + #[tokio::test] + async fn test_btree_index_state_range_partitioned_plugin_cache_roundtrip() { + // Build a range-partitioned BTree (two range partitions merged into one index) and + // round-trip it through the plugin's cache hooks. This exercises the + // `ranges_to_files = Some` path end-to-end through serialize/deserialize/reconstruct. + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + Arc::new(ObjectStore::local()), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let half = DEFAULT_BTREE_BATCH_SIZE; + let total = (2 * half) as i32; + + // Partition 0: values/rowids [0, half). + let part0 = gen_batch() + .col("value", array::step::()) + .col("_rowid", array::step::()) + .into_df_stream(RowCount::from(half), BatchCount::from(1)); + train_btree_index(part0, store.as_ref(), half, None, Some(0u32)) + .await + .unwrap(); + + // Partition 1: values/rowids [half, 2*half). + let values: Vec = (half as i32..total).collect(); + let row_ids: Vec = (half..total as u64).collect(); + let part1 = gen_batch() + .col("value", array::cycle::(values)) + .col("_rowid", array::cycle::(row_ids)) + .into_df_stream(RowCount::from(half), BatchCount::from(1)); + train_btree_index(part1, store.as_ref(), half, None, Some(1u32)) + .await + .unwrap(); + + super::merge_metadata_files( + store.as_ref(), + &[ + part_page_data_file_path(0 << 32), + part_page_data_file_path(1 << 32), + ], + &[ + part_lookup_file_path(0 << 32), + part_lookup_file_path(1 << 32), + ], + Some(1usize), + noop_progress(), + ) + .await + .unwrap(); + + let index = BTreeIndex::load(store.clone(), None, &LanceCache::no_cache()) + .await + .unwrap(); + assert!( + index.ranges_to_files.is_some(), + "test setup should produce a range-partitioned index", + ); + + let cache = LanceCache::with_capacity(64 * 1024 * 1024); + let plugin = BTreeIndexPlugin; + plugin.put_in_cache(&cache, index.clone()).await.unwrap(); + let from_cache = plugin + .get_from_cache(store.clone(), None, &cache) + .await + .unwrap() + .expect("index should be served from the cache"); + + // Search a value from each range partition and confirm both paths agree. + for value in [0i32, total - 1] { + let query = SargableQuery::Equals(ScalarValue::Int32(Some(value))); + let expected = index.search(&query, &NoOpMetricsCollector).await.unwrap(); + let actual = from_cache + .search(&query, &NoOpMetricsCollector) + .await + .unwrap(); + assert_eq!(expected, actual, "value {value}"); + } + } } diff --git a/rust/lance-index/src/scalar/btree/flat.rs b/rust/lance-index/src/scalar/btree/flat.rs index 113a850315b..10f0b1ad339 100644 --- a/rust/lance-index/src/scalar/btree/flat.rs +++ b/rust/lance-index/src/scalar/btree/flat.rs @@ -14,7 +14,9 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_physical_expr::create_physical_expr; use deepsize::DeepSizeOf; use lance_arrow::RecordBatchExt; +use lance_arrow::ipc::{read_ipc_stream_single_at, read_len_prefixed_bytes_at, write_ipc_stream}; use lance_core::Result; +use lance_core::cache::CacheCodecImpl; use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps}; use roaring::RoaringBitmap; @@ -233,6 +235,45 @@ impl FlatIndex { } } +impl CacheCodecImpl for FlatIndex { + fn serialize(&self, writer: &mut dyn std::io::Write) -> Result<()> { + // Format: + // [len-prefixed all_addrs_map][len-prefixed null_addrs_map][batch IPC stream] + writer.write_all(&(self.all_addrs_map.serialized_size() as u64).to_le_bytes())?; + self.all_addrs_map.serialize_into(&mut *writer)?; + + writer.write_all(&(self.null_addrs_map.serialized_size() as u64).to_le_bytes())?; + self.null_addrs_map.serialize_into(&mut *writer)?; + + write_ipc_stream(self.data.as_ref(), writer)?; + + Ok(()) + } + + fn deserialize(data: &bytes::Bytes) -> Result + where + Self: Sized, + { + let mut offset = 0; + let all_addrs_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let all_addrs_map = RowAddrTreeMap::deserialize_from(all_addrs_bytes.as_ref())?; + + let null_addrs_bytes = read_len_prefixed_bytes_at(data, &mut offset)?; + let null_addrs_map = RowAddrTreeMap::deserialize_from(null_addrs_bytes.as_ref())?; + + let batch = read_ipc_stream_single_at(data, &mut offset)?; + + let df_schema = DFSchema::try_from(batch.schema())?; + + Ok(Self { + data: Arc::new(batch), + all_addrs_map, + null_addrs_map, + df_schema, + }) + } +} + #[cfg(test)] mod tests { use crate::{ @@ -266,6 +307,34 @@ mod tests { assert_eq!(actual, expected); } + fn assert_roundtrips(index: &FlatIndex) { + let mut buf = Vec::new(); + index.serialize(&mut buf).unwrap(); + let restored = FlatIndex::deserialize(&bytes::Bytes::from(buf)).unwrap(); + + assert_eq!(restored.data, index.data); + assert_eq!(restored.all_addrs_map, index.all_addrs_map); + assert_eq!(restored.null_addrs_map, index.null_addrs_map); + } + + #[test] + fn test_cache_codec_roundtrip() { + // No nulls + assert_roundtrips(&example_index()); + + // With nulls in the values column + let batch = record_batch!( + (BTREE_VALUES_COLUMN, Int32, [None, Some(0), Some(5)]), + (BTREE_IDS_COLUMN, UInt64, [0, 1, 2]) + ) + .unwrap(); + assert_roundtrips(&FlatIndex::try_new(batch).unwrap()); + + // Empty index + let empty = RecordBatch::new_empty(example_index().data.schema()); + assert_roundtrips(&FlatIndex::try_new(empty).unwrap()); + } + #[tokio::test] async fn test_equality() { check_index(&SargableQuery::Equals(ScalarValue::from(100)), &[0]).await; diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 52e09864c14..d9c675b40ec 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -33,6 +33,7 @@ use arrow_array::{ use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::metrics::Time; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use deepsize::DeepSizeOf; use fst::{Automaton, IntoStreamer, Streamer}; @@ -4062,6 +4063,7 @@ async fn tokenize_and_count( tokenizer: Box, query_tokens: Arc, doc_col_idx: usize, + elapsed_compute: Option