diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java
index 2223cb9b3..3cb835d42 100644
--- a/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java
+++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/LanceSparkReadOptions.java
@@ -59,12 +59,48 @@ public class LanceSparkReadOptions implements Serializable {
public static final String CONFIG_TOP_N_PUSH_DOWN = "topN_push_down";
public static final String CONFIG_NEAREST = "nearest";
+
+ /**
+ * Whether executors should rebuild the namespace client and re-fetch storage options via {@code
+ * namespace.describeTable()} when opening a dataset for fragment scans.
+ *
+ *
When {@code true} (the default), executors reconstruct the namespace client and route the
+ * dataset open through the namespace path. This keeps the Rust-side storage-options provider
+ * attached so that short-lived vended credentials returned by {@code describeTable()} (e.g. STS
+ * tokens from Iceberg REST, Polaris, Unity) can be refreshed mid-scan.
+ *
+ *
When {@code false}, executors open the dataset directly by URI using the storage options the
+ * driver already obtained (passed in via {@code initialStorageOptions}). This skips the eager
+ * {@code describeTable()} RPC on every fragment scan, which is required for catalogs whose
+ * backing service authenticates per-call (e.g. Hive Metastore over Kerberos): executors typically
+ * do not have a Kerberos TGT and the call would otherwise fail with {@code GSS initiate failed}.
+ *
+ *
Whether disabling this option actually costs anything depends on the namespace impl:
+ *
+ *
+ *
{@code Hive2Namespace} / {@code Hive3Namespace}: {@code describeTable()} returns only the
+ * table location, never storage options. The refresh callback is a no-op, so setting this
+ * option to {@code false} has no downside. The underlying object-store credentials (e.g.
+ * IAM-role / {@code hive-site.xml} / env-vars on the executor) are rotated by the storage
+ * client SDK independently of Lance.
+ *
{@code GlueNamespace}: storage options come from a static {@code
+ * config.getStorageOptions()} and are typically not time-bound; setting {@code false} is
+ * usually safe unless you rely on LakeFormation-vended temporary credentials.
+ *
{@code IcebergNamespace} (REST), {@code PolarisNamespace}, {@code UnityNamespace}: {@code
+ * describeTable()} commonly returns vended temporary credentials. Leave this option at the
+ * default ({@code true}) unless every scan is guaranteed to finish within the credential
+ * TTL.
+ *
+ */
+ public static final String CONFIG_EXECUTOR_CREDENTIAL_REFRESH = "executor_credential_refresh";
+
public static final String LANCE_FILE_SUFFIX = ".lance";
private static final boolean DEFAULT_PUSH_DOWN_FILTERS = true;
// Changed from 512 to 8192 for better OLAP scan performance (33x improvement)
private static final int DEFAULT_BATCH_SIZE = 8192;
private static final boolean DEFAULT_TOP_N_PUSH_DOWN = true;
+ private static final boolean DEFAULT_EXECUTOR_CREDENTIAL_REFRESH = true;
private final String datasetUri;
private final String dbPath;
@@ -88,6 +124,12 @@ public class LanceSparkReadOptions implements Serializable {
/** The catalog name for cache isolation when multiple catalogs are configured. */
private final String catalogName;
+ /**
+ * Whether executors should rebuild the namespace client for credential refresh. See {@link
+ * #CONFIG_EXECUTOR_CREDENTIAL_REFRESH} for details.
+ */
+ private final boolean executorCredentialRefresh;
+
private LanceSparkReadOptions(Builder builder) {
this.datasetUri = builder.datasetUri;
String[] paths = extractDbPathAndDatasetName(datasetUri);
@@ -105,6 +147,7 @@ private LanceSparkReadOptions(Builder builder) {
this.namespace = builder.namespace;
this.tableId = builder.tableId;
this.catalogName = builder.catalogName;
+ this.executorCredentialRefresh = builder.executorCredentialRefresh;
}
/** Creates a new builder for LanceSparkReadOptions. */
@@ -239,6 +282,15 @@ public String getCatalogName() {
return catalogName;
}
+ /**
+ * Returns whether executors should rebuild the namespace client and route the dataset open
+ * through the namespace path (for credential refresh). See {@link
+ * #CONFIG_EXECUTOR_CREDENTIAL_REFRESH}.
+ */
+ public boolean isExecutorCredentialRefresh() {
+ return executorCredentialRefresh;
+ }
+
public boolean hasNamespace() {
return namespace != null && tableId != null;
}
@@ -275,6 +327,7 @@ public LanceSparkReadOptions withVersion(int newVersion) {
.namespace(this.namespace)
.tableId(this.tableId)
.catalogName(this.catalogName)
+ .executorCredentialRefresh(this.executorCredentialRefresh)
.build();
}
@@ -324,6 +377,7 @@ public boolean equals(Object o) {
return pushDownFilters == that.pushDownFilters
&& batchSize == that.batchSize
&& topNPushDown == that.topNPushDown
+ && executorCredentialRefresh == that.executorCredentialRefresh
&& Objects.equals(nearest, that.nearest)
&& Objects.equals(datasetUri, that.datasetUri)
&& Objects.equals(blockSize, that.blockSize)
@@ -347,7 +401,8 @@ public int hashCode() {
nearest,
topNPushDown,
storageOptions,
- tableId);
+ tableId,
+ executorCredentialRefresh);
}
/** Builder for creating LanceSparkReadOptions instances. */
@@ -365,6 +420,7 @@ public static class Builder {
private LanceNamespace namespace;
private List tableId;
private String catalogName;
+ private boolean executorCredentialRefresh = DEFAULT_EXECUTOR_CREDENTIAL_REFRESH;
private Builder() {}
@@ -442,6 +498,11 @@ public Builder catalogName(String catalogName) {
return this;
}
+ public Builder executorCredentialRefresh(boolean executorCredentialRefresh) {
+ this.executorCredentialRefresh = executorCredentialRefresh;
+ return this;
+ }
+
/**
* Parses options from a map, extracting read-specific settings.
*
@@ -450,39 +511,18 @@ public Builder catalogName(String catalogName) {
*/
public Builder fromOptions(Map options) {
this.storageOptions = new HashMap<>(options);
- if (options.containsKey(CONFIG_PUSH_DOWN_FILTERS)) {
- this.pushDownFilters = Boolean.parseBoolean(options.get(CONFIG_PUSH_DOWN_FILTERS));
- }
- if (options.containsKey(CONFIG_BLOCK_SIZE)) {
- this.blockSize = Integer.parseInt(options.get(CONFIG_BLOCK_SIZE));
- }
- if (options.containsKey(CONFIG_VERSION)) {
- this.version = Integer.parseInt(options.get(CONFIG_VERSION));
- }
- if (options.containsKey(CONFIG_INDEX_CACHE_SIZE)) {
- this.indexCacheSize = Integer.parseInt(options.get(CONFIG_INDEX_CACHE_SIZE));
- }
- if (options.containsKey(CONFIG_METADATA_CACHE_SIZE)) {
- this.metadataCacheSize = Integer.parseInt(options.get(CONFIG_METADATA_CACHE_SIZE));
- }
- if (options.containsKey(CONFIG_BATCH_SIZE)) {
- int parsedBatchSize = Integer.parseInt(options.get(CONFIG_BATCH_SIZE));
- Preconditions.checkArgument(parsedBatchSize > 0, "batch_size must be positive");
- this.batchSize = parsedBatchSize;
- }
- if (options.containsKey(CONFIG_TOP_N_PUSH_DOWN)) {
- this.topNPushDown = Boolean.parseBoolean(options.get(CONFIG_TOP_N_PUSH_DOWN));
- }
- if (options.containsKey(CONFIG_NEAREST)) {
- String json = options.get(CONFIG_NEAREST);
- nearest(json);
- }
+ parseTypedFlags(options);
return this;
}
/**
* Merges catalog config options as defaults (read options override).
*
+ *
Also promotes recognized typed flags from the catalog config into their corresponding
+ * Builder fields so that catalog-level settings (e.g. {@code spark.sql.catalog..})
+ * take effect on paths that do not later go through {@link #fromOptions(Map)} — notably SQL DML
+ * (DELETE / UPDATE / MERGE INTO) and plain SELECT without per-read {@code .option(...)}.
+ *
* @param catalogConfig the catalog config
* @return this builder
*/
@@ -490,8 +530,45 @@ public Builder withCatalogDefaults(LanceSparkCatalogConfig catalogConfig) {
// Merge storage options: catalog options are defaults, current options override
Map merged = new HashMap<>(catalogConfig.getStorageOptions());
merged.putAll(this.storageOptions);
- this.storageOptions = merged;
- return this;
+ return fromOptions(merged);
+ }
+
+ /**
+ * Applies typed-flag parsing for every known read option present in {@code opts}. Shared by
+ * {@link #fromOptions(Map)} and {@link #withCatalogDefaults(LanceSparkCatalogConfig)} so that
+ * both call sites stay in sync and catalog-level configs reach the typed fields.
+ */
+ private void parseTypedFlags(Map opts) {
+ if (opts.containsKey(CONFIG_PUSH_DOWN_FILTERS)) {
+ this.pushDownFilters = Boolean.parseBoolean(opts.get(CONFIG_PUSH_DOWN_FILTERS));
+ }
+ if (opts.containsKey(CONFIG_BLOCK_SIZE)) {
+ this.blockSize = Integer.parseInt(opts.get(CONFIG_BLOCK_SIZE));
+ }
+ if (opts.containsKey(CONFIG_VERSION)) {
+ this.version = Integer.parseInt(opts.get(CONFIG_VERSION));
+ }
+ if (opts.containsKey(CONFIG_INDEX_CACHE_SIZE)) {
+ this.indexCacheSize = Integer.parseInt(opts.get(CONFIG_INDEX_CACHE_SIZE));
+ }
+ if (opts.containsKey(CONFIG_METADATA_CACHE_SIZE)) {
+ this.metadataCacheSize = Integer.parseInt(opts.get(CONFIG_METADATA_CACHE_SIZE));
+ }
+ if (opts.containsKey(CONFIG_BATCH_SIZE)) {
+ int parsedBatchSize = Integer.parseInt(opts.get(CONFIG_BATCH_SIZE));
+ Preconditions.checkArgument(parsedBatchSize > 0, "batch_size must be positive");
+ this.batchSize = parsedBatchSize;
+ }
+ if (opts.containsKey(CONFIG_TOP_N_PUSH_DOWN)) {
+ this.topNPushDown = Boolean.parseBoolean(opts.get(CONFIG_TOP_N_PUSH_DOWN));
+ }
+ if (opts.containsKey(CONFIG_NEAREST)) {
+ nearest(opts.get(CONFIG_NEAREST));
+ }
+ if (opts.containsKey(CONFIG_EXECUTOR_CREDENTIAL_REFRESH)) {
+ this.executorCredentialRefresh =
+ Boolean.parseBoolean(opts.get(CONFIG_EXECUTOR_CREDENTIAL_REFRESH));
+ }
}
public LanceSparkReadOptions build() {
diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java
index c302d6218..d7889ae80 100644
--- a/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java
+++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/internal/LanceFragmentScanner.java
@@ -56,7 +56,18 @@ public static LanceFragmentScanner create(int fragmentId, LanceInputPartition in
Dataset dataset = null;
try {
LanceSparkReadOptions readOptions = inputPartition.getReadOptions();
- if (inputPartition.getNamespaceImpl() != null) {
+ // Optionally rebuild the namespace client on the executor so the dataset open routes through
+ // Utils.OpenDatasetBuilder's namespaceClient branch. This preserves the storage options
+ // provider on the Rust side, which refreshes short-lived vended credentials (e.g. STS
+ // tokens) during long-running scans. The price is an eager describeTable() RPC against the
+ // namespace on every fragment open.
+ //
+ // For catalogs whose backing service authenticates per-call (e.g. Hive Metastore over
+ // Kerberos) executors typically lack a TGT and that RPC fails with "GSS initiate failed".
+ // Setting LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH=false makes executors
+ // skip the rebuild and open the dataset by URI using the initialStorageOptions the driver
+ // already obtained, at the cost of losing the Rust-side credential refresh callback.
+ if (inputPartition.getNamespaceImpl() != null && readOptions.isExecutorCredentialRefresh()) {
if (LanceRuntime.useNamespaceOnWorkers(inputPartition.getNamespaceImpl())) {
readOptions.setNamespace(
LanceRuntime.getOrCreateNamespace(
diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java
index 8336824d5..6567db776 100644
--- a/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java
+++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/LanceSparkReadOptionsSerializationTest.java
@@ -23,6 +23,9 @@
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
public class LanceSparkReadOptionsSerializationTest {
@@ -106,4 +109,121 @@ public void testUseIndexSerialization() throws IOException, ClassNotFoundExcepti
deserializedOptionsTrue.getNearest().isUseIndex(),
"useIndex should remain true after serialization/deserialization");
}
+
+ @Test
+ public void testExecutorCredentialRefreshDefaultsToTrue() {
+ LanceSparkReadOptions options =
+ LanceSparkReadOptions.builder().datasetUri("s3://bucket/path").build();
+ Assertions.assertTrue(
+ options.isExecutorCredentialRefresh(),
+ "executor_credential_refresh must default to true to preserve existing behavior");
+ }
+
+ @Test
+ public void testExecutorCredentialRefreshParsedFromOptions() {
+ LanceSparkReadOptions optionsFalse =
+ LanceSparkReadOptions.from(
+ Collections.singletonMap(
+ LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "false"),
+ "s3://bucket/path");
+ Assertions.assertFalse(optionsFalse.isExecutorCredentialRefresh());
+
+ LanceSparkReadOptions optionsTrue =
+ LanceSparkReadOptions.from(
+ Collections.singletonMap(
+ LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "true"),
+ "s3://bucket/path");
+ Assertions.assertTrue(optionsTrue.isExecutorCredentialRefresh());
+ }
+
+ @Test
+ public void testExecutorCredentialRefreshSurvivesSerialization()
+ throws IOException, ClassNotFoundException {
+ LanceSparkReadOptions options =
+ LanceSparkReadOptions.builder()
+ .datasetUri("s3://bucket/path")
+ .executorCredentialRefresh(false)
+ .build();
+ Assertions.assertFalse(options.isExecutorCredentialRefresh());
+
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try (ObjectOutputStream oos = new ObjectOutputStream(baos)) {
+ oos.writeObject(options);
+ }
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ LanceSparkReadOptions deserialized;
+ try (ObjectInputStream ois = new ObjectInputStream(bais)) {
+ deserialized = (LanceSparkReadOptions) ois.readObject();
+ }
+
+ Assertions.assertFalse(
+ deserialized.isExecutorCredentialRefresh(),
+ "executor_credential_refresh must survive Java serialization (driver -> executor handoff)");
+ }
+
+ @Test
+ public void testExecutorCredentialRefreshPreservedByWithVersion() {
+ LanceSparkReadOptions options =
+ LanceSparkReadOptions.builder()
+ .datasetUri("s3://bucket/path")
+ .executorCredentialRefresh(false)
+ .build();
+
+ LanceSparkReadOptions pinned = options.withVersion(7);
+ Assertions.assertFalse(
+ pinned.isExecutorCredentialRefresh(),
+ "withVersion() must propagate the executor_credential_refresh flag");
+ }
+
+ /**
+ * Catalog-level config (set via {@code --conf spark.sql.catalog..}) is the only route
+ * available to SQL DML (DELETE / UPDATE / MERGE INTO), which has no per-statement {@code
+ * .option(...)} attach point. This test guards the catalog-conf path.
+ */
+ @Test
+ public void testExecutorCredentialRefreshFromCatalogDefaults() {
+ Map catalogOpts = new HashMap<>();
+ catalogOpts.put(LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "false");
+ LanceSparkCatalogConfig catalogConfig = LanceSparkCatalogConfig.from(catalogOpts);
+
+ LanceSparkReadOptions options =
+ LanceSparkReadOptions.builder()
+ .datasetUri("s3://bucket/path")
+ .withCatalogDefaults(catalogConfig)
+ .build();
+
+ Assertions.assertFalse(
+ options.isExecutorCredentialRefresh(),
+ "executor_credential_refresh set at catalog level must land in the typed field "
+ + "so it takes effect for SELECT without .option(...) and for SQL DML");
+ }
+
+ /**
+ * Spark's scan-time options (via {@code spark.read.option(...)}) go through a second {@code
+ * fromOptions(mergedMap)} rebuild in {@code LanceDataset.newScanBuilder}. Per-read settings must
+ * win over catalog-level defaults.
+ */
+ @Test
+ public void testPerReadOptionOverridesCatalogDefaults() {
+ Map catalogOpts = new HashMap<>();
+ catalogOpts.put(LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "false");
+ LanceSparkCatalogConfig catalogConfig = LanceSparkCatalogConfig.from(catalogOpts);
+
+ // Simulate the rebuild path in LanceDataset.newScanBuilder: the builder starts by applying
+ // the catalog defaults, then fromOptions() replays against the merged (catalog + per-read)
+ // map where the per-read value wins.
+ Map merged = new HashMap<>(catalogConfig.getStorageOptions());
+ merged.put(LanceSparkReadOptions.CONFIG_EXECUTOR_CREDENTIAL_REFRESH, "true");
+
+ LanceSparkReadOptions options =
+ LanceSparkReadOptions.builder()
+ .datasetUri("s3://bucket/path")
+ .withCatalogDefaults(catalogConfig)
+ .fromOptions(merged)
+ .build();
+
+ Assertions.assertTrue(
+ options.isExecutorCredentialRefresh(),
+ "per-read .option(...) must override the catalog-level default");
+ }
}
diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/internal/LanceFragmentScannerTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/internal/LanceFragmentScannerTest.java
index 13a33be7e..edeeb51f1 100644
--- a/lance-spark-base_2.12/src/test/java/org/lance/spark/internal/LanceFragmentScannerTest.java
+++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/internal/LanceFragmentScannerTest.java
@@ -13,8 +13,13 @@
*/
package org.lance.spark.internal;
+import org.lance.namespace.LanceNamespace;
import org.lance.spark.LanceConstant;
+import org.lance.spark.LanceSparkReadOptions;
+import org.lance.spark.read.LanceInputPartition;
+import org.lance.spark.utils.Optional;
+import org.apache.arrow.memory.BufferAllocator;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -23,9 +28,14 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
public class LanceFragmentScannerTest {
@@ -199,4 +209,77 @@ public void testGetColumnNamesWithFragmentId() throws Exception {
List expected = Arrays.asList("id");
assertEquals(expected, result);
}
+
+ /**
+ * Locks down the executor-branch contract for {@code executor_credential_refresh=false}: when an
+ * executor opens a fragment for a namespace-backed table with the flag disabled, {@link
+ * LanceFragmentScanner#create} must not reconstruct the namespace client. Without this
+ * gate, executors of Kerberized HMS catalogs hit {@code GSS initiate failed} because they lack a
+ * TGT for the eager {@code describeTable()} RPC.
+ *
+ *
Strategy: hand a real, loadable {@link LanceNamespace} impl to the partition. If the gate
+ * regresses (rebuild not skipped), {@code LanceNamespace.connect} would succeed via {@code
+ * Class.forName}, {@link RecordingNamespace#initialize} would run, and {@code
+ * readOptions.setNamespace(...)} would fire — all observable here. The bogus dataset URI lets the
+ * outer {@code Utils.openDatasetBuilder().build()} call fail predictably, since the gate runs
+ * before the dataset is opened. No real Lance dataset is required.
+ */
+ @Test
+ public void testCreateSkipsNamespaceRebuildWhenExecutorCredentialRefreshDisabled() {
+ RecordingNamespace.INITIALIZE_CALLS.set(0);
+
+ LanceSparkReadOptions readOptions =
+ LanceSparkReadOptions.builder()
+ .datasetUri("file:///tmp/__lance_nonexistent_for_executor_gate_test__")
+ .executorCredentialRefresh(false)
+ .build();
+
+ LanceInputPartition partition =
+ new LanceInputPartition(
+ new StructType(),
+ 0,
+ null,
+ readOptions,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ "test-scan",
+ Collections.emptyMap(),
+ RecordingNamespace.class.getName(),
+ Collections.emptyMap(),
+ null);
+
+ assertThrows(RuntimeException.class, () -> LanceFragmentScanner.create(0, partition));
+
+ assertNull(
+ readOptions.getNamespace(),
+ "executor_credential_refresh=false must skip namespace rebuild on the executor");
+ assertEquals(
+ 0,
+ RecordingNamespace.INITIALIZE_CALLS.get(),
+ "executor_credential_refresh=false must not load or initialize the namespace impl");
+ }
+
+ /**
+ * Public, top-level-by-FQCN, no-arg {@link LanceNamespace} so that {@link
+ * LanceNamespace#connect(String, Map, BufferAllocator)} can resolve it via {@code Class.forName}
+ * if the executor branch is (incorrectly) taken.
+ */
+ public static class RecordingNamespace implements LanceNamespace {
+ static final AtomicInteger INITIALIZE_CALLS = new AtomicInteger();
+
+ public RecordingNamespace() {}
+
+ @Override
+ public void initialize(Map properties, BufferAllocator allocator) {
+ INITIALIZE_CALLS.incrementAndGet();
+ }
+
+ @Override
+ public String namespaceId() {
+ return "recording";
+ }
+ }
}
diff --git a/lance-spark-knn-4.2_2.13/pom.xml b/lance-spark-knn-4.2_2.13/pom.xml
new file mode 100644
index 000000000..282715f75
--- /dev/null
+++ b/lance-spark-knn-4.2_2.13/pom.xml
@@ -0,0 +1,165 @@
+
+
+ 4.0.0
+
+
+ org.lance
+ lance-spark-root
+ 0.4.0-beta.4
+ ../pom.xml
+
+
+ lance-spark-knn-4.2_2.13
+ ${project.artifactId}
+
+ Phase 2 Catalyst integration for the indexed nearest-by join. Pattern-matches Spark 4.2's
+ NearestByJoin operator and rewrites it to call the staged probe/merge/materialize
+ pipeline from lance-spark-knn. Spark 4.2-SNAPSHOT-only (NearestByJoin was added in
+ SPARK-56395, post-4.1).
+
+ jar
+
+
+ ${scala213.version}
+ ${scala213.compat.version}
+ ${java17.release}
+
+ 4.2.0-SNAPSHOT
+
+ 19.0.0
+
+
+
+
+ org.lance
+ lance-spark-knn_${scala.compat.version}
+ ${project.version}
+
+
+
+ org.apache.arrow
+ arrow-memory-netty-buffer-patch
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.compat.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-catalyst_${scala.compat.version}
+ ${spark.version}
+ provided
+
+
+
+ org.junit.jupiter
+ junit-jupiter
+ ${junit.version}
+ test
+
+
+ org.lance
+ lance-spark-knn_${scala.compat.version}
+ ${project.version}
+ test-jar
+ test
+
+
+
+ org.lance
+ lance-spark-4.1_${scala.compat.version}
+ ${project.version}
+ test
+
+
+
+
+
+
+
+ org.codehaus.mojo
+ exec-maven-plugin
+ 3.1.0
+
+ test
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ ${scala-maven-plugin.version}
+
+
+ scala-compile-first
+ process-resources
+
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ -feature
+ -release
+ ${java.release}
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ ${maven-compiler-plugin.version}
+
+ ${java.release}
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+
+ test-jar
+
+
+
+
+
+
+
diff --git a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala
new file mode 100644
index 000000000..c40c074bd
--- /dev/null
+++ b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRule.scala
@@ -0,0 +1,472 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.catalyst
+
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, VectorCosineSimilarity, VectorInnerProduct, VectorL2Distance}
+import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, NearestByDirection, NearestByDistance, NearestBySimilarity}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, NearestByJoin, Project, SubqueryAlias}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, Metric}
+import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceMergeLogicalPlan, LanceProbeLogicalPlan, ProbedLeftCodec}
+
+/**
+ * Catalyst rule that rewrites a Spark [[NearestByJoin]] (`approx = true`) over a Lance scan with
+ * a recognized vector-distance ranking expression into the 3-plan staged tree
+ * (`LanceProbeLogicalPlan → LanceMergeLogicalPlan → LanceMaterializeLogicalPlan`), wrapped in a
+ * top-level [[Project]] that restores `NearestByJoin.output` exactly. The shared
+ * [[org.lance.spark.knn.internal.staged.LanceKnnStagedStrategy]] then lowers that tree to the
+ * matching probe/merge/materialize execs — identical shape to the DataFrame API path.
+ *
+ * == Why this rule must be a `postHocResolutionRule`, not an optimizer rule ==
+ *
+ * Spark's built-in [[org.apache.spark.sql.catalyst.optimizer.RewriteNearestByJoin]] rule runs in
+ * the optimizer's `FinishAnalysis` batch — the very first batch. `injectOptimizerRule` adds
+ * rules to `operatorOptimizationBatch`, which runs AFTER `FinishAnalysis`. By the time an
+ * injected optimizer rule fires, the `NearestByJoin` operator has already been replaced with the
+ * cross-product + `MaxMinByK` rewrite, and we have nothing to pattern-match.
+ *
+ * `injectPostHocResolutionRule` runs after analysis but before any optimizer batch — it is the
+ * only injection point that sees the unrewritten `NearestByJoin`. The same constraint applies to
+ * any future engine wanting to substitute a different physical strategy for `NearestByJoin`.
+ *
+ * == Pattern match ==
+ *
+ * The rule fires on the conjunction of:
+ * - `NearestByJoin(_, right, joinType, approx = true, k, rankingExpression, direction)`
+ * - `right` resolves to a Lance DSv2 relation (immediate or under a `SubqueryAlias`)
+ * - `rankingExpression` is one of three recognized vector functions, AND its direction matches
+ * the direction declared on `NearestByJoin`:
+ *
+ * | Spark expression | direction | metric |
+ * |---------------------------------|------------------------|----------------|
+ * | `VectorL2Distance(L, R)` | `NearestByDistance` | `Metric.L2` |
+ * | `VectorCosineSimilarity(L, R)` | `NearestBySimilarity` | `Metric.Cosine`|
+ * | `VectorInnerProduct(L, R)` | `NearestBySimilarity` | `Metric.Dot` |
+ *
+ * Any other shape is left alone — Spark's default cross-product rewrite handles it.
+ *
+ * The two arguments of the ranking function must each resolve to an [[Attribute]] from one
+ * specific side of the join. Mixed-side compounds (e.g. `l2_distance(left.vec, left.vec)`) and
+ * derived expressions (e.g. `l2_distance(left.vec, slice(right.vec, ...))`) are out of scope and
+ * fall through to the cross-product rewrite. Phase 3 may extend this.
+ *
+ * == Lance scan detection ==
+ *
+ * Class-name match: `getClass.getName.contains("Lance")`. The probe / materialize stages
+ * drive Lance's Java API directly, so the indexed-path executor is Lance-specific by
+ * construction — there's no general "any vector-capable backend" extension point here. URI
+ * extracted from the standard `path` / `datasetUri` option. The rule is opt-in via
+ * `spark.lance.knn.indexedNearestByJoin.enabled`, so a false positive can only fire when
+ * the user explicitly enabled the feature against a non-Lance backend; the runtime probe
+ * would surface the mismatch.
+ *
+ * == Prefilter pushdown ==
+ *
+ * If the right side is a `Filter(cond, lance)` (a `WHERE` clause on the indexed table), the
+ * rule translates the predicate to a Lance SQL filter string and threads it through to the
+ * probe. Lance applies the filter BEFORE the index lookup (we always pass `prefilter = true`),
+ * so the top-K is computed over only the rows matching the filter — the only correct behavior
+ * for `right WHERE p APPROX NEAREST K`.
+ *
+ * Translation is conservative: it handles binary comparisons (=, !=, <, <=, >, >=), `IN`,
+ * `IS [NOT] NULL`, `AND`/`OR`/`NOT` over right-side attributes vs. literals. Anything else
+ * (UDFs, subqueries, computed expressions) means the rule REFUSES the rewrite and returns the
+ * original `NearestByJoin`, falling through to Spark's brute-force cross-product. Refusal — not
+ * "push what we can, drop the rest" — because dropping a residual would silently change result
+ * semantics. The job becomes slow rather than wrong.
+ *
+ * Filter pushdown into the V2 relation does NOT happen at this point: this rule runs as a
+ * `postHocResolutionRule` (before the optimizer), so the right side is still the freshly
+ * analyzed `Filter` over `DataSourceV2Relation` — the V2 `SupportsPushDownFilters` step has not
+ * yet run. After we rewrite, the right side is absorbed into our plan, so V2 pushdown never
+ * gets a chance to drop the filter on the floor.
+ */
+object IndexedNearestByJoinRule extends Rule[LogicalPlan] {
+
+ /** Configuration key that gates the rule. Off by default to keep the rule opt-in for now. */
+ val EnabledConfKey: String = "spark.lance.knn.indexedNearestByJoin.enabled"
+
+ /**
+ * IVF cluster count to visit per query. Higher = better recall, more compute. Default
+ * (None) leaves Lance's index-default (typically 1).
+ */
+ val NprobesConfKey: String = "spark.lance.knn.nprobes"
+
+ /**
+ * IVF-PQ refine factor — Lance fetches `K * refineFactor` PQ candidates and re-ranks them
+ * with exact distance using full vectors. Highest-leverage recall knob for IVF-PQ. Default
+ * (None) leaves Lance's index-default (= 1, no re-rank).
+ */
+ val RefineFactorConfKey: String = "spark.lance.knn.refineFactor"
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!conf.getConfString(EnabledConfKey, "false").toBoolean) {
+ return plan
+ }
+ val nprobes = optInt(NprobesConfKey)
+ val refineFactor = optInt(RefineFactorConfKey)
+ plan.transformDown {
+ case j @ NearestByJoin(left, right, joinType, true, k, rankingExpr, direction) =>
+ rewriteIfApplicable(
+ j,
+ left,
+ right,
+ joinType,
+ k,
+ rankingExpr,
+ direction,
+ nprobes,
+ refineFactor).getOrElse(j)
+ }
+ }
+
+ private def optInt(key: String): Option[Int] =
+ Option(conf.getConfString(key, null)).map(_.toInt)
+
+ /**
+ * Rewrite `NearestByJoin` into the 3-exec staged logical-plan tree — the same tree
+ * `IndexedNearestJoin.apply` produces on the DataFrame API path.
+ *
+ * {{{
+ * Project(j.output, drop __score)
+ * +- LanceMaterializeLogicalPlan output = left ++ right ++ __score
+ * +- LanceMergeLogicalPlan output = [_leftId, leftFields, _refs]
+ * +- LanceProbeLogicalPlan output = [_leftId, leftFields, _refs]
+ * +- left
+ * }}}
+ *
+ * We add a top-level `Project` because `NearestByJoin.output` is `left ++ right` (no
+ * score), but the materialize stage emits `left ++ right ++ __score`. The Project
+ * slices the trailing score attribute — Catalyst's ColumnPruning won't interfere because
+ * `LanceMaterializeLogicalPlan` overrides `references = child.outputSet`.
+ */
+ private def rewriteIfApplicable(
+ j: NearestByJoin,
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ k: Int,
+ rankingExpr: Expression,
+ direction: NearestByDirection,
+ nprobes: Option[Int],
+ refineFactor: Option[Int]): Option[LogicalPlan] = {
+ for {
+ lance <- unwrapLanceScan(right)
+ (metric, leftVecAttr, rightVecCol) <- recognizeRanking(rankingExpr, direction, left, right)
+ } yield {
+ val leftVecIdx = left.output.indexWhere(_.exprId == leftVecAttr.exprId)
+ require(leftVecIdx >= 0, s"left vector attr not found in left.output: $leftVecAttr")
+
+ val internalK = k // no overfetch on the SQL path
+ val probeConf = LanceProbeStage.Conf(
+ datasetUri = lance.uri,
+ fragmentIds = None,
+ vectorColumn = rightVecCol,
+ version = lance.version,
+ metric = metric,
+ k = internalK,
+ nprobes = nprobes,
+ leftVecIdx = leftVecIdx,
+ refineFactor = refineFactor,
+ prefilter = lance.prefilter)
+
+ val mergeConf = LanceMergeStage.Conf(finalK = k, smallerIsBetter = metric.smallerIsBetter)
+
+ val rightSchema = lance.output
+ val rightFields: Seq[StructField] = rightSchema.map(a =>
+ StructField(a.name, a.dataType, nullable = true))
+ val rightProjection: Seq[String] = rightSchema.map(_.name)
+
+ val materializeConf = LanceMaterializeStage.Conf(
+ datasetUri = lance.uri,
+ version = lance.version,
+ rightProjection = rightProjection,
+ rightFields = rightFields,
+ leftFieldCount = left.output.size,
+ outerJoin = joinType == LeftOuter)
+
+ // Build the three logical plans. Attributes must be created once and shared between
+ // probe + merge so Catalyst's attribute resolution is happy at the inter-stage
+ // boundary (exprId equality, not just name).
+ val leftSchemaStruct = StructType(left.output.map(a =>
+ StructField(a.name, a.dataType, a.nullable)))
+ val interStageAttrs = ProbedLeftCodec.interStageAttributes(leftSchemaStruct)
+
+ val probeLogical = LanceProbeLogicalPlan(
+ child = left,
+ stageConf = probeConf,
+ fragmentGroups = None, // SQL path uses single-task probe; Phase 1.5 is DataFrame-only
+ leftSchema = leftSchemaStruct,
+ interStageOutput = interStageAttrs)
+
+ val mergeLogical = LanceMergeLogicalPlan(
+ child = probeLogical,
+ stageConf = mergeConf,
+ leftSchema = leftSchemaStruct,
+ interStageOutput = interStageAttrs)
+
+ // `LanceMaterializeLogicalPlan.finalSchema` is left ++ right ++ __score. The SQL output
+ // is j.output (= left ++ right, no score). Match finalOutput to `j.output ++ scoreAttr`
+ // so the inner node's output is stable; then Project drops __score at the top.
+ //
+ // `NearestByJoin.output` widens every left+right attribute to `nullable = true` — a
+ // contract the base Spark rewrite also honors (see `NearestByJoin.output` scaladoc).
+ // `finalSchema` feeds `ExpressionEncoder` in `LanceMaterializeExec.doExecute`; if we
+ // leave left fields with the raw `nullable = false` but the logical output declares
+ // them nullable, the encoder's binary layout drifts from what downstream consumers
+ // expect. Widen left here to keep the encoder consistent with the declared output.
+ val scoreAttr = AttributeReference("__score", FloatType, nullable = true)()
+ val finalSchema = StructType(
+ leftSchemaStruct.fields.map(_.copy(nullable = true)) ++
+ rightFields.map(f => f.copy(nullable = true)) :+
+ StructField("__score", FloatType, nullable = true))
+ val finalOutput: Seq[Attribute] = j.output :+ scoreAttr
+
+ val materializeLogical = LanceMaterializeLogicalPlan(
+ child = mergeLogical,
+ stageConf = materializeConf,
+ leftSchema = leftSchemaStruct,
+ finalSchema = finalSchema,
+ finalOutput = finalOutput)
+
+ // Top-level Project drops the __score attr so the plan's external output matches
+ // NearestByJoin.output exactly.
+ Project(j.output, materializeLogical)
+ }
+ }
+
+ /** Lance scan info extracted from a DSv2 relation, optionally with a translated prefilter. */
+ final private case class LanceScanInfo(
+ uri: String,
+ version: Option[Long],
+ output: Seq[Attribute],
+ prefilter: Option[String])
+
+ private def unwrapLanceScan(plan: LogicalPlan): Option[LanceScanInfo] = plan match {
+ case SubqueryAlias(_, child) => unwrapLanceScan(child)
+ case v: org.apache.spark.sql.catalyst.plans.logical.View =>
+ // SQL `createOrReplaceTempView` + `spark.sql(... FROM ...)` wraps the underlying
+ // DataSourceV2Relation in a `View`. Unwrap to find the actual relation underneath.
+ unwrapLanceScan(v.children.head)
+ case Filter(cond, child) =>
+ // Right-side `WHERE` clause. Recurse first so we have the relation's output to validate
+ // attribute references against, then translate the predicate. If translation fails we
+ // bail entirely (return None, no rewrite) — pushing only PART of a `WHERE` would silently
+ // change query semantics. The user's filter must be pushed in full or not at all.
+ unwrapLanceScan(child).flatMap { info =>
+ translateFilter(cond, AttributeSet(info.output)).map { sql =>
+ val combined = info.prefilter match {
+ case Some(prev) => Some(s"($prev) AND ($sql)")
+ case None => Some(sql)
+ }
+ info.copy(prefilter = combined)
+ }
+ }
+ case Project(projectList, child) if isPassthroughProject(projectList, child) =>
+ // `SELECT * FROM lance` analyzes to `Project(, lance)` — a pass-through
+ // that preserves attrs and exprIds. Unwrap it. Non-pass-through Projects (renames, drops,
+ // computed columns) would change the schema we rely on for `j.output` mapping, so we
+ // refuse those by falling through to the default `_ => None` case.
+ unwrapLanceScan(child)
+ case rel: DataSourceV2Relation if isLanceTable(rel.table) =>
+ // The probe / materialize stages drive Lance's Java API directly, so this rule is
+ // Lance-specific by construction — there's no plug-in point for a non-Lance backend
+ // here. We detect Lance via class-name match and pull the URI from the standard `path`
+ // / `datasetUri` option. If neither is present we fall through (returning None lets
+ // Spark's brute-force rewrite handle the query).
+ val opts = rel.options
+ val uri = Option(opts.get("path")).orElse(Option(opts.get("datasetUri")))
+ uri.map { u =>
+ LanceScanInfo(
+ uri = u,
+ version = Option(opts.get("version")).map(_.toLong),
+ output = rel.output,
+ prefilter = None)
+ }
+ case _ => None
+ }
+
+ /**
+ * Translate a Spark `Filter` predicate into a Lance SQL filter string. Returns `None` if any
+ * sub-expression isn't supported — refusal, not partial pushdown.
+ *
+ * Supported shapes (over right-side attributes only):
+ * - `attr literal` and `literal attr` for `=`, `!=`, `<`, `<=`, `>`, `>=`
+ * - `attr IS NULL` / `attr IS NOT NULL`
+ * - `attr IN (lit, lit, ...)` (the IN list must be all foldable literals)
+ * - `AND` / `OR` over supported children
+ * - `NOT` over supported child
+ *
+ * Anything else — UDFs, joins, subqueries, expressions on both sides referencing the LEFT
+ * input, computed sub-expressions on the right (e.g. `year(ts) = 2025`) — returns `None`.
+ * Lance's SQL dialect is DataFusion-flavored; the constructs above all parse identically
+ * there, so we don't need to translate operator names beyond literal serialization.
+ */
+ private[catalyst] def translateFilter(
+ expr: Expression,
+ rightAttrs: AttributeSet): Option[String] = expr match {
+ case And(l, r) =>
+ for {
+ a <- translateFilter(l, rightAttrs)
+ b <- translateFilter(r, rightAttrs)
+ } yield s"($a) AND ($b)"
+ case Or(l, r) =>
+ for {
+ a <- translateFilter(l, rightAttrs)
+ b <- translateFilter(r, rightAttrs)
+ } yield s"($a) OR ($b)"
+ case Not(EqualTo(l, r)) =>
+ // Render `NOT (a = b)` as `(a != b)` so it's the natural Lance form.
+ binaryOp(l, r, rightAttrs, "!=")
+ case Not(child) =>
+ translateFilter(child, rightAttrs).map(s => s"NOT ($s)")
+ case IsNull(c) =>
+ asRightColumn(c, rightAttrs).map(name => s"$name IS NULL")
+ case IsNotNull(c) =>
+ asRightColumn(c, rightAttrs).map(name => s"$name IS NOT NULL")
+ case EqualTo(l, r) => binaryOp(l, r, rightAttrs, "=")
+ case GreaterThan(l, r) => binaryOp(l, r, rightAttrs, ">")
+ case GreaterThanOrEqual(l, r) => binaryOp(l, r, rightAttrs, ">=")
+ case LessThan(l, r) => binaryOp(l, r, rightAttrs, "<")
+ case LessThanOrEqual(l, r) => binaryOp(l, r, rightAttrs, "<=")
+ case In(value, list) if list.nonEmpty =>
+ for {
+ col <- asRightColumn(value, rightAttrs)
+ lits <- list.foldLeft(Option(Vector.empty[String])) { (accOpt, e) =>
+ accOpt.flatMap(acc => asLiteral(e).map(acc :+ _))
+ }
+ } yield s"$col IN (${lits.mkString(", ")})"
+ case _ => None
+ }
+
+ private def binaryOp(
+ l: Expression,
+ r: Expression,
+ rightAttrs: AttributeSet,
+ op: String): Option[String] = {
+ // attr literal — the natural shape
+ val attrLit = for {
+ col <- asRightColumn(l, rightAttrs)
+ lit <- asLiteral(r)
+ } yield s"$col $op $lit"
+ // literal attr — flip when the parser/optimizer emitted args in this order. Renders
+ // as `lit op col`, which DataFusion also accepts.
+ attrLit.orElse {
+ for {
+ col <- asRightColumn(r, rightAttrs)
+ lit <- asLiteral(l)
+ } yield s"$lit $op $col"
+ }
+ }
+
+ private def asRightColumn(e: Expression, rightAttrs: AttributeSet): Option[String] = e match {
+ case a: Attribute if rightAttrs.contains(a) => Some(a.name)
+ case _ => None
+ }
+
+ /**
+ * Render a Spark literal as a Lance SQL literal. Dispatch is by `dataType`, NOT by the boxed
+ * value class — Catalyst stores e.g. `Literal(0, DateType)` with the value as a plain `Int`,
+ * so a value-class match would silently let a date literal through as the integer "0", a
+ * recall-corrupting mistranslation.
+ *
+ * Supports nulls, booleans, numeric primitives, and strings (with `'`-escaped quoting). Bails
+ * on dates, timestamps, decimals, binary, arrays, structs — those have non-trivial cross-
+ * dialect renderings and we'd rather refuse pushdown than risk a wrong filter.
+ */
+ private def asLiteral(e: Expression): Option[String] = e match {
+ case Literal(null, _) => Some("NULL")
+ case Literal(v, BooleanType) => Some(v.toString)
+ case Literal(v, ByteType) => Some(v.toString)
+ case Literal(v, ShortType) => Some(v.toString)
+ case Literal(v, IntegerType) => Some(v.toString)
+ case Literal(v, LongType) => Some(v.toString)
+ case Literal(v, FloatType) => Some(v.toString)
+ case Literal(v, DoubleType) => Some(v.toString)
+ case Literal(v: UTF8String, StringType) => Some(quoteString(v.toString))
+ case Literal(v: String, StringType) => Some(quoteString(v))
+ case _ => None
+ }
+
+ private def quoteString(s: String): String = "'" + s.replace("'", "''") + "'"
+
+ /**
+ * True iff the Project is the canonical `SELECT *` form: same number of outputs as the child,
+ * each entry a bare `AttributeReference` whose `exprId` matches the child's output in order.
+ * Any aliasing, reordering, dropping, or computed column — return false and refuse to
+ * unwrap, since those change the schema we'd surface as the join's right-side output.
+ */
+ private def isPassthroughProject(
+ projectList: Seq[org.apache.spark.sql.catalyst.expressions.NamedExpression],
+ child: LogicalPlan): Boolean = {
+ val childOut = child.output
+ if (projectList.size != childOut.size) return false
+ projectList.zip(childOut).forall {
+ case (a: Attribute, c) => a.exprId == c.exprId
+ case _ => false
+ }
+ }
+
+ private def isLanceTable(table: Table): Boolean = {
+ val cls = table.getClass.getName
+ // Loose by design — the rule is opt-in via spark.lance.knn.indexedNearestByJoin.enabled, so
+ // a false positive here would only fire when the user explicitly turned the feature on
+ // against a non-Lance backend, and the runtime probe would surface the mismatch.
+ cls.contains("Lance") || cls.contains("lance")
+ }
+
+ /**
+ * Recognize `rankingExpr` as one of the supported vector-distance functions, AND verify the
+ * declared `direction` on `NearestByJoin` matches the function's natural ordering.
+ *
+ * Returns `(metric, leftVecAttr, rightVecColName)` on success.
+ */
+ private def recognizeRanking(
+ rankingExpr: Expression,
+ direction: NearestByDirection,
+ left: LogicalPlan,
+ right: LogicalPlan): Option[(Metric, Attribute, String)] = {
+ val (metric, lhs, rhs) = rankingExpr match {
+ case VectorL2Distance(l, r) if direction == NearestByDistance => (Metric.L2, l, r)
+ case VectorCosineSimilarity(l, r) if direction == NearestBySimilarity => (Metric.Cosine, l, r)
+ case VectorInnerProduct(l, r) if direction == NearestBySimilarity => (Metric.Dot, l, r)
+ case _ => return None
+ }
+ // Each argument must be a bare attribute from one side of the join.
+ (asAttr(lhs), asAttr(rhs)) match {
+ case (Some(la), Some(ra)) =>
+ val leftAttrIds = left.outputSet
+ val rightAttrIds = right.outputSet
+ if (leftAttrIds.contains(la) && rightAttrIds.contains(ra)) {
+ Some((metric, la, ra.name))
+ } else if (leftAttrIds.contains(ra) && rightAttrIds.contains(la)) {
+ // Argument order swapped — still valid for symmetric metrics. All three of L2/Cosine/Dot
+ // are symmetric so we don't have to retain the original orientation.
+ Some((metric, ra, la.name))
+ } else {
+ None
+ }
+ case _ => None
+ }
+ }
+
+ private def asAttr(e: Expression): Option[Attribute] = e match {
+ case a: Attribute => Some(a)
+ case _ => None
+ }
+}
diff --git a/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/extensions/LanceKnnSparkSessionExtensions.scala b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/extensions/LanceKnnSparkSessionExtensions.scala
new file mode 100644
index 000000000..37181a17d
--- /dev/null
+++ b/lance-spark-knn-4.2_2.13/src/main/scala/org/lance/spark/knn/extensions/LanceKnnSparkSessionExtensions.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.extensions
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.lance.spark.knn.catalyst.IndexedNearestByJoinRule
+import org.lance.spark.knn.internal.staged.LanceKnnStagedStrategy
+
+/**
+ * Registers Phase 2 Catalyst integration for the indexed nearest-by join.
+ *
+ * Wire this into a SparkSession with:
+ *
+ * {{{
+ * SparkSession.builder()
+ * .config("spark.sql.extensions",
+ * "org.lance.spark.knn.extensions.LanceKnnSparkSessionExtensions")
+ * .config("spark.lance.knn.indexedNearestByJoin.enabled", "true")
+ * ...
+ * }}}
+ *
+ * The `enabled` flag gates the rule itself — see [[IndexedNearestByJoinRule.EnabledConfKey]].
+ * Off by default to keep the integration opt-in until the cost gate (Phase 3) is in place.
+ *
+ * == Injection point: postHocResolutionRule, NOT optimizerRule ==
+ *
+ * Spark's `RewriteNearestByJoin` runs in `FinishAnalysis`, which precedes the
+ * `operatorOptimizationBatch` that `injectOptimizerRule` adds rules to. By the time an injected
+ * optimizer rule fires, the `NearestByJoin` operator has already been replaced with the
+ * cross-product + `MaxMinByK` rewrite. `injectPostHocResolutionRule` runs after analysis but
+ * before any optimizer batch — this is the only injection point that sees the unrewritten
+ * `NearestByJoin`. See `IndexedNearestByJoinRule`'s class doc for the full rationale.
+ *
+ * Coexistence: this extension does not replace `LanceSparkSessionExtensions` from the connector
+ * modules; both can be wired together in a comma-separated `spark.sql.extensions` value.
+ */
+class LanceKnnSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ extensions.injectPostHocResolutionRule(_ => IndexedNearestByJoinRule)
+ // Shared with the DataFrame API path: `LanceKnnStagedStrategy` lowers the three
+ // logical plans (`LanceProbeLogicalPlan` / `LanceMergeLogicalPlan` /
+ // `LanceMaterializeLogicalPlan`) to the matching physical execs.
+ // `IndexedNearestJoin.apply` also installs this strategy via
+ // `experimentalMethods.extraStrategies` at first call; wiring it into the session
+ // extension makes the SQL path self-sufficient without depending on a prior DataFrame
+ // call having run.
+ extensions.injectPlannerStrategy(_ => LanceKnnStagedStrategy)
+ }
+}
diff --git a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/benchmark/IndexedNearestByJoinSqlBenchmark.scala b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/benchmark/IndexedNearestByJoinSqlBenchmark.scala
new file mode 100644
index 000000000..753cf0116
--- /dev/null
+++ b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/benchmark/IndexedNearestByJoinSqlBenchmark.scala
@@ -0,0 +1,606 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.sql.{DataFrame, RowFactory, SparkSession}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
+import org.lance.spark.knn.catalyst.IndexedNearestByJoinRule
+import org.lance.spark.knn.internal.LanceVectorIndexBuilder
+
+import java.nio.file.{Files, Paths}
+import java.util.{Locale, Random}
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+/**
+ * SQL-level benchmark for the Phase 2 Catalyst integration. Same `APPROX NEAREST k BY DISTANCE
+ * vector_l2_distance(...)` SQL run with the rule ON vs OFF — measuring the speedup at the SQL
+ * level a user would observe.
+ *
+ * Rule OFF lowers to Spark's built-in `RewriteNearestByJoin`, which is the optimizer rule that
+ * rewrites `NearestByJoin` to `Generate(Inline(Aggregate(min_by_k(...))))` over a
+ * `BroadcastNestedLoopJoin`. Cross-product semantics. This is what every user would get today
+ * out of vanilla Spark when they write `APPROX NEAREST` against any data source.
+ *
+ * Rule ON intercepts the same SQL pre-optimizer and emits the same 3-plan staged tree
+ * (`LanceProbe → LanceMerge → LanceMaterialize`, with a Catalyst-inserted Exchange on the
+ * merge side) that the DataFrame API path builds — single-task probe per partition (Phase
+ * 0/1 default) or fragment-grouped if `probeParallelism > 1`. The SQL path can't expose
+ * `probeParallelism` directly via SQL — that's a session-config or rule-side choice (Phase
+ * 3.x). Defaults to 1.
+ *
+ * Requires Spark 4.2-SNAPSHOT runtime AND lance-spark-4.1 connector recompiled against it. See
+ * `IndexedNearestByJoinE2ETest` class doc for the setup commands.
+ *
+ * Invocation:
+ *
+ * {{{
+ * MAVEN_OPTS="-Xmx12g " \
+ * ./mvnw -pl lance-spark-knn-4.2_2.13 -q exec:java \
+ * -Dexec.classpathScope=test \
+ * -Dexec.mainClass=org.lance.spark.knn.benchmark.IndexedNearestByJoinSqlBenchmark
+ * }}}
+ *
+ * Override scale via `BENCHMARK_SCALE`: `small`, `medium`, or `both` (default).
+ */
+object IndexedNearestByJoinSqlBenchmark {
+
+ private val Dim: Int = 128
+ private val K: Int = 10
+ private val Seed: Long = 1337L
+
+ /**
+ * Data distribution selector. `uniform` = independent floats over [0, 1]^Dim — the IVF worst
+ * case (k-means has no cluster structure to latch onto, recall on indexed paths is poor).
+ * `clustered` = unit-sphere-normalized Gaussian-mixture, the geometry of typical
+ * sentence-transformer / image-feature embeddings. Override via `BENCHMARK_DATA` env var.
+ */
+ sealed private trait DataMode { def label: String }
+ private object DataUniform extends DataMode { val label = "uniform" }
+ private object DataClustered extends DataMode { val label = "clustered" }
+ private val NumClusters: Int = 64
+
+ /**
+ * Index-type selector. `flat` = IVF_FLAT, exact distances within visited clusters (the
+ * workaround we used to get >0.9 recall on uniform-random dim-128 data, where PQ noise
+ * dominates). `pq` = IVF-PQ with `Dim / 16` sub-vectors and 8 bits — the production-realistic
+ * compressed index. PQ is what actually gets used at scale in production deployments because
+ * IVF_FLAT's per-cluster full-vector storage doesn't fit, but PQ's recall is highly sensitive
+ * to data distribution. Override via `BENCHMARK_INDEX` env var.
+ */
+ sealed private trait IndexMode { def label: String }
+ private object IndexFlat extends IndexMode { val label = "ivf_flat" }
+ private object IndexPq extends IndexMode { val label = "ivf_pq" }
+
+ /** runRuleOff: skip the brute-force baseline at scales where it'd take >10 min. */
+ private case class Scale(
+ name: String,
+ numRight: Int,
+ numLeft: Int,
+ numFragments: Int,
+ runRuleOff: Boolean) {
+ override def toString: String = s"$name (|R|=$numRight, |L|=$numLeft, frags=$numFragments)"
+ }
+
+ private val Small =
+ Scale("small", numRight = 100000, numLeft = 100, numFragments = 4, runRuleOff = true)
+ private val Medium =
+ Scale("medium", numRight = 1000000, numLeft = 1000, numFragments = 8, runRuleOff = false)
+
+ private case class Result(scale: String, config: String, medianMs: Long, runs: Seq[Long])
+
+ def main(args: Array[String]): Unit = {
+ val scales = sys.env.getOrElse("BENCHMARK_SCALE", "both").toLowerCase(Locale.ROOT) match {
+ case "small" => Seq(Small)
+ case "medium" => Seq(Medium)
+ case _ => Seq(Small, Medium)
+ }
+ val dataMode: DataMode =
+ sys.env.getOrElse("BENCHMARK_DATA", "uniform").toLowerCase(Locale.ROOT) match {
+ case "clustered" => DataClustered
+ case _ => DataUniform
+ }
+ val indexMode: IndexMode =
+ sys.env.getOrElse("BENCHMARK_INDEX", "flat").toLowerCase(Locale.ROOT) match {
+ case "pq" | "ivf_pq" => IndexPq
+ case _ => IndexFlat
+ }
+
+ println(banner("Phase 2 SQL Benchmark — APPROX NEAREST with rule ON vs OFF"))
+ println(s"Spark: 4.2-SNAPSHOT, master=local[*] Dim: $Dim K: $K Seed: $Seed")
+ println(
+ s"Scales: ${scales.map(_.name).mkString(", ")} Data: ${dataMode.label} " +
+ s"Index: ${indexMode.label}")
+ println()
+
+ val spark = SparkSession
+ .builder()
+ .appName("indexed-nearest-by-join-sql-benchmark")
+ .master("local[*]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config(
+ "spark.sql.extensions",
+ "org.lance.spark.knn.extensions.LanceKnnSparkSessionExtensions")
+ .config("spark.sql.crossJoin.enabled", "true")
+ .config("spark.sql.shuffle.partitions", "32")
+ .getOrCreate()
+ spark.sparkContext.setLogLevel("WARN")
+
+ val tmpRoot = Files.createTempDirectory("knn-sql-bench-")
+ val results = scala.collection.mutable.ArrayBuffer.empty[Result]
+ try {
+ scales.foreach { scale =>
+ println(banner(s"Scale: $scale Data: ${dataMode.label}"))
+ val rng = new Random(Seed)
+ println(s" Generating ${scale.numLeft} left rows × dim $Dim (${dataMode.label}) ...")
+ val leftVecs = generateVectors(dataMode, scale.numLeft, Dim, Seed)
+ val leftRows = (0 until scale.numLeft).map { i =>
+ RowFactory.create(Integer.valueOf(i), leftVecs(i))
+ }
+ spark.createDataFrame(leftRows.asJava, leftSchema()).createOrReplaceTempView("queries")
+
+ val rightUri = Paths.get(tmpRoot.toString, s"right_${scale.name}").toString
+ println(s" Writing ${scale.numRight} right rows × dim $Dim across ${scale.numFragments} " +
+ s"Spark partitions to $rightUri (${dataMode.label}) ...")
+ val rightVecs = generateVectors(dataMode, scale.numRight, Dim, Seed + 1)
+ val rightRows = (0 until scale.numRight).map { i =>
+ RowFactory.create(Integer.valueOf(i + 1000000), rightVecs(i))
+ }
+ spark.createDataFrame(rightRows.asJava, rightSchema())
+ .repartition(scale.numFragments)
+ .write.format("lance").save(rightUri)
+ spark.read.format("lance").load(rightUri).createOrReplaceTempView("docs")
+
+ val sql =
+ s"""SELECT q.lid, d.rid
+ |FROM queries q INNER JOIN docs d
+ |APPROX NEAREST $K BY DISTANCE vector_l2_distance(q.lvec, d.rvec)""".stripMargin
+
+ // Cross-config validation: confirm rule ON returns the same top-K row IDs as rule OFF
+ // (= Spark's RewriteNearestByJoin = brute-force ground truth) on a 16-row left
+ // subset. Run before timing so the measured speedup is on equivalent results, not
+ // on two paths that disagree on output. The check uses a separate `queries_small`
+ // view (16 rows) so the rule-OFF cross-product is 16 × |R| (sub-second), not |L| × |R|.
+ verifyRuleOnMatchesRuleOff(spark, leftRows)
+
+ if (scale.runRuleOff) {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "false")
+ val r =
+ timeIt(scale.name, "A: rule OFF (Spark RewriteNearestByJoin)", () => spark.sql(sql))
+ results += r
+ println(formatResult(r))
+ }
+
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val r =
+ timeIt(scale.name, "B: rule ON (no index, Lance brute-force scan)", () => spark.sql(sql))
+ results += r
+ println(formatResult(r))
+
+ // Build the chosen vector index on the right dataset and time the rule-ON path again.
+ // Lance's nearest-search auto-detects the index and switches to the indexed-scan code
+ // path. Recall < 1.0 in general (approximate index), so we report recall@K rather than
+ // strict-equality validation. numFragments is used as numPartitions so each partition
+ // roughly maps to one Spark fragment.
+ //
+ // IVF_FLAT vs. IVF-PQ:
+ // - IVF_FLAT stores full vectors per cluster — exact distances within visited
+ // clusters, no PQ noise. Higher disk/memory cost but recovers high recall on
+ // high-dim or random workloads. The workaround when PQ noise dominates.
+ // - IVF-PQ compresses each vector into `Dim / 16` 8-bit sub-vectors. Smaller index
+ // and faster scan, but recall is highly sensitive to data distribution. On
+ // uniform-random high-dim data PQ collapses (~3% recall at defaults); on
+ // production-shaped clustered embeddings PQ codebook training latches onto natural
+ // structure and recall recovers. This is what most production deployments use.
+ println(
+ s" Building ${indexMode.label} index (numPartitions=${scale.numFragments}, " +
+ s"dim=$Dim) ...")
+ val tIdx = System.nanoTime()
+ indexMode match {
+ case IndexFlat =>
+ LanceVectorIndexBuilder.buildIvfFlat(
+ datasetUri = rightUri,
+ vectorColumn = "rvec",
+ numPartitions = scale.numFragments)
+ case IndexPq =>
+ // numSubVectors trades index size + scan speed for code precision:
+ // - Dim/16 (8 at Dim=128): coarse PQ, 16 dims per sub-vector. Compact but very
+ // lossy — recall ~5–10% on dim-128 data even with clustered distribution.
+ // This is the default because it's the only setting Lance can train at our
+ // test scales (100K–1M rows). At 1M rows uniform, Lance rejects
+ // `numSubVectors=32` with "needs 4.3B training samples" — production
+ // deployments at much larger N can train fine PQ but we can't here.
+ // - Dim/4 (32 at Dim=128): fine PQ, 4 dims per sub-vector. The production-
+ // realistic setting; needs > a few B rows to train. Override via
+ // BENCHMARK_PQ_SUBVEC if you have a real dataset that supports it.
+ val pqSubVec = sys.env
+ .getOrElse("BENCHMARK_PQ_SUBVEC", math.max(1, Dim / 16).toString)
+ .toInt
+ println(s" (PQ: numSubVectors=$pqSubVec, numBits=8 — ${Dim / pqSubVec} dims/code)")
+ LanceVectorIndexBuilder.buildIvfPq(
+ datasetUri = rightUri,
+ vectorColumn = "rvec",
+ numPartitions = scale.numFragments,
+ numSubVectors = pqSubVec,
+ numBits = 8)
+ }
+ println(s" ... done in ${TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - tIdx)}s")
+
+ // Default-tuning indexed run: nprobes default (1), refineFactor default (no re-rank).
+ // For 100K rows × dim 128 with 4 IVF partitions this gives terrible recall (~3%) — PQ
+ // compression noise dominates and only 1/4 of the dataset is visited. Demonstrates the
+ // raw indexed-scan speedup but is unusable for real workloads.
+ spark.conf.unset(IndexedNearestByJoinRule.NprobesConfKey)
+ spark.conf.unset(IndexedNearestByJoinRule.RefineFactorConfKey)
+ val rIndexed =
+ timeIt(
+ scale.name,
+ s"C: rule ON (${indexMode.label} indexed, defaults)",
+ () => spark.sql(sql))
+ results += rIndexed
+ println(formatResult(rIndexed))
+ val recallC = computeIndexedRecall(spark, leftRows)
+ println(f" -> recall@$K (indexed defaults, sample 16): $recallC%.3f")
+
+ // Tuned indexed run #1: refineFactor = 8 alone. Fetches K * 8 PQ candidates *from the
+ // probed clusters*, re-ranks with exact distance, returns top K. Sidesteps PQ noise but
+ // can't recover true neighbors that live in unprobed clusters.
+ spark.conf.set(IndexedNearestByJoinRule.RefineFactorConfKey, "8")
+ val rTuned =
+ timeIt(
+ scale.name,
+ s"D: rule ON (${indexMode.label} + refineFactor=64)",
+ () => spark.sql(sql))
+ results += rTuned
+ println(formatResult(rTuned))
+ val recallD = computeIndexedRecall(spark, leftRows)
+ println(f" -> recall@$K (refineFactor=64, sample 16): $recallD%.3f")
+ spark.conf.unset(IndexedNearestByJoinRule.RefineFactorConfKey)
+
+ // Tuned indexed run #2: nprobes = numFragments. Visits every IVF cluster, recovering
+ // true neighbors that the default nprobes=1 cuts away. Speedup degrades because we're
+ // back to scanning roughly the whole dataset (just with extra IVF overhead), but
+ // recall should approach 1.0.
+ spark.conf.set(IndexedNearestByJoinRule.NprobesConfKey, scale.numFragments.toString)
+ val rNprobes =
+ timeIt(
+ scale.name,
+ s"E: rule ON (${indexMode.label} + nprobes=${scale.numFragments})",
+ () => spark.sql(sql))
+ results += rNprobes
+ println(formatResult(rNprobes))
+ val recallE = computeIndexedRecall(spark, leftRows)
+ println(f" -> recall@$K (nprobes=${scale.numFragments}, sample 16): $recallE%.3f")
+
+ // Tuned indexed run #3: nprobes = full + refineFactor = 8. Visits every cluster AND
+ // re-ranks with exact distance. The high-recall configuration; the most expensive of
+ // the indexed paths but still typically faster than rule OFF (Spark's brute-force
+ // crossJoin) because Lance's native scan beats Catalyst per-pair overhead even when
+ // the data volume is the same.
+ spark.conf.set(IndexedNearestByJoinRule.RefineFactorConfKey, "64")
+ val rFull = timeIt(
+ scale.name,
+ s"F: rule ON (${indexMode.label} + nprobes=${scale.numFragments} + refineFactor=64)",
+ () => spark.sql(sql))
+ results += rFull
+ println(formatResult(rFull))
+ val recallF = computeIndexedRecall(spark, leftRows)
+ println(
+ f" -> recall@$K (nprobes=${scale.numFragments} + refineFactor=64, sample 16): $recallF%.3f")
+ spark.conf.unset(IndexedNearestByJoinRule.NprobesConfKey)
+ spark.conf.unset(IndexedNearestByJoinRule.RefineFactorConfKey)
+
+ // Drop the temp views so the next scale starts clean.
+ spark.catalog.dropTempView("queries")
+ spark.catalog.dropTempView("docs")
+ println()
+ }
+
+ println(banner("Summary"))
+ printSummaryTable(results.toSeq)
+ } finally {
+ spark.stop()
+ deleteRecursively(tmpRoot.toFile)
+ }
+ }
+
+ // -- timing harness ----------------------------------------------------------------------
+
+ private val WarmupRuns = 1
+ private val MeasurementRuns = 3
+
+ /**
+ * Execute the plan fully and discard output — Spark's canonical benchmark sink. Same
+ * shape as the other two benchmarks. `count()` would skip result-row materialization
+ * unequally (the crossJoin path skips it entirely; the indexed path still runs
+ * `LanceMaterialize` due to the `references = child.outputSet` override), biasing the
+ * speedup comparison. `noop` sink closes that gap.
+ */
+ private def runFull(df: DataFrame): Unit =
+ df.write.format("noop").mode("overwrite").save()
+
+ private def timeIt(scale: String, config: String, f: () => DataFrame): Result = {
+ print(s" $config ... ")
+ System.out.flush()
+ var i = 0
+ while (i < WarmupRuns) { runFull(f()); i += 1 }
+ val runs = (0 until MeasurementRuns).map { _ =>
+ val t0 = System.nanoTime()
+ runFull(f())
+ TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0)
+ }
+ val median = runs.sorted.apply(runs.size / 2)
+ println(s"runs=${runs.mkString("[", ",", "]")} ms, median=$median ms")
+ Result(scale, config, median, runs)
+ }
+
+ // -- recall measurement -------------------------------------------------------------
+
+ /**
+ * Measure recall@K of the indexed rule-ON path on a 16-row left subset, with rule-OFF
+ * (Spark `RewriteNearestByJoin` = brute-force) as the ground truth. Recall is the average
+ * fraction of the brute-force top-K row IDs that the indexed path also returned.
+ */
+ private def computeIndexedRecall(
+ spark: SparkSession,
+ allLeftRows: Seq[org.apache.spark.sql.Row]): Double = {
+ val sample = allLeftRows.take(16)
+ spark.createDataFrame(sample.asJava, leftSchema()).createOrReplaceTempView("queries_recall")
+ val sampleLids = sample.map(_.getInt(0)).toSet
+ val sql =
+ s"""SELECT q.lid, d.rid
+ |FROM queries_recall q INNER JOIN docs d
+ |APPROX NEAREST $K BY DISTANCE vector_l2_distance(q.lvec, d.rvec)""".stripMargin
+
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "false")
+ val truthRows = spark.sql(sql).collect()
+ val truth = truthRows.groupBy(_.getAs[Int]("lid"))
+ .map { case (lid, rs) => lid -> rs.map(_.getAs[Int]("rid")).toSet }
+
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val indexedRows = spark.sql(sql).collect()
+ val indexed = indexedRows.groupBy(_.getAs[Int]("lid"))
+ .map { case (lid, rs) => lid -> rs.map(_.getAs[Int]("rid")).toSet }
+
+ val perLidRecall = sampleLids.toSeq.map { lid =>
+ val truthSet = truth.getOrElse(lid, Set.empty[Int])
+ val indexedSet = indexed.getOrElse(lid, Set.empty[Int])
+ if (truthSet.isEmpty) 0.0
+ else (truthSet intersect indexedSet).size.toDouble / truthSet.size
+ }
+ spark.catalog.dropTempView("queries_recall")
+ perLidRecall.sum / perLidRecall.size
+ }
+
+ // -- cross-config validation -------------------------------------------------------------
+
+ /**
+ * Run the same SQL twice on a 16-row left subset — once with rule OFF (Spark's
+ * RewriteNearestByJoin = brute-force cross-product + min_by_k = exact ground truth on
+ * no-index Lance), once with rule ON (our 3-exec staged chain). Compare top-K row IDs
+ * per left row. Bail if they disagree.
+ *
+ * Uses a separate `queries_small` view so the rule-OFF cross-product is 16 × |R|
+ * (sub-second), not |L| × |R| which would dominate wall-clock at medium scale.
+ * Compared as Sets to tolerate tied-distance ordering.
+ */
+ private def verifyRuleOnMatchesRuleOff(
+ spark: SparkSession,
+ allLeftRows: Seq[org.apache.spark.sql.Row]): Unit = {
+ println(" Sanity check: rule ON top-K matches rule OFF on a 16-row left subset ...")
+ val sample = allLeftRows.take(16)
+ spark.createDataFrame(sample.asJava, leftSchema()).createOrReplaceTempView("queries_small")
+ // RowFactory.create() makes schema-less Rows, so getAs[String] doesn't work; use positional.
+ val sampleLids = sample.map(_.getInt(0)).toSet
+ val verifySql =
+ s"""SELECT q.lid, d.rid
+ |FROM queries_small q INNER JOIN docs d
+ |APPROX NEAREST $K BY DISTANCE vector_l2_distance(q.lvec, d.rvec)""".stripMargin
+
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "false")
+ val offRows = spark.sql(verifySql).collect()
+ val offMap = offRows.groupBy(_.getAs[Int]("lid"))
+ .map { case (lid, rs) => lid -> rs.map(_.getAs[Int]("rid")).toSet }
+
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val onRows = spark.sql(verifySql).collect()
+ val onMap = onRows.groupBy(_.getAs[Int]("lid"))
+ .map { case (lid, rs) => lid -> rs.map(_.getAs[Int]("rid")).toSet }
+
+ sampleLids.foreach { lid =>
+ val offSet = offMap.getOrElse(lid, Set.empty[Int])
+ val onSet = onMap.getOrElse(lid, Set.empty[Int])
+ if (offSet != onSet) {
+ sys.error(
+ s"RULE ON/OFF MISMATCH at lid=$lid:\n rule OFF: $offSet\n rule ON: $onSet")
+ }
+ }
+ spark.catalog.dropTempView("queries_small")
+ println(s" ... rule ON and rule OFF agree on top-K (sample size: ${sampleLids.size}).")
+ }
+
+ // -- output formatting ------------------------------------------------------------------
+
+ private def banner(s: String): String = s"\n=== $s " + ("=" * (76 - s.length - 5))
+
+ private def formatResult(r: Result): String =
+ f" -> ${r.config}%-44s median=${r.medianMs}%6d ms"
+
+ private def printSummaryTable(results: Seq[Result]): Unit = {
+ val byScale = results.groupBy(_.scale)
+ val scaleOrder = Seq("small", "medium").filter(byScale.contains)
+ val configWidth = 44
+ val numWidth = 13
+ val divider = "-" * (configWidth + scaleOrder.size * numWidth)
+ val header = s"%-${configWidth}s" + scaleOrder.map(_ => s"%${numWidth}s").mkString
+ println(divider)
+ println(header.format(("Configuration" +: scaleOrder.map(s => s"$s (ms)")): _*))
+ println(header.format(("" +: scaleOrder.map(_ => "speedup ×")): _*))
+ println(divider)
+ val configs = results.map(_.config).distinct
+ val baselineByScale = scaleOrder.flatMap { s =>
+ byScale(s).find(_.config.startsWith("A:")).map(b => s -> b.medianMs)
+ }.toMap
+ configs.foreach { config =>
+ val cellsMs = scaleOrder.map { s =>
+ byScale(s).find(_.config == config).map(_.medianMs.toString).getOrElse("-")
+ }
+ println(header.format((config +: cellsMs): _*))
+ val cellsSpeedup = scaleOrder.map { s =>
+ val mineMs = byScale(s).find(_.config == config).map(_.medianMs).getOrElse(0L)
+ val baseMs = baselineByScale.getOrElse(s, 0L)
+ if (config.startsWith("A:")) "1.00x"
+ else if (mineMs <= 0 || baseMs <= 0) "(no base)"
+ else f"${baseMs.toDouble / mineMs}%.2fx"
+ }
+ println(header.format(("" +: cellsSpeedup): _*))
+ }
+ println(divider)
+ println("Same SQL: APPROX NEAREST K BY DISTANCE vector_l2_distance(q.lvec, d.rvec).")
+ println("Rule OFF lowers to Spark's RewriteNearestByJoin (cross-product + min_by_k).")
+ println("Rule ON routes through the 3-exec staged pipeline (LanceProbe → LanceMerge → LanceMaterialize).")
+ }
+
+ // -- schemas + helpers ------------------------------------------------------------------
+
+ private def leftSchema(): StructType = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ private def rightSchema(): StructType = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ /**
+ * Vector generator dispatch. Uniform mode keeps the existing baseline (independent
+ * `nextFloat`s — IVF's worst case at high dim because pairwise distances cluster around
+ * a narrow range). Clustered mode draws unit-sphere-normalized Gaussian-mixture vectors
+ * around `NumClusters` centers — the geometry of typical sentence-transformer / image-
+ * feature embeddings that IVF was actually designed for.
+ */
+ private def generateVectors(
+ mode: DataMode,
+ n: Int,
+ dim: Int,
+ seed: Long): Array[Array[Float]] = mode match {
+ case DataUniform =>
+ val rng = new Random(seed)
+ Array.fill(n)(randomVector(rng, dim))
+ case DataClustered =>
+ // sigma is in units of inter-cluster spacing. Override via BENCHMARK_SIGMA — tighter
+ // clusters (e.g. 0.05) approximate real semantic embeddings where intra-cluster variance
+ // is small relative to inter-cluster separation. Default 0.15 is the moderate setting.
+ val sigma = sys.env.getOrElse("BENCHMARK_SIGMA", "0.15").toDouble
+ generateClusteredVectors(n, dim, NumClusters, sigma = sigma, seed = seed)
+ }
+
+ /**
+ * Inlined clustered-Gaussian-mixture generator. Lives here (rather than reused from
+ * `lance-spark-knn_2.12`'s test util) because that helper is in test scope of a different
+ * module and isn't visible across module test scopes. Logic mirrors
+ * `org.lance.spark.knn.testutil.ClusteredEmbeddings`:
+ *
+ * 1. Pick `numClusters` centers uniformly on [0, 1]^dim.
+ * 2. For each row, round-robin a cluster, sample N(center, sigma * sep) where `sep`
+ * is the median pairwise distance between centers (so sigma is in units of
+ * inter-cluster spacing — stable across (dim, numClusters) settings).
+ * 3. L2-normalize. Production embeddings live on the unit sphere; that's the geometry
+ * IVF expects.
+ */
+ private def generateClusteredVectors(
+ n: Int,
+ dim: Int,
+ numClusters: Int,
+ sigma: Double,
+ seed: Long): Array[Array[Float]] = {
+ val rng = new Random(seed)
+ val centers = Array.fill(numClusters)(Array.fill(dim)(rng.nextDouble()))
+ val sep = medianPairwiseDistance(centers)
+ val scaledSigma = sigma * sep
+ val out = new Array[Array[Float]](n)
+ var i = 0
+ while (i < n) {
+ val center = centers(i % numClusters)
+ val v = new Array[Float](dim)
+ var d = 0
+ while (d < dim) {
+ v(d) = (center(d) + rng.nextGaussian() * scaledSigma).toFloat
+ d += 1
+ }
+ l2Normalize(v)
+ out(i) = v
+ i += 1
+ }
+ out
+ }
+
+ private def medianPairwiseDistance(centers: Array[Array[Double]]): Double = {
+ val k = centers.length
+ if (k < 2) return 1.0
+ val rng = new Random(0L)
+ val numPairs = math.min(1024, k * (k - 1) / 2)
+ val dists = new Array[Double](numPairs)
+ var p = 0
+ while (p < numPairs) {
+ var i = rng.nextInt(k)
+ var j = rng.nextInt(k)
+ while (j == i) j = rng.nextInt(k)
+ var s = 0.0
+ var d = 0
+ while (d < centers(i).length) {
+ val diff = centers(i)(d) - centers(j)(d)
+ s += diff * diff
+ d += 1
+ }
+ dists(p) = math.sqrt(s)
+ p += 1
+ }
+ java.util.Arrays.sort(dists)
+ dists(dists.length / 2)
+ }
+
+ private def l2Normalize(v: Array[Float]): Unit = {
+ var s = 0.0
+ var i = 0
+ while (i < v.length) { s += v(i) * v(i); i += 1 }
+ val norm = math.sqrt(s).toFloat
+ if (norm > 0f) {
+ i = 0
+ while (i < v.length) { v(i) = v(i) / norm; i += 1 }
+ }
+ }
+
+ private def deleteRecursively(f: java.io.File): Unit = {
+ if (f.isDirectory) f.listFiles().foreach(deleteRecursively)
+ f.delete()
+ }
+}
diff --git a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinE2ETest.scala b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinE2ETest.scala
new file mode 100644
index 000000000..c8ad8693c
--- /dev/null
+++ b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinE2ETest.scala
@@ -0,0 +1,348 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.catalyst
+
+import org.apache.spark.sql.{RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+import org.lance.spark.knn.internal.staged.LanceProbeLogicalPlan
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * End-to-end SQL test for the Phase 2 Catalyst integration. Drives the full path:
+ *
+ * ANTLR parser ─▶ Analyzer ─▶ IndexedNearestByJoinRule (our postHoc) ─▶
+ * Optimizer ─▶ LanceKnnStagedStrategy ─▶
+ * LanceProbeExec ─▶ ShuffleExchangeExec ─▶ LanceMergeExec ─▶ LanceMaterializeExec ─▶
+ * Lance brute-force per-fragment scan ─▶ Rows
+ *
+ * Requires Spark 4.2-SNAPSHOT (the only release where `NearestByJoin` exists today, added by
+ * SPARK-56395) AND the lance-spark-4.1_2.13 connector built against the same Spark version. To
+ * set up:
+ *
+ * {{{
+ * cd /path/to/spark/master
+ * ./build/mvn install -DskipTests -DskipChecks -pl sql/core -am
+ * cd /path/to/lance-spark
+ * ./mvnw install -pl lance-spark-4.1_2.13 -am -DskipTests \
+ * -Dspark41.version=4.2.0-SNAPSHOT -Darrow183.version=19.0.0
+ * ./mvnw install -pl lance-spark-knn_2.13 -am -DskipTests
+ * ./mvnw -pl lance-spark-knn-4.2_2.13 test
+ * }}}
+ *
+ * Coverage:
+ * - SQL `APPROX NEAREST k BY DISTANCE vector_l2_distance(...)` parses, the rule rewrites
+ * to the 3-logical-plan tree (probe/merge/materialize), the strategy lowers it to the
+ * matching exec chain, which executes against a real Lance dataset; results match the
+ * brute-force oracle.
+ * - With the gating config disabled, the same SQL falls through to Spark's
+ * `RewriteNearestByJoin` (cross-product + `MaxMinByK`) — proves the rule's opt-in
+ * behavior at the SQL level.
+ */
+class IndexedNearestByJoinE2ETest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 16
+ private val NumRight = 64
+ private val NumLeft = 8
+ private val Seed = 4242L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-by-join-e2e")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config(
+ "spark.sql.extensions",
+ "org.lance.spark.knn.extensions.LanceKnnSparkSessionExtensions")
+ .config("spark.sql.crossJoin.enabled", "true")
+ .getOrCreate()
+ spark.sparkContext.setLogLevel("WARN")
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * Full SQL path with the rule enabled. The physical plan must contain
+ * [[LanceProbeExec]] AND the result must match the brute-force oracle on every left row.
+ * With no vector index built, Lance does an exact per-fragment scan, so any disagreement
+ * with brute force is a bug.
+ */
+ @Test def testSqlApproxNearestRoutesThroughIndexedPathAndMatchesOracle(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+
+ val (leftRows, leftVectors, leftIds) = generateLeft(NumLeft, Dim, Seed)
+ val (rightRows, rightVectors, rightIds) = generateRight(NumRight, Dim, Seed + 1)
+ val rightUri = writeRightLance(rightRows)
+
+ spark.createDataFrame(leftRows.asJava, leftSchema()).createOrReplaceTempView("queries")
+ spark.read.format("lance").load(rightUri).createOrReplaceTempView("docs")
+
+ val k = 5
+ val sql =
+ s"""SELECT q.lid, d.rid
+ |FROM queries q INNER JOIN docs d
+ |APPROX NEAREST $k BY DISTANCE vector_l2_distance(q.lvec, d.rvec)""".stripMargin
+ val df = spark.sql(sql)
+
+ // Plan-shape: confirm the rule fired AND the strategy lowered all three execs with a
+ // Catalyst-inserted Exchange between probe and merge. Use `optimizedPlan` for the
+ // logical assertion (AQE-independent) and the physical `treeString` for the exec
+ // assertion. Checking all three nodes + the hashpartitioning exchange is the SQL-side
+ // analogue of `IndexedNearestJoinAqeVisibilityTest.testAllThreeCustomExecsInTree` on
+ // the DataFrame path — the strategy is shared, but the rule wiring is SQL-specific.
+ val optimized = df.queryExecution.optimizedPlan
+ val probeLogicals = optimized.collect { case p: LanceProbeLogicalPlan => p }
+ assertTrue(
+ probeLogicals.nonEmpty,
+ s"expected LanceProbeLogicalPlan in optimized plan; got:\n$optimized")
+ val executed = df.queryExecution.executedPlan
+ val tree = executed.treeString
+ assertTrue(tree.contains("LanceProbe"), s"expected LanceProbe exec in tree:\n$tree")
+ assertTrue(tree.contains("LanceMerge"), s"expected LanceMerge exec in tree:\n$tree")
+ assertTrue(
+ tree.contains("LanceMaterialize"),
+ s"expected LanceMaterialize exec in tree:\n$tree")
+ assertTrue(
+ tree.contains("hashpartitioning(_leftId"),
+ s"expected hashpartitioning(_leftId) Exchange in tree:\n$tree")
+
+ // Correctness: oracle equivalence.
+ val rows = df.collect()
+ assertEquals(NumLeft * k, rows.length, "expected k results per left row")
+ val byLid = rows.groupBy(_.getAs[Int]("lid"))
+ leftIds.zip(leftVectors).foreach { case (lid, lvec) =>
+ val oracle = rightVectors.indices
+ .map(i => (rightIds(i), l2(lvec, rightVectors(i))))
+ .sortBy(_._2)
+ .take(k)
+ .map(_._1)
+ .toSet
+ val actual = byLid(lid).map(_.getAs[Int]("rid")).toSet
+ assertEquals(
+ oracle,
+ actual,
+ s"top-K mismatch for lid=$lid (rule on, brute-force oracle)")
+ }
+ }
+
+ /**
+ * Right-side `WHERE` clause must round-trip through the prefilter pushdown — Lance computes
+ * top-K only over rows matching the filter, so the result must equal the brute-force oracle
+ * computed AFTER applying the same filter. If the rule pushed the filter wrong (or dropped
+ * it), this test would diverge from the oracle.
+ *
+ * Two right-side rows in this test share each `category` value, so a `WHERE category = 'A'`
+ * shrinks the candidate pool meaningfully without zeroing it out.
+ */
+ @Test def testSqlWherePushdownMatchesFilteredOracle(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+
+ val (leftRows, leftVectors, leftIds) = generateLeft(NumLeft, Dim, Seed + 200)
+ val (rightRows, rightVectors, rightIds, rightCategories) =
+ generateRightWithCategories(NumRight, Dim, Seed + 201)
+ val rightUri = writeRightLanceWithCategories(rightRows)
+
+ spark.createDataFrame(leftRows.asJava, leftSchema()).createOrReplaceTempView("queries")
+ spark.read.format("lance").load(rightUri).createOrReplaceTempView("docs")
+
+ val k = 4
+ val targetCat = "A"
+ val sql =
+ s"""SELECT q.lid, d.rid
+ |FROM queries q INNER JOIN (SELECT * FROM docs WHERE category = '$targetCat') d
+ |APPROX NEAREST $k BY DISTANCE vector_l2_distance(q.lvec, d.rvec)""".stripMargin
+ val df = spark.sql(sql)
+
+ val optimized = df.queryExecution.optimizedPlan
+ val probeLogicals = optimized.collect { case p: LanceProbeLogicalPlan => p }
+ assertTrue(
+ probeLogicals.nonEmpty,
+ s"expected LanceProbeLogicalPlan; optimized plan was:\n$optimized")
+ val prefilter = probeLogicals.head.stageConf.prefilter
+ assertTrue(
+ prefilter.exists(_.contains(s"'$targetCat'")),
+ s"expected prefilter to carry category='$targetCat'; got: $prefilter")
+
+ // Oracle: brute-force top-K computed AFTER applying the same filter on the right side.
+ val filteredIdxs = rightCategories.indices.filter(rightCategories(_) == targetCat)
+ val rows = df.collect()
+ val byLid = rows.groupBy(_.getAs[Int]("lid"))
+ leftIds.zip(leftVectors).foreach { case (lid, lvec) =>
+ val oracle = filteredIdxs
+ .map(i => (rightIds(i), l2(lvec, rightVectors(i))))
+ .sortBy(_._2)
+ .take(k)
+ .map(_._1)
+ .toSet
+ assertTrue(
+ oracle.nonEmpty,
+ s"oracle is empty for lid=$lid — test setup didn't produce filterable rows")
+ val actual = byLid(lid).map(_.getAs[Int]("rid")).toSet
+ assertEquals(
+ oracle,
+ actual,
+ s"top-K mismatch under WHERE pushdown for lid=$lid (filtered brute-force oracle)")
+ }
+ }
+
+ /**
+ * With the gating config disabled, the SAME SQL falls through to Spark's
+ * `RewriteNearestByJoin` (cross-product + `MaxMinByK`). The plan contains NO
+ * `LanceProbeLogicalPlan` and (importantly) results still match the oracle. This proves
+ * the rule's opt-in behavior at the SQL level: turning it off doesn't break correctness.
+ */
+ @Test def testSqlFallsThroughToBruteForceWhenRuleDisabled(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "false")
+
+ val (leftRows, leftVectors, leftIds) = generateLeft(NumLeft, Dim, Seed + 100)
+ val (rightRows, rightVectors, rightIds) = generateRight(NumRight, Dim, Seed + 101)
+ val rightUri = writeRightLance(rightRows)
+
+ spark.createDataFrame(leftRows.asJava, leftSchema()).createOrReplaceTempView("queries")
+ spark.read.format("lance").load(rightUri).createOrReplaceTempView("docs")
+
+ val k = 4
+ val df = spark.sql(
+ s"""SELECT q.lid, d.rid
+ |FROM queries q INNER JOIN docs d
+ |APPROX NEAREST $k BY DISTANCE vector_l2_distance(q.lvec, d.rvec)""".stripMargin)
+
+ val optimized = df.queryExecution.optimizedPlan
+ val probeLogicals = optimized.collect { case p: LanceProbeLogicalPlan => p }
+ assertTrue(
+ probeLogicals.isEmpty,
+ s"rule disabled — expected NO LanceProbeLogicalPlan; got:\n$optimized")
+
+ val rows = df.collect()
+ assertEquals(NumLeft * k, rows.length)
+ val byLid = rows.groupBy(_.getAs[Int]("lid"))
+ leftIds.zip(leftVectors).foreach { case (lid, lvec) =>
+ val oracle = rightVectors.indices
+ .map(i => (rightIds(i), l2(lvec, rightVectors(i))))
+ .sortBy(_._2)
+ .take(k)
+ .map(_._1)
+ .toSet
+ val actual = byLid(lid).map(_.getAs[Int]("rid")).toSet
+ assertEquals(oracle, actual, s"top-K mismatch for lid=$lid (rule off, brute-force fallback)")
+ }
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def leftSchema(): StructType = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ private def rightSchema(): StructType = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ private def generateLeft(
+ n: Int,
+ dim: Int,
+ seed: Long): (Seq[org.apache.spark.sql.Row], Array[Array[Float]], Array[Int]) = {
+ val rng = new Random(seed)
+ val vectors = (0 until n).map(_ => randomVector(rng, dim)).toArray
+ val ids = (0 until n).toArray
+ val rows = ids.zip(vectors).map { case (id, v) => RowFactory.create(Integer.valueOf(id), v) }
+ (rows.toSeq, vectors, ids)
+ }
+
+ private def generateRight(
+ n: Int,
+ dim: Int,
+ seed: Long): (Seq[org.apache.spark.sql.Row], Array[Array[Float]], Array[Int]) = {
+ val rng = new Random(seed)
+ val vectors = (0 until n).map(_ => randomVector(rng, dim)).toArray
+ val ids = (0 until n).map(_ + 1000).toArray
+ val rows = ids.zip(vectors).map { case (id, v) => RowFactory.create(Integer.valueOf(id), v) }
+ (rows.toSeq, vectors, ids)
+ }
+
+ private def writeRightLance(rows: Seq[org.apache.spark.sql.Row]): String = {
+ val df = spark.createDataFrame(rows.asJava, rightSchema())
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ out
+ }
+
+ private def rightSchemaWithCategories(): StructType = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField("category", StringType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ /**
+ * Like `generateRight`, but every row also carries a category drawn from a small alphabet so
+ * the e2e WHERE-pushdown test has a non-trivial filter to apply.
+ */
+ private def generateRightWithCategories(n: Int, dim: Int, seed: Long): (
+ Seq[org.apache.spark.sql.Row],
+ Array[Array[Float]],
+ Array[Int],
+ Array[String]) = {
+ val rng = new Random(seed)
+ val vectors = (0 until n).map(_ => randomVector(rng, dim)).toArray
+ val ids = (0 until n).map(_ + 2000).toArray
+ val alphabet = Array("A", "B", "C", "D")
+ val categories = (0 until n).map(i => alphabet(i % alphabet.length)).toArray
+ val rows = ids.zip(vectors).zip(categories).map { case ((id, v), cat) =>
+ RowFactory.create(Integer.valueOf(id), cat, v)
+ }
+ (rows.toSeq, vectors, ids, categories)
+ }
+
+ private def writeRightLanceWithCategories(rows: Seq[org.apache.spark.sql.Row]): String = {
+ val df = spark.createDataFrame(rows.asJava, rightSchemaWithCategories())
+ val out = tempDir.resolve(s"right_cat_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ out
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def l2(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 }
+ s
+ }
+}
diff --git a/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala
new file mode 100644
index 000000000..a9bf8c509
--- /dev/null
+++ b/lance-spark-knn-4.2_2.13/src/test/scala/org/lance/spark/knn/catalyst/IndexedNearestByJoinRuleTest.scala
@@ -0,0 +1,468 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.catalyst
+
+import org.apache.spark.sql.{RowFactory, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeSet, EqualTo, Expression, GreaterThan, In, IsNotNull, IsNull, LessThanOrEqual, Literal, Not, Or, VectorCosineSimilarity, VectorInnerProduct, VectorL2Distance}
+import org.apache.spark.sql.catalyst.plans.{Inner, NearestByDirection, NearestByDistance, NearestBySimilarity}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, NearestByJoin, Project, SubqueryAlias}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+import org.lance.spark.knn.internal.Metric
+import org.lance.spark.knn.internal.staged.{LanceMaterializeLogicalPlan, LanceMergeLogicalPlan, LanceProbeLogicalPlan}
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Unit tests for [[IndexedNearestByJoinRule]]. The rule's responsibility is purely Catalyst-side
+ * pattern-matching — we don't need a Lance backend to exercise it. Each test constructs a small
+ * resolved plan and runs the rule, asserting either a rewrite to the 3-exec staged logical-plan
+ * tree (`Project(... LanceMaterializeLogicalPlan(LanceMergeLogicalPlan(LanceProbeLogicalPlan)))`)
+ * or a no-op fallthrough.
+ *
+ * Coverage:
+ * - Happy path: VectorL2Distance + NearestByDistance over a Lance DSv2 relation rewrites.
+ * - Direction mismatch (e.g. L2 distance with NearestBySimilarity) does NOT rewrite.
+ * - EXACT (`approx = false`) does NOT rewrite — Spark's brute-force keeps owning that path.
+ * - Non-Lance right side does NOT rewrite (duck-type check via class name).
+ * - Disabled by default — fires only when the gating config is set.
+ *
+ * The rule's runtime behavior beyond the rewrite (probe execution against real Lance) is covered
+ * by the Phase 0/1 oracle tests in lance-spark-knn_2.12.
+ */
+class IndexedNearestByJoinRuleTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-by-join-rule-test")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /** L2 + NearestByDistance + Lance scan + enabled config → rewrite. */
+ @Test def testL2RewritesToIndexedPlan(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val join = NearestByJoin(
+ left = left,
+ right = right,
+ joinType = Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ val plan = expectRewritten(rewritten)
+ assertEquals(Metric.L2, plan.metric)
+ assertEquals(5, plan.k)
+ assertEquals(rightVec.name, plan.rightVecCol)
+ assertEquals(leftVec.exprId, plan.leftVecAttr.exprId)
+ }
+
+ /** Cosine similarity + NearestBySimilarity → rewrite. */
+ @Test def testCosineRewrites(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "cosine")
+ val join = NearestByJoin(
+ left,
+ right,
+ Inner,
+ approx = true,
+ numResults = 3,
+ rankingExpression = VectorCosineSimilarity(leftVec, rightVec),
+ direction = NearestBySimilarity)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertEquals(Metric.Cosine, expectRewritten(rewritten).metric)
+ }
+
+ /** Inner product + NearestBySimilarity → rewrite as Dot. */
+ @Test def testDotRewrites(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "dot")
+ val join = NearestByJoin(
+ left,
+ right,
+ Inner,
+ approx = true,
+ numResults = 4,
+ rankingExpression = VectorInnerProduct(leftVec, rightVec),
+ direction = NearestBySimilarity)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertEquals(Metric.Dot, expectRewritten(rewritten).metric)
+ }
+
+ /** L2 distance with NearestBySimilarity is inconsistent — rule should NOT fire. */
+ @Test def testDirectionMismatchDoesNotRewrite(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val join = NearestByJoin(
+ left,
+ right,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestBySimilarity)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertSame(join, rewritten, "rule should not fire on direction/metric mismatch")
+ }
+
+ /** EXACT mode (approx = false) is owned by Spark's brute-force rewrite. */
+ @Test def testExactModeDoesNotRewrite(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val join = NearestByJoin(
+ left,
+ right,
+ Inner,
+ approx = false,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertSame(join, rewritten, "EXACT queries must not be intercepted")
+ }
+
+ /** Disabled flag (default) → no rewrite even when otherwise applicable. */
+ @Test def testDisabledByDefault(): Unit = {
+ spark.conf.unset(IndexedNearestByJoinRule.EnabledConfKey)
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val join = NearestByJoin(
+ left,
+ right,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertSame(join, rewritten, "rule must be opt-in")
+ }
+
+ /** Non-Lance right side (regular DataFrame as Project, no DSv2 relation) → no rewrite. */
+ @Test def testNonLanceRightDoesNotRewrite(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val left = trivialPlan("lid", "lvec")
+ val right = trivialPlan("rid", "rvec")
+ val leftVec = left.output.find(_.name == "lvec").get
+ val rightVec = right.output.find(_.name == "rvec").get
+ val join = NearestByJoin(
+ left,
+ right,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertSame(join, rewritten, "non-Lance right must fall through")
+ }
+
+ /** Right side wrapped in SubqueryAlias still rewrites — alias unwrapping happens in the rule. */
+ @Test def testSubqueryAliasOnRightStillRewrites(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val aliased = SubqueryAlias("d", right)
+ val join = NearestByJoin(
+ left,
+ aliased,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ // Rule emits `Project(j.output, LanceMaterializeLogicalPlan(LanceMergeLogicalPlan(
+ // LanceProbeLogicalPlan(left, ...), ...), ...))`. Asserting on the top Project is
+ // enough for the "did the rule fire" check.
+ assertTrue(
+ rewritten.isInstanceOf[Project] &&
+ rewritten.asInstanceOf[Project].child.isInstanceOf[LanceMaterializeLogicalPlan],
+ s"expected Project(..., LanceMaterializeLogicalPlan(...)), got: " +
+ s"${rewritten.getClass.getSimpleName}")
+ }
+
+ // -- prefilter pushdown -------------------------------------------------------------------
+
+ /**
+ * Right side wrapped in `Filter(simple predicate)` rewrites AND the predicate lands on the
+ * indexed plan as a Lance SQL filter string. The filter must be pushed in full (not dropped)
+ * for the result to be semantically equivalent to the original plan.
+ */
+ @Test def testFilterOverLancePushesAsPrefilter(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val category = right.output.find(_.name == "category").get
+ val bucket = right.output.find(_.name == "bucket").get
+ val cond = And(
+ EqualTo(category, Literal(UTF8String.fromString("A"), StringType)),
+ GreaterThan(bucket, Literal(5, IntegerType)))
+ val filtered = Filter(cond, right)
+ val join = NearestByJoin(
+ left,
+ filtered,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ val plan = expectRewritten(rewritten)
+ assertTrue(plan.prefilter.isDefined, "prefilter should be populated")
+ val sql = plan.prefilter.get
+ assertTrue(sql.contains("category"), s"prefilter missing column ref: $sql")
+ assertTrue(sql.contains("'A'"), s"prefilter missing string literal: $sql")
+ assertTrue(sql.contains("bucket"), s"prefilter missing column ref: $sql")
+ assertTrue(sql.contains("> 5"), s"prefilter missing numeric comparison: $sql")
+ assertTrue(sql.contains("AND"), s"prefilter missing conjunction: $sql")
+ }
+
+ /**
+ * Predicate touches a left-side attribute — translator can't safely render that as a Lance
+ * SQL string (Lance only sees the right table's columns). Rule must REFUSE the rewrite, not
+ * drop the predicate. We verify the original `NearestByJoin` is returned unchanged.
+ */
+ @Test def testPredicateReferencingLeftAttrRefusesRewrite(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val lid = left.output.find(_.name == "lid").get
+ val cond = EqualTo(lid, Literal(0, IntegerType))
+ val filtered = Filter(cond, right)
+ val join = NearestByJoin(
+ left,
+ filtered,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertSame(
+ join,
+ rewritten,
+ "predicate touching left side must refuse pushdown — not partial-push")
+ }
+
+ /**
+ * Predicate is a computed expression (e.g. `bucket + 1 = 6`), not a bare attr-vs-literal
+ * comparison. Translator returns None, rule refuses.
+ */
+ @Test def testComputedPredicateRefusesRewrite(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val bucket = right.output.find(_.name == "bucket").get
+ val cond = EqualTo(Add(bucket, Literal(1, IntegerType)), Literal(6, IntegerType))
+ val filtered = Filter(cond, right)
+ val join = NearestByJoin(
+ left,
+ filtered,
+ Inner,
+ approx = true,
+ numResults = 5,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ assertSame(join, rewritten, "computed expression must refuse pushdown")
+ }
+
+ /** Filter wrapped in SubqueryAlias still pushes — order of unwrap shouldn't matter. */
+ @Test def testFilterUnderSubqueryAliasPushes(): Unit = {
+ spark.conf.set(IndexedNearestByJoinRule.EnabledConfKey, "true")
+ val (left, leftVec, right, rightVec) = buildPlans(metricFunction = "l2")
+ val category = right.output.find(_.name == "category").get
+ val cond = EqualTo(category, Literal(UTF8String.fromString("X"), StringType))
+ val plan = SubqueryAlias("d", Filter(cond, right))
+ val join = NearestByJoin(
+ left,
+ plan,
+ Inner,
+ approx = true,
+ numResults = 3,
+ rankingExpression = VectorL2Distance(leftVec, rightVec),
+ direction = NearestByDistance)
+ val rewritten = IndexedNearestByJoinRule(join)
+ val p = expectRewritten(rewritten)
+ assertTrue(p.prefilter.isDefined, s"prefilter should be set; got ${p.prefilter}")
+ }
+
+ // -- predicate translator unit tests -----------------------------------------------------
+
+ /**
+ * Direct unit tests on `translateFilter` to lock in the supported shapes. Uses a synthetic
+ * AttributeSet so we don't need a logical plan.
+ */
+ @Test def testTranslatorHandlesSupportedShapes(): Unit = {
+ val rid = makeAttr("rid", IntegerType)
+ val category = makeAttr("category", StringType)
+ val bucket = makeAttr("bucket", IntegerType)
+ val attrs = AttributeSet(Seq(rid, category, bucket))
+
+ val cases: Seq[(Expression, String)] = Seq(
+ EqualTo(category, lit("A")) -> "category = 'A'",
+ Not(EqualTo(category, lit("A"))) -> "category != 'A'",
+ GreaterThan(bucket, lit(5)) -> "bucket > 5",
+ LessThanOrEqual(bucket, lit(5)) -> "bucket <= 5",
+ IsNull(category) -> "category IS NULL",
+ IsNotNull(category) -> "category IS NOT NULL",
+ In(bucket, Seq(lit(1), lit(2), lit(3))) -> "bucket IN (1, 2, 3)",
+ And(EqualTo(category, lit("A")), GreaterThan(bucket, lit(5))) ->
+ "(category = 'A') AND (bucket > 5)",
+ Or(EqualTo(category, lit("A")), EqualTo(category, lit("B"))) ->
+ "(category = 'A') OR (category = 'B')",
+ // String-literal escape — single quotes inside the value get doubled.
+ EqualTo(category, lit("O'Brien")) -> "category = 'O''Brien'",
+ // literal-on-left flip
+ EqualTo(lit(5), bucket) -> "5 = bucket")
+ cases.foreach { case (expr, expected) =>
+ val got = IndexedNearestByJoinRule.translateFilter(expr, attrs)
+ assertEquals(Some(expected), got, s"translation mismatch for: $expr")
+ }
+ }
+
+ /** Translator must return None for unsupported expressions so the rule refuses pushdown. */
+ @Test def testTranslatorRefusesUnsupportedShapes(): Unit = {
+ val rid = makeAttr("rid", IntegerType)
+ val ts = makeAttr("ts", DateType) // date literals not in our supported set
+ val attrs = AttributeSet(Seq(rid, ts))
+
+ val rejected: Seq[Expression] = Seq(
+ // Two attributes — no literal — translator can't render `attr op attr` safely (Lance can,
+ // but we don't promise it; refuse to keep the rule conservative).
+ EqualTo(rid, makeAttr("rid2", IntegerType)),
+ // Foreign attribute (not in `attrs`) — translator must reject.
+ EqualTo(makeAttr("foreign", IntegerType), lit(1)),
+ // Empty IN list.
+ In(rid, Seq.empty),
+ // Date literal — out of supported types.
+ EqualTo(ts, Literal(0, DateType)))
+ rejected.foreach { e =>
+ assertEquals(
+ None,
+ IndexedNearestByJoinRule.translateFilter(e, attrs),
+ s"expected refusal for: $e")
+ }
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ /**
+ * Construct a left-side regular plan and a right-side that resembles a Lance DSv2 scan via the
+ * duck-type check. Avoids the need for a real Lance reader.
+ */
+ private def buildPlans(metricFunction: String)
+ : (LogicalPlan, Attribute, LogicalPlan, Attribute) = {
+ val left = trivialPlan("lid", "lvec")
+ val rightLance = lanceLikeDsv2Relation()
+ val leftVec = left.output.find(_.name == "lvec").get
+ val rightVec = rightLance.output.find(_.name == "rvec").get
+ (left, leftVec, rightLance, rightVec)
+ }
+
+ private def trivialPlan(idCol: String, vecCol: String): LogicalPlan = {
+ val schema = new StructType(Array(
+ StructField(idCol, IntegerType, nullable = false),
+ StructField(vecCol, ArrayType(FloatType, containsNull = false), nullable = false)))
+ val rows = (0 until 4).map(i => RowFactory.create(Integer.valueOf(i), Array.fill(8)(0.0f)))
+ spark.createDataFrame(rows.asJava, schema).queryExecution.analyzed
+ }
+
+ /**
+ * Build a `DataSourceV2Relation` whose `table.getClass.getName.contains("Lance")` so the
+ * rule's duck-type check accepts it. We don't actually run any I/O. Includes a `category`
+ * (string) and `bucket` (int) column so prefilter-pushdown tests can build realistic
+ * filter predicates without needing to extend the schema separately.
+ */
+ private def lanceLikeDsv2Relation(): LogicalPlan = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField("category", StringType, nullable = true),
+ StructField("bucket", IntegerType, nullable = true),
+ StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false)))
+ val table = new FakeLanceTable(schema)
+ val opts = new java.util.HashMap[String, String]()
+ opts.put("path", tempDir.resolve("fake_lance").toString)
+ val cims = new org.apache.spark.sql.util.CaseInsensitiveStringMap(opts)
+ DataSourceV2Relation.create(table, None, None, cims)
+ }
+
+ /**
+ * Extract an assertion-friendly summary of the rule's rewrite output. The rule produces
+ * `Project(j.output, LanceMaterializeLogicalPlan(LanceMergeLogicalPlan(
+ * LanceProbeLogicalPlan(left, probeConf), mergeConf), materializeConf))`; this helper
+ * walks down and pulls out the fields the test cases want to check.
+ */
+ private case class RewriteSummary(
+ metric: Metric,
+ k: Int,
+ rightVecCol: String,
+ leftVecAttr: Attribute,
+ prefilter: Option[String])
+
+ private def expectRewritten(plan: LogicalPlan): RewriteSummary = plan match {
+ case Project(
+ _,
+ LanceMaterializeLogicalPlan(
+ LanceMergeLogicalPlan(
+ probe: LanceProbeLogicalPlan,
+ mergeConf,
+ _,
+ _),
+ _,
+ _,
+ _,
+ _)) =>
+ val probeConf = probe.stageConf
+ val leftVec = probe.child.output(probeConf.leftVecIdx)
+ RewriteSummary(
+ metric = probeConf.metric,
+ k = mergeConf.finalK,
+ rightVecCol = probeConf.vectorColumn,
+ leftVecAttr = leftVec,
+ prefilter = probeConf.prefilter)
+ case other =>
+ fail(s"expected Project(LanceMaterialize(LanceMerge(LanceProbe))), got: $other"); ???
+ }
+
+ private def makeAttr(name: String, dt: DataType): Attribute =
+ org.apache.spark.sql.catalyst.expressions.AttributeReference(name, dt, nullable = true)()
+
+ private def lit(v: Int): Literal = Literal(v, IntegerType)
+ private def lit(s: String): Literal = Literal(UTF8String.fromString(s), StringType)
+}
+
+/**
+ * Stub Table whose class name ends with "Lance" so the rule's duck-type check accepts it. No I/O
+ * — the rule only reads schema and options. Lives in the test source tree.
+ */
+class FakeLanceTable(_schema: StructType) extends org.apache.spark.sql.connector.catalog.Table {
+ override def name(): String = "fake_lance"
+ override def schema(): StructType = _schema
+ override def capabilities()
+ : java.util.Set[org.apache.spark.sql.connector.catalog.TableCapability] =
+ java.util.Collections.emptySet()
+}
diff --git a/lance-spark-knn_2.12/BENCHMARK_RESULTS.md b/lance-spark-knn_2.12/BENCHMARK_RESULTS.md
new file mode 100644
index 000000000..051febcbc
--- /dev/null
+++ b/lance-spark-knn_2.12/BENCHMARK_RESULTS.md
@@ -0,0 +1,768 @@
+# Benchmark results — local M5 Max
+
+Two benchmarks, two complementary headline numbers — **both validated** against an
+in-memory brute-force oracle on a 16-row left subset before timing:
+
+| Benchmark | What it compares | Headline (small scale, validated) |
+|---|---|---|
+| **DataFrame** (`IndexedNearestJoinBenchmark`, this dir) | Indexed staged pipeline vs. naive Spark `crossJoin + UDF + window` | **608×** |
+| **SQL** (`lance-spark-knn-4.2_2.13/.../IndexedNearestByJoinSqlBenchmark`) | Same `APPROX NEAREST` SQL with the Phase 2 rule ON vs. OFF (= Spark's `RewriteNearestByJoin` cross-product + `min_by_k`) | **17.4×** |
+
+The 608× is what users on Spark 3.5/4.0/4.1 (no `NearestByJoin` SQL yet) would observe vs. the natural workaround they write today. The 17.4× is the apples-to-apples SQL-level number on Spark 4.2+ where users can write `APPROX NEAREST` and Spark's optimized `min_by_k` heap aggregate handles the cross-product more efficiently than a naive crossJoin + window. Both wins are real; the audience determines which one to quote.
+
+## Validation methodology
+
+Both benchmarks run a **pre-timing oracle equivalence check**:
+
+1. Sample 16 rows from the left side.
+2. Compute the brute-force top-K row IDs for each sample using a plain-Scala loop (the ground truth).
+3. Run **every** config — including the slow Spark crossJoin baseline / rule-OFF path — on the same 16-row subset and collect its top-K row IDs per left row.
+4. Compare each config's result to the oracle. `sys.error` if any disagrees.
+
+Latest validation passes:
+
+- **DataFrame benchmark** (small scale, 5 configs): `all 5 configs match the oracle (sample size: 16)` — A: Spark crossJoin baseline, B: Phase 0/1, C: Phase 1.5 G=4, D: Phase 1.5 G=8, E: Phase 1.5 G=8 skew-balanced — all return identical top-K row IDs.
+- **SQL benchmark** (small scale, 2 configs): `rule ON and rule OFF agree on top-K (sample size: 16)` — Spark's `RewriteNearestByJoin` (cross-product + `min_by_k`) and our 3-exec staged chain (shared with the DataFrame path) return identical top-K row IDs.
+
+The 16-row subset keeps validation under a few seconds even though the slow baseline / rule-OFF path runs in it: `O(16 × |R|)` = 1.6M-16M cross-product evaluations, sub-second wall-clock. The full timed runs use the full left side; the speedup is on results that have been proven equivalent to the baseline on the same dataset.
+
+Hardware: Apple M5 Max, 18 cores (12 P + 6 E), 48 GB RAM. Spark `local[*]`.
+
+Run via:
+
+```sh
+cd /path/to/lance-spark
+./mvnw -pl lance-spark-knn_2.12 install -DskipTests
+MAVEN_OPTS="-Xmx12g \
+ --add-opens=java.base/java.lang=ALL-UNNAMED \
+ --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \
+ --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \
+ --add-opens=java.base/java.io=ALL-UNNAMED \
+ --add-opens=java.base/java.net=ALL-UNNAMED \
+ --add-opens=java.base/java.nio=ALL-UNNAMED \
+ --add-opens=java.base/java.util=ALL-UNNAMED \
+ --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \
+ --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \
+ --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \
+ --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \
+ --add-opens=java.base/sun.security.action=ALL-UNNAMED \
+ --add-opens=java.base/sun.util.calendar=ALL-UNNAMED" \
+./mvnw -pl lance-spark-knn_2.12 -q exec:java \
+ -Dexec.classpathScope=test \
+ -Dexec.mainClass=org.lance.spark.knn.benchmark.IndexedNearestJoinBenchmark
+```
+
+Median of 3 runs after 1 warmup. Dim = 128, K = 10, L2 distance.
+
+## DataFrame benchmark
+
+vs. naive Spark `crossJoin + array_distance UDF + row_number window` (`IndexedNearestJoinBenchmark`).
+
+| Config | Small (\|R\|=100K, \|L\|=100) | Medium (\|R\|=1M, \|L\|=1000) | Speedup vs. baseline (small) |
+|---|---:|---:|---:|
+| A: Spark crossJoin baseline | 109,373 ms | — | 1.00× |
+| B: Phase 0/1 (probeParallelism=1) | 180 ms | 11,351 ms | **608×** |
+| C: Phase 1.5 (probeParallelism=4) | 288 ms | 11,830 ms | 380× |
+| D: Phase 1.5 (probeParallelism=8) | 276 ms | 12,660 ms | 396× |
+| E: Phase 1.5 (G=8, skew-balanced) | 277 ms | 12,301 ms | 395× |
+
+**The win vs. naive Spark:** indexed staged pipeline is **608× faster** than `crossJoin +
+UDF + window` at 100K × 100. The medium-scale baseline isn't included — 1M × 1000 = 1B-pair
+crossJoin is measured in tens of minutes on this hardware, the small-scale number is
+already conclusive.
+
+## SQL benchmark — Phase 2 rule ON vs OFF
+
+Same `APPROX NEAREST K BY DISTANCE vector_l2_distance(q.lvec, d.rvec)` SQL run with the
+Phase 2 rule's gating config flipped (`IndexedNearestByJoinSqlBenchmark`, in
+`lance-spark-knn-4.2_2.13`). Spark 4.2-SNAPSHOT runtime + `lance-spark-4.1` connector
+recompiled against it.
+
+| Config | Small (\|R\|=100K, \|L\|=100) | Medium (\|R\|=1M, \|L\|=1000) | Speedup vs. rule-OFF (small) |
+|---|---:|---:|---:|
+| A: rule OFF (Spark `RewriteNearestByJoin`) | 3,739 ms | — (skipped) | 1.00× |
+| B: rule ON (3-exec staged + shared strategy) | 217 ms | 11,025 ms | **17.23×** |
+
+Re-measured after switching the identity column from `_rowaddr` to `_rowid` (the universal
+Lance row identifier — `_rowaddr` only materializes on non-indexed scan paths). Numbers are
+within noise of the prior un-rowid run (was 17.4×; now 17.23×) — JVM/Lance state variance,
+not a correctness change. Validation passes either way.
+
+**The win vs. Spark 4.2's built-in:** **17×** at 100K × 100. Smaller than the 608× headline
+because Spark's `RewriteNearestByJoin` rule is itself optimized — it lowers to a
+`min_by_k` heap aggregate over a `BroadcastNestedLoopJoin`, which avoids materializing all
+|L|×|R| pairs in memory. Medium baseline is skipped because 1B `min_by_k` evaluations
+is still impractical on this hardware.
+
+### Why Lance brute-force (no index) is 17× faster than Spark `RewriteNearestByJoin`
+
+Both paths do the same 10M pair evaluations on the small-scale benchmark. The 17× gap is
+constant-factor JVM/native overhead, in roughly this order of impact:
+
+1. **Native SIMD vs JVM expression evaluation.** Lance's distance kernel is hand-tuned Rust
+ with AVX-512 / NEON intrinsics — ~8 cycles per dim-128 L2 distance. Spark's
+ `vector_l2_distance` is a `RuntimeReplaceable` lowered to
+ `StaticInvoke(VectorFunctionImplUtils.vectorL2Distance)`; JVM bytecode through Catalyst
+ expression evaluation per row. JIT auto-vectorizes the inner loop but loses to
+ hand-written intrinsics by 5-10×.
+2. **Columnar Arrow arrays vs per-row deserialization.** Lance stores the vector column as
+ a contiguous `float32` array per fragment; the kernel iterates contiguous memory.
+ Spark's path goes through `UnsafeArrayData.toFloatArray()` per row → per-row malloc +
+ scattered loads.
+3. **Catalyst expression-evaluation overhead per pair.** Spark's `RewriteNearestByJoin`
+ lowers to roughly: `Generate(Inline(Aggregate(min_by_k(struct(right.*),
+ distance(L,R), K), BroadcastNestedLoopJoin(left tagged with __qid, right))))`. Each
+ (L, R) pair walks ~5 expression-evaluation layers: BNL row iterator → tag projection →
+ aggregate input projection → `vector_l2_distance` → `min_by_k` heap update. Lance's
+ path is just: scanner pulls a column batch, kernel iterates the float array, updates a
+ top-K heap. No JOIN, no struct serialization, no broadcast.
+
+Rough math sanity check (10M pair evals across 18 cores):
+
+```
+Spark: 3,676 ms / 10M = 370 ns/pair (≈80 cycles/core/pair @ 4 GHz)
+Lance: 223 ms / 10M = 22 ns/pair (≈ 5 cycles/core/pair @ 4 GHz)
+```
+
+The Lance number is consistent with AVX-512 doing 16 float ops/cycle on dim-128 vectors
+(128/16 ≈ 8 cycles for the math, plus heap-maintenance and Arrow pointer dereferences).
+
+**Implication.** The 17× speedup doesn't require a vector index. Lance's native-SIMD,
+columnar, no-JOIN path beats Catalyst's per-pair JVM overhead even when both do
+brute-force scan. The vector index is a multiplier on top, not a prerequisite.
+
+## Surprise: Phase 1.5 doesn't help at local-laptop scale
+
+Phase 1.5 fragment-grouped probing (configs C/D/E) is **slower** than Phase 0/1 single-task
+probing (config B) at every scale measured. This is genuine, not noise — the difference is
+50-90 ms small / ~500 ms medium across 3 runs each.
+
+### Why
+
+The Phase 1.5 design pays for two things and benefits from one:
+
+| Cost | Benefit |
+|---|---|
+| `flatMap` to replicate each left row across G groups | Each task probes only `\|R\|/G` rows |
+| `partitionBy(HashPartitioner(G))` shuffle | |
+| Catalyst-inserted `ShuffleExchangeExec` above `LanceMergeExec` to co-locate contributions | |
+
+At local-laptop scale on M5 Max:
+
+1. **Lance's cross-fragment merge is highly optimized.** A single `LanceProbe` instance
+ running with `fragmentIds = None` parallelizes internally across the dataset's
+ fragments via Lance's own scan kernels (vectorized, AVX-512 / NEON-tuned native code).
+ Phase 0/1 already gets fragment-level parallelism, just not at the Spark task boundary.
+2. **Two shuffles is two shuffles.** Each shuffle is bound by spill, network (or local
+ loopback), and serialization overhead. Local-mode Spark using shared memory still pays
+ the serialization cost.
+3. **Shared L2/L3 cache.** All 18 cores share the same on-chip caches and unified memory
+ on the M5 Max. Splitting work across Spark tasks vs. across Lance's internal threads
+ doesn't change the cache footprint.
+
+Net: Phase 0/1's single Spark task delegating fragment parallelism to Lance is the best
+shape for this hardware.
+
+### When Phase 1.5 *would* pay off
+
+The premise of fragment-grouping is "different probe tasks, on different machines, with
+disjoint network paths to disjoint fragment files." That premise holds in:
+
+- **True distributed cluster** with N executors, each owning local fragment shards. Lance's
+ internal threading is bounded by single-machine compute; spreading work across machines
+ needs Spark's task scheduler.
+- **Object-store-backed Lance** (S3, GCS) where per-fragment fetch latency is the bottleneck
+ and parallel fetches help. M5 Max with local NVMe doesn't have this problem.
+- **Right side too large for one machine's memory** — single-task probe would page or OOM.
+ Splitting across tasks lets each task open a manageable working set.
+- **Per-fragment vector indexes** that need to be built / loaded once per task. Sharing the
+ index across many probes amortizes the load cost.
+
+The local benchmark doesn't exercise any of these. The fragment-grouping plumbing is
+correct (oracle-equivalence test passes); it just doesn't *win* on this single-machine
+workload because Lance's internal scheduling already covers the relevant parallelism.
+
+### What this means for users
+
+Today, leave `probeParallelism = 1` (the default) on a single-machine setup. On a
+distributed cluster, set it to roughly `numExecutors × executorCores / k` and benchmark
+your own workload — the crossover point where shuffle overhead is overtaken by the
+per-task speedup is dataset-shape-specific.
+
+## SQL benchmark — indexed paths (IVF_FLAT vs IVF-PQ × uniform vs clustered)
+
+The configs above (A/B) measure Spark cross-product vs. Lance no-index brute-force. Once
+`IndexedNearestByJoinSqlBenchmark` builds a vector index on the right dataset, configs C–F
+exercise approximate paths with various tuning knobs. The four-cell matrix below crosses
+`BENCHMARK_INDEX={ivf_flat, ivf_pq}` with `BENCHMARK_DATA={uniform, clustered}` to expose
+how data distribution interacts with index choice. Same hardware (M5 Max), same params
+(small scale, dim=128, K=10, 4 IVF partitions, median of 3 runs).
+
+`clustered` data is a unit-sphere-normalized Gaussian-mixture sample with 64 cluster
+centers and `sigma = 0.15 × inter-cluster-spacing` — a synthetic stand-in for production
+sentence-transformer / image-feature embeddings. Real embeddings cluster around topic
+centroids and live on the unit sphere, both of which IVF was designed for. `uniform` is
+independent floats over `[0, 1]^Dim` — the IVF worst case.
+
+| Config | uniform+flat (ms / r@10) | uniform+pq (ms / r@10) | clustered+flat (ms / r@10) | clustered+pq (ms / r@10) |
+|---|---:|---:|---:|---:|
+| A: rule OFF (Spark cross-product) | 3611 | 3667 | 3675 | 3718 |
+| B: rule ON, no index | 217 / r=1.000 | 216 / r=1.000 | 215 / r=1.000 | 218 / r=1.000 |
+| C: defaults (nprobes=1) | 137 / **1.000** | 109 / 0.044 | 131 / **1.000** | 103 / 0.094 |
+| D: refineFactor=64 | 171 / **1.000** | 152 / 0.175 | 172 / **1.000** | 165 / 0.225 |
+| E: nprobes=4 (full) | 128 / **1.000** | 103 / 0.044 | 122 / **1.000** | 100 / 0.094 |
+| F: nprobes=4 + refineFactor=64 | 639 / **1.000** | 558 / 0.537 | 594 / **1.000** | 586 / 0.550 |
+
+Speedups vs. config A (rule OFF), small scale:
+
+| Config | uniform+flat | uniform+pq | clustered+flat | clustered+pq |
+|---|---:|---:|---:|---:|
+| C: defaults | 26.4× | **33.6×** (r=4%) | 28.1× | **36.1×** (r=9%) |
+| E: nprobes=4 | 28.2× | 35.6× (r=4%) | 30.1× | **37.2×** (r=9%) |
+| F: nprobes=4 + refineFactor=64 | 5.7× | 6.6× (r=54%) | 6.2× | 6.3× (r=55%) |
+
+### What this matrix says
+
+**IVF_FLAT is distribution-invariant at dim=128.** Both uniform and clustered hit
+recall@10 = 1.0 across every config because IVF_FLAT stores the full vectors per cluster —
+within a probed cluster, the distance computation is exact. Only `nprobes` coverage
+matters, not vector geometry. IVF_FLAT is the safe production choice when you have the
+disk/memory budget for full-vector storage.
+
+**IVF-PQ at dim=128 is genuinely hard regardless of distribution.** Clustered data lifts
+default-config recall from 4.4% → 9.4% (~2× improvement, statistically real) but neither
+distribution gets to production-quality recall at defaults. The high-recall config
+(F: nprobes=full + refineFactor=64) reaches ~55% on either — the refineFactor re-rank pass
+helps but can't recover true neighbors that landed in unprobed PQ clusters.
+
+**Why clustered doesn't unlock PQ as much as expected.** Two reasons hit this benchmark:
+
+1. **Sub-vector budget vs. dim.** At Dim=128 with `numSubVectors = Dim/16 = 8`, each
+ sub-vector quantizes 16 dims into 256 codes — extremely lossy. Production PQ usually
+ uses `numSubVectors = Dim/4` (4 dims per sub-vector, much finer codes). Trying that
+ here ran into Lance's PQ training-sample requirement: 32-sub-vec PQ asks for 4.3B rows
+ to train 256 codes per sub-vector cleanly, vs. our 100K. So at 100K-row scale we're
+ structurally stuck with coarse PQ. Production deployments at much-larger N can train
+ fine-PQ codebooks and recall recovers.
+2. **Cluster tightness has a sweet spot, not a monotone effect.** Tested `sigma = 0.05`
+ (much tighter clusters) expecting better PQ recall; got 3.8% vs. the 9.4% at
+ `sigma = 0.15`. With overly-tight clusters, the K nearest neighbors all live inside ONE
+ cluster — and PQ within a cluster has high quantization noise (many vectors map to the
+ same code). The index can't distinguish among them. Real semantic embeddings sit
+ somewhere between these extremes.
+
+**The honest summary for users**: at production scale (millions to billions of rows, fine
+PQ codebooks well-trained), IVF-PQ on real-shaped data hits 90%+ recall. At our
+local-laptop benchmark scale (100K rows, dim=128, coarse PQ forced by training-sample
+limits), IVF-PQ is a speed-vs-recall regression compared to IVF_FLAT. The benchmark is
+honest about both data distributions; the structural lesson — clustered helps PQ, but the
+gap depends much more on PQ codebook training than on cluster tightness — generalizes.
+
+### Medium-scale matrix (1M × 1000, dim=128, 8 IVF partitions)
+
+Same matrix re-run at production-shaped scale. Note that the rule-OFF (Spark cross-product)
+baseline is impractical here — 1B-pair `min_by_k` is measured in tens of minutes — so config
+A is skipped and speedups are reported against config B (Lance no-index brute-force).
+
+| Config | uniform+flat (ms / r@10) | uniform+pq (ms / r@10) | clustered+flat (ms / r@10) | clustered+pq (ms / r@10) |
+|---|---:|---:|---:|---:|
+| B: rule ON, no index | 10380 / r=1.000 | 10568 / r=1.000 | 10233 / r=1.000 | 10755 / r=1.000 |
+| C: defaults (nprobes=1) | 2800 / **1.000** | 684 / 0.025 | 2636 / **1.000** | 776 / 0.031 |
+| D: refineFactor=64 | 3097 / **1.000** | 2056 / 0.119 | 3191 / **1.000** | 2019 / 0.125 |
+| E: nprobes=8 (full) | 2805 / **1.000** | 767 / 0.025 | 2892 / **1.000** | 834 / 0.031 |
+| F: nprobes=8 + refineFactor=64 | 12094 / **1.000** | 10818 / 0.294 | 10724 / **1.000** | 11252 / 0.281 |
+
+Speedups vs. config B (no-index Lance baseline):
+
+| Config | uniform+flat | uniform+pq | clustered+flat | clustered+pq |
+|---|---:|---:|---:|---:|
+| C: defaults | 3.7× | **15.4×** (r=2.5%) | 3.9× | **13.9×** (r=3.1%) |
+| E: nprobes=full | 3.7× | 13.8× (r=2.5%) | 3.5× | 12.9× (r=3.1%) |
+
+### Medium-scale findings
+
+**IVF-PQ is genuinely faster than IVF_FLAT at scale.** At medium, PQ defaults run 684 ms vs
+IVF_FLAT's 2800 ms — 4.1× faster. PQ codes are tiny so per-query scan touches much less
+data; at 1M rows, that wins. At small (100K) the absolute times were too short for the
+ratio to matter. This is the "PQ for scale" argument that drives most production
+deployments to PQ.
+
+**Coarse PQ recall gets *worse* at scale at this config.** Uniform PQ defaults:
+small=4.4% → medium=2.5%. Two reasons:
+ 1. More IVF partitions (8 vs 4 at small) means `nprobes=1` cuts more data away.
+ 2. The K nearest neighbors are sparser per cluster at 1M.
+
+**The clustered uplift on PQ is much smaller at medium.** Small: 4% → 9% (2.1× lift).
+Medium: 2.5% → 3.1% (1.2× lift, near-noise). At our forced PQ sub-vec setting, the
+structural advantage of realistic data is too small to matter at this scale. A higher
+PQ sub-vec budget (production setting `numSubVectors = Dim/4 = 32`) would close the gap —
+Lance's PQ training rejects that at our scales (needs > 4B training samples), but
+production-scale deployments hit it routinely. The thing this benchmark *can* show: the
+direction of effect is consistent (clustered ≥ uniform on PQ); the *magnitude* needs
+production-scale data.
+
+**IVF_FLAT remains recall-perfect at every config + distribution at medium.** Distribution
+truly doesn't matter for IVF_FLAT; only `nprobes` coverage does.
+
+### Reproducing the matrix
+
+```sh
+for SCALE in small medium; do
+ for DATA in uniform clustered; do
+ for IDX in flat pq; do
+ BENCHMARK_SCALE=$SCALE BENCHMARK_DATA=$DATA BENCHMARK_INDEX=$IDX \
+ MAVEN_OPTS="" \
+ ./mvnw -pl lance-spark-knn-4.2_2.13 -q exec:java \
+ -Dexec.classpathScope=test \
+ -Dexec.mainClass=org.lance.spark.knn.benchmark.IndexedNearestByJoinSqlBenchmark
+ done
+ done
+done
+```
+
+Defaults: `BENCHMARK_DATA=uniform`, `BENCHMARK_INDEX=flat`, `BENCHMARK_SCALE=both`.
+Additional knobs: `BENCHMARK_SIGMA` (cluster tightness for `clustered`, default 0.15),
+`BENCHMARK_PQ_SUBVEC` (PQ sub-vector count, default `Dim/16 = 8` — Lance rejects 32+
+at our test scales due to PQ training-sample requirements; production at >>1M rows can
+override to 32 for finer codes).
+
+## Sanity check
+
+Before timing, the benchmark verifies oracle equivalence on a 16-row left subset against
+brute-force ground truth. Bails the run if any indexed path disagrees with the exact
+top-K. Output above includes:
+
+```
+Sanity check: indexed-path top-K matches brute-force oracle on a 16-row subset ...
+... oracle equivalence holds.
+```
+
+So the timing numbers are for paths that produce correct results.
+
+---
+
+# Cluster benchmarks — OSS Spark 3.5 on Kubernetes
+
+Cluster numbers complementing the local M5 Max headline, run on an OSS Spark 3.5.4
+engine (Scala 2.12, Linux x86_64, Gluten-bundled executors, each executor in its own
+Kubernetes pod, multi-tenant shared infrastructure).
+
+## Cluster shape
+
+- **Spark**: OSS Spark 3.5.4, standalone-per-app mode on Kubernetes.
+- **Executors**: 8 × 4 cores × 16 GB = 32 cores / 128 GB total. Each executor is a
+ separate Kubernetes pod.
+- **Driver**: 4 cores × 16 GB.
+- **Critical submit-time settings** (documented here because they're not obvious on
+ managed-Spark distributions):
+ - Executor count: some managed distributions ignore `spark.executor.instances` in
+ standalone-per-app mode and expect their own vendor-specific knob. Without setting
+ it, each app gets one worker pod regardless of the Spark conf. If you hit this,
+ check your distribution's template docs for the equivalent setting.
+ - `spark.driver.extraClassPath = ` + `spark.executor.extraClassPath = ` —
+ puts the benchmark fat JAR earlier on the classpath than any cluster-bundled Arrow.
+ Our fat JAR ships Arrow 15.0.2; clusters that bundle an older Arrow (e.g. Gluten
+ bundles) would otherwise shadow ours and cause IVF-PQ + Arrow-C DataFusion
+ interaction to throw
+ `NoSuchMethodError: ArrowArrayStream.allocateNew(BufferAllocator)`.
+ - `spark.rpc.message.maxSize=512` — needed for the medium-scale (1M-row) synthetic
+ benchmark's driver-side row shipment. Default 128 MB trips the serialized-task
+ limit at 1M rows × dim-128.
+- **Fat JAR**: `lance-spark-knn_2.12--benchmark.jar`, shaded to Linux-x86_64 natives
+ only (darwin-aarch64 + linux-aarch64 excluded) to stay under some clusters'
+ volume-upload ingress timeout (~5-minute hard cap on managed-Spark distributions).
+ Drops from 254 MB → 102 MB.
+
+## Synthetic benchmark (dim=128, `IndexedNearestJoinBenchmark`)
+
+This section covers the cross-cluster scaling sweep at synthetic dim=128. Two cluster
+shapes (8 × 8c/32g and 4 × 8c/32g) × five `|R|` scales (sampling from 10K to 1M with a
+ground-truth at |R|=1M, |L|=100). Methodology is detailed below the headline tables;
+read it before quoting any number — there are gotchas around baseline plan choice and
+multi-tenant variance.
+
+### Setup for first-time reviewers
+
+| Item | Value |
+|---|---|
+| Spark version | OSS 3.5.4, standalone-per-app, Kubernetes-deployed pods |
+| Vector dim | 128 (synthetic random; uniform over `Float`) |
+| K (top-K per query) | 10 |
+| Right-side fragments | 8 (Lance write `repartition(8)` → 8 fragments) |
+| Sink | `df.write.format("noop").save()` (Spark's canonical benchmark sink — full row materialization, no driver round-trip) |
+| Iterations | 1 warmup + 3 measurement, median reported |
+| Correctness gate | All configs run against in-memory brute-force oracle on 16-row left subset before timed runs; `sys.error` if any disagree |
+| AQE | **Disabled** for this sweep (`spark.sql.adaptive.enabled=false`) — see "Why AQE off" below |
+| Cross-product right-side repartition | Right side (post-Lance-read) repartitioned to `cores total` so the cross-join compute stage parallelizes; without this, 8-fragment Lance read caps the fused stage at 8 tasks |
+
+### Configurations
+
+Six configurations exercised in the sweep. **Naming change vs. earlier sections of this
+doc:** the older "A" was "row_number window over crossJoin" — that plan is structurally
+unfit for medium scale (no partial aggregation, single-task per shuffle partition runs
+hours). It's preserved as opt-in via `BENCHMARK_INCLUDE_BASELINE_A=true`. The default
+sweep uses **A2** as the headline baseline:
+
+| Config | Plan | Why it's the right baseline |
+|---|---|---|
+| A | crossJoin + L2 UDF + `row_number().over(Window.partitionBy(lid))` + filter rank≤K | **Off by default.** O(\|R\| log \|R\|) global sort per `lid`, no partial aggregation. Hours at \|R\| ≥ 100K. |
+| **A2** | crossJoin + L2 UDF + `groupBy(lid).agg(slice(sort_array(collect_list(struct(dist, rid))), 1, K))` + `inline()` | **Headline baseline.** Closest Spark 3.5 SQL expression of what 4.2's `RewriteNearestByJoin` lowers to. Spark applies partial aggregation per task (each task partial-sorts and trims to K-ish before shuffle), so wall-clock is bounded. |
+| B | `df.kNearestJoin(probeParallelism=1)` | Phase 0/1 single-task probe |
+| C | `df.kNearestJoin(probeParallelism=4)` | Phase 1.5 fragment-grouping at 4 |
+| D | `df.kNearestJoin(probeParallelism=8)` | Phase 1.5 fragment-grouping at 8 (= numFragments) |
+| E | `df.kNearestJoin(probeParallelism=8, balanceFragments=true)` | Phase 1.5 + LPT skew balancing |
+
+The closest 4.2-native plan would use `min_by(struct, expr, K)` (`MaxMinByK`,
+SPARK-55322) which does O(|R| log K) heap-K, asymptotically better than A2's
+O(|R| log |R|) per-group sort. That expression doesn't exist on Spark 3.5; A2 is the
+closest expressible shape. Quoted speedup is therefore conservative for what users will
+see on Spark 4.2 (the 4.2-native baseline runs slightly faster than A2, so the speedup
+ratio shrinks).
+
+### Why AQE off
+
+AQE's `CoalesceShufflePartitions` aggressively reduces partition count when the
+post-shuffle data is small. On A2 at small/medium |R|, the post-cross-join shuffle is
+a few hundred MB which AQE coalesces from `spark.sql.shuffle.partitions=128` down to
+~8 partitions. Combined with the Lance-read producing 8 fragment-partitions, this fuses
+the cross-join + UDF + groupBy into one stage of 8 tasks — capping parallelism at
+8 cores regardless of the cluster's total cores. With AQE off, post-shuffle partitions
+stay at the configured value (128 here) and all cluster cores get utilized. AQE remains
+on for indexed-path runs (it benefits the merge-side shuffle); the toggle is per-run via
+`BENCH_DISABLE_AQE=true`.
+
+### Sweep — big cluster (8 executors × 8c/32g = 64 cores, 8 pods)
+
+7-iteration baseline-sweep run. AQE off. Repartition right side to 64.
+
+| Scale | R × L | A2 (ms) | B (ms) | C (ms) | D (ms) | E (ms) | A2/B | A2/E |
+|---|---|---:|---:|---:|---:|---:|---:|---:|
+| sample_r10k | 10K × 1K = 10M pairs | 11,522 | 1,787 | 2,270 | 2,587 | 2,584 | **6.4×** | **4.5×** |
+| sample_r50k | 50K × 1K = 50M pairs | 60,678 | 3,596 | 4,114 | 4,245 | 4,303 | **17×** | **14×** |
+| sample_r100k | 100K × 1K = 100M pairs | 120,564 | 7,033 | 6,697 | 7,391 | 7,195 | **17×** | **17×** |
+| sample_r200k | 200K × 1K = 200M pairs | 218,071 | 12,898 | 13,457 | 13,923 | 14,180 | **17×** | **15×** |
+| **medium_l100** | **1M × 100 = 100M pairs** | **112,135** | **1,499** | **1,697** | **1,252** | **1,179** | **75×** | **95×** |
+
+**Linear scaling on A2 confirmed** — coefficient is ~1.05s per million pairs. Doubling
+|R| roughly doubles A2's wall-clock (50→100K = 2.0×; 100→200K = 1.81×).
+
+**Extrapolation to full medium (|R|=1M, |L|=1K = 1B pairs):** A2 ≈ 1.05 × 1000s ≈
+~1050s. **Validation by ground truth:** medium_l100 (|R|=1M, |L|=100, 100M pairs) =
+112s. Scaling that to |L|=1000 gives 1120s. **Extrapolation matches independent
+ground truth within 7%.** Indexed-path E at full medium ≈ 12s (linear from medium_l100's
+1.18s × 10), giving an extrapolated speedup of ~1100/12 = **~90×** at full medium scale.
+
+### Sweep — small cluster (4 executors × 8c/32g = 32 cores, 4 pods)
+
+Same sweep at half the cores. Goal: see how indexed-path scales when cluster shrinks.
+
+| Scale | A2 (ms) | B (ms) | C (ms) | D (ms) | E (ms) | A2/B | A2/E |
+|---|---:|---:|---:|---:|---:|---:|---:|
+| sample_r10k | 22,437 | 1,232 | 2,199 | 2,012 | 2,034 | **18×** | **11×** |
+| sample_r50k | 108,078 | 2,257 | 3,350 | 3,345 | 3,282 | **48×** | **33×** |
+| sample_r100k | 215,261 | 4,319 | 4,974 | 4,718 | 4,700 | **50×** | **46×** |
+| sample_r200k | 430,426 | 8,284 | 6,786 | (13,000)¹ | (12,952)¹ | **52×** | (33×)¹ |
+| **medium_l100** | **222,498** | **2,459** | **1,928** | **1,635** | **1,578** | **91×** | **141×** |
+
+¹ D and E at r200k inflated by an unrelated executor-network failure mid-run that
+forced task retries on a degraded cluster. Discount these two cells; B and C in the
+same row aren't affected.
+
+### Big vs. small cluster — what scales how
+
+The interesting question is how indexed-path numbers move when you halve cores. A2 is
+the reference for "cross-product baseline scaling" — purely compute-bound, should be
+roughly linear with cores.
+
+| Metric | Big (64c) | Small (32c) | Ratio (small/big) | Reading |
+|---|---:|---:|---:|---|
+| **A2 at r200k** | 218,071 | 430,426 | **1.97×** | Perfectly linear with cores. Cross-product baseline is purely compute-bound. |
+| **A2 at medium_l100** | 112,135 | 222,498 | **1.98×** | Same — confirms A2 scales linearly. |
+| **B at r200k** | 12,898 | 8,284 | **0.64×** | Small cluster is **faster**. Phase 0/1 has 8 fragment-parallel tasks; bigger cluster's wider merge shuffle is overhead with 1 contributor per leftId. |
+| **B at medium_l100** | 1,499 | 2,459 | **1.64×** | Small cluster slower, but sub-linear (vs A2's 1.98×). |
+| **E at medium_l100** | 1,179 | 1,578 | **1.34×** | Small cluster only 1.34× slower at half cores. Phase 1.5's fragment-grouping shuffle benefits from co-located executors. |
+
+**Headline finding: indexed-path scales sub-linearly with cores.** Phase 0/1 with
+8-fragment data sees the merge-side shuffle as pure overhead beyond ~8 cores; Phase 1.5
+also benefits from fewer pods (less network traffic). On the cross-product side, halving
+cores doubles wall-clock — exactly as theory predicts. **Speedup ratios therefore grow
+on smaller clusters** (51-141× small vs 17-95× big, depending on shape) because the
+indexed path is already CPU-saturated and the baseline isn't.
+
+### Variance / multi-tenant noise
+
+The OSS Spark 3.5 cluster is multi-tenant infrastructure. Two run-to-run effects show
+up across the sweep:
+
+1. **Run-to-run variance ±20% per config.** Same 7-iteration medians on identical jobs
+ land 10-30% apart depending on overall cluster load at the time. The headline
+ "100-200× speedup range" framing in the cross-cluster summary at the bottom of this
+ doc is the honest read for any single run.
+
+2. **Noisy-neighbor pods.** On some pod allocations, one executor runs 2-3× slower
+ than the rest (sustained, across every stage of the run, not data-skew). When a
+ stage's wall-clock = `max(per-executor task time)`, this multiplies that stage's
+ wall-clock by the same factor. A first attempt of the big-cluster sweep hit this
+ (executor 2 was 2.7× slower than the other 7 across every stage). Re-submitting
+ typically lands different pod hosts and clears it. **The numbers above are from
+ clean runs (max-skew ≤ 1.4×).** Detect via the Spark UI's stages → tasks page or
+ the REST API (`/api/v1/applications//stages///taskList`,
+ sorted by duration); a sustained ~2× ratio between the slowest and median executor's
+ per-task times across multiple stages indicates a noisy-neighbor pod rather than
+ data skew.
+
+3. **Executor death recovery inflates retried numbers.** One config (D at small
+ cluster, r200k) hit an executor network disassociation mid-stage; Spark recovers
+ by retrying tasks on remaining executors, which ~doubles the wall-clock for that
+ measurement. Marked with footnote ¹ in the table.
+
+For internal teammates: when sharing these numbers, note "single-run point estimate
+on multi-tenant cluster, ±20% noise envelope; speedup order-of-magnitude is robust,
+specific multiplier is not." The baseline-vs-indexed-path gap is large enough (≥1.5
+orders of magnitude) that no plausible noise envelope flips the conclusion.
+
+### Earlier two-run measurement (8 × 4c/16g, prior cluster sizing)
+
+Kept for historical comparison with what was published in earlier iterations of this
+doc. Sampled at full medium scale only:
+
+| Config | Run 1 (ms) | Run 3 (ms) | Stable signal |
+|---|---:|---:|---|
+| B: Phase 0/1 (probeParallelism=1) | 92,107 | 93,466 | ~92 s |
+| C: Phase 1.5 (probeParallelism=4) | 106,126 | 108,026 | ~107 s (slower than B) |
+| **D: Phase 1.5 (probeParallelism=8)** | **54,639** | **55,783** | **~55 s (1.69× faster than B)** |
+| **E: Phase 1.5 (G=8, skew-balanced)** | **54,236** | **56,341** | **~55 s** |
+
+**Key cluster finding (different from local):** Phase 1.5 D/E **wins** at medium scale
+on a true distributed cluster. Grain must match fragment count (probeParallelism=8 on
+8 fragments → 1 fragment per task). The local M5 Max "Phase 1.5 doesn't help" finding
+was single-machine specific — cross-machine parallelism (8 independent executor JVMs
+with independent memory buses) beats Lance-internal 8-thread execution on one machine.
+
+C (probeParallelism=4 on 8 fragments) is slower than B because the grain mismatch
+pays for shuffle overhead without enough work-partitioning to offset it. This matches
+the algebraic hypothesis: Phase 1.5 wins only when `probeParallelism == numFragments`.
+
+Note this older run did **not** use the `noop` sink (used `count()`) and did not have
+the AQE-off / right-side-repartition fixes that the baseline-sweep above applies. The
+indexed-path numbers are still directly comparable; the baseline numbers from this
+earlier cluster sizing are not quoted here because they used the row_number-window
+plan that doesn't run to completion at medium scale.
+
+## Production-shape perf (dim=1024, `WikipediaKnnPerfBenchmark`)
+
+Cohere Labs `wikipedia-2023-11-embed-multilingual-v3` English shard — 1024-dim
+multilingual-v3 embeddings, normalized for cosine (L2 used in benchmark; produces the
+same top-K ordering on unit vectors).
+
+### Measurement methodology
+
+Results below use Spark's `write.format("noop")` sink in the timing loop instead of
+`count()`. `count()` could in principle give the crossJoin baseline a small relative
+advantage (it can skip some per-row result materialization, while the indexed path's
+`LanceMaterializeLogicalPlan` forces materialize to run in full due to the
+`references = child.outputSet` override). In practice the dominant cost on both paths
+is upstream of result assembly — the crossJoin's `l2()` UDF over |L|×|R| pairs, and
+the indexed path's Lance native distance kernel — so the sink switch moves single-run
+wall-clock within the cluster's natural run-to-run variance envelope. The `noop` sink
+is still the right default: it's what Spark's internal benchmarks use, matches
+end-to-end execution, and avoids any ambiguity about what got skipped.
+
+Every run passes a **brute-force oracle check** on a 16-row left subset before the
+timed measurements: each config's top-K row IDs must match an in-memory O(|R|) oracle.
+Bails via `sys.error` if any config disagrees. Cardinality alone (what `count()` would
+check) isn't a correctness proof — a bug emitting |L|×K garbage rows would still pass
+a count-based gate.
+
+**Variance envelope.** The OSS Spark cluster used here is multi-tenant infrastructure; 3-iteration medians on
+jobs of this shape show roughly ±30% run-to-run variance per config on shared
+CPU/disk/network. Numbers below are single-run medians unless otherwise noted —
+don't over-interpret a single point estimate. The speedup vs the crossJoin baseline
+is large enough (≥100×) that it survives noise comfortably; precise speedup within
+the indexed-path configs (B vs C vs D vs E) should be read as approximate.
+
+### Speedup vs. Spark crossJoin (|R|=1K × |L|=50, dim=1024)
+
+7-iteration run on 8 × 4c/16g OSS Spark 3.5 executors, all configs oracle-verified at K=10.
+
+| Config | Median (ms) | Min–Max (ms) | Speedup × (median) |
+|---|---:|---:|---:|
+| A: Spark crossJoin (baseline) | 64,944 | 63,793–66,450 | 1.00× |
+| B: Phase 0/1 (probeParallelism=1) | 469 | 333–752 | 138× |
+| C: Phase 1.5 (probeParallelism=4) | 455 | 402–958 | 143× |
+| D: Phase 1.5 (probeParallelism=8) | 452 | 371–513 | 144× |
+| **E: Phase 1.5 (G=8, skew-balanced)** | **406** | 391–557 | **160×** |
+
+**Reading the numbers.** The baseline is tight: 7-run range is ±2% around 65s
+(crossJoin is purely CPU-bound JVM arithmetic, nothing cache-sensitive). The indexed
+path is noisier: per-run spikes of ±20–40% around the median appear at arbitrary
+iteration positions (not run 1, not end-of-sequence), consistent with cluster-level
+contention on the multi-tenant cluster — not with Lance-side cache warming (no monotonic
+drift). Quote the median, but treat the speedup as "100–200× range" rather than a
+crisp point estimate.
+
+**Why the baseline is at 1K, not 100K**: Spark's `crossJoin + L2 UDF + row_number
+window` at dim=1024 × 100K rows × 100 queries is ~20 minutes per run. The `O(|L|·|R|·dim)`
+JVM UDF evaluation is the bottleneck; Lance's native SIMD kernel on Arrow columnar
+batches avoids that entirely. At 1K scale the baseline is already ~70s — meaning a
+full 100K× baseline on this cluster would run 1-2 hours for a single timing. The
+indexed path at the same 1K scale is well under 1s. The 100–200× multiplier is on
+the correctness comparison — both paths produce the same top-K rows (oracle-gated).
+
+**Note on prior numbers from earlier runs.** Two earlier runs on the same shape
+produced 188× and 139× headlines; both were 3-iteration medians. The current 160×
+headline is a 7-iteration median with explicit per-run variance visible above. All
+three runs agree on the order of magnitude and disagree on the specific multiplier,
+which is exactly what ±20% cluster noise predicts. **The honest framing is
+"100–200× speedup on real Cohere embeddings at dim=1024 on this OSS Spark cluster" —
+the order-of-magnitude story is robust to cluster noise; the specific multiplier
+drifts with whichever run you quote.**
+
+**Lance caching hypothesis — refuted by the per-run data.** If Lance's Rust Session
+cache or JVM JIT warming were a factor, run 1 would be systematically slow and later
+runs consistently faster. Instead, the slowest measurement for each indexed config
+lands at an arbitrary iteration (run 3 for C/D/E, run 3 for B), and the distribution
+scatters rather than monotonically decreasing. The variance is cluster-side
+(multi-tenant CPU contention, GC pauses), not Lance-side cache warming.
+
+### Indexed-path scaling (|R|=90K × |L|=10K, no baseline, dim=1024)
+
+Production-scale run: 10,000 query vectors against 90,000 base vectors at dim=1024, with
+a `noop` sink (all 100,000 result rows assembled per measurement) and 16-row oracle
+check.
+
+| Config | Median (ms) |
+|---|---:|
+| **B: Phase 0/1 (probeParallelism=1)** | **120,565** |
+| C: Phase 1.5 (probeParallelism=4) | 479,577 |
+| D: Phase 1.5 (probeParallelism=8) | 273,601 |
+| E: Phase 1.5 (G=8, skew-balanced) | 273,880 |
+
+At ~90,000 queries/minute throughput for dim=1024 L2 on a Spark cluster, config B is
+the right default for single-shard production embeddings. C/D/E all regress vs B
+because once the per-task probe is already processing tens of thousands of rows
+(|R|/G × |L| per task), Lance's internal threading saturates the CPU; the
+fragment-group replication shuffle becomes pure overhead. Same finding as the smaller
+|R|=99.9K × |L|=100 scaling run published earlier — the fragment-grouping cost
+doesn't pay off once per-task probe work is already large.
+
+**Takeaway:** leave `probeParallelism = 1` for production embeddings at this scale.
+Phase 1.5 shines only at small-per-task + multi-fragment workloads (e.g., many small
+Lance shards spread across executors where each task processes a narrow slice).
+
+### Vs. dim=128 synthetic baseline
+
+| | dim=128 (synthetic) | dim=1024 (Cohere wiki) |
+|---|---:|---:|
+| Baseline speedup vs crossJoin, small scale | ~18× (SIFT-style, M5 Max) | **100–200×** (OSS Spark cluster, 7-iter median 160×, noop sink + oracle-verified) |
+| Best indexed config, 100K × 100 | B: 1,515 ms | B: 2,945 ms |
+
+The speedup **grows with dim** — the opposite of naive expectation. Lance's SIMD
+kernel processes 8–16 floats/cycle; the per-pair work increase from dim 128→1024 is
+linear in kernel time, but Spark's per-pair JVM overhead (BNL iterator + 5 expression
+layers + Catalyst boxing) is a near-constant ~300 ns that dominates at small dim. At
+large dim the native kernel advantage widens because the JVM can't vectorize the UDF's
+per-row `Seq[Float]` access pattern.
+
+## Production-shape recall (dim=1024, `CohereWikiRecallBenchmark`)
+
+Same dataset, same cluster shape. IVF-FLAT 256 partitions, 100K base / 100 held-out
+queries, ground truth computed by brute-force crossJoin (what's fast enough at 100K).
+
+| nprobes | recall@10 | mean ms/query |
+|---:|---:|---:|
+| 1 | 0.6620 | 41.10 |
+| 4 | 0.8380 | 21.97 |
+| **16** | **0.9490** | **9.88** |
+| 64 | 0.9910 | 16.19 |
+
+**Production operating points**:
+- `nprobes=16` — 95% recall at 10 ms/query. The sweet spot for most RAG workloads.
+- `nprobes=64` — 99% recall at 16 ms/query. Higher nprobes doesn't help linearly (the
+ ground truth is already captured at 64; beyond is diminishing returns).
+- IVF-FLAT build cost: 19s for 100K × 1024-dim on 8×4c/16g.
+
+IVF-PQ was NOT measured in this run — tested initially on SIFT1M and found that PQ's
+top-K results exactly matched IVF-FLAT's across every (nprobes, refineFactor) grid
+cell, which is a red flag that the query path may always select the first-built index
+rather than honoring the probe-time index choice. Separate investigation.
+
+## SIFT1M recall (dim=128, `SiftRecallBenchmark`)
+
+Mechanics / published-comparable validation against the canonical ANN-benchmark
+corpus. Same OSS Spark 3.5 cluster shape, 1M base vectors × 1000 queries, IVF-FLAT 256 partitions.
+
+| nprobes | recall@10 |
+|---:|---:|
+| 1 | 0.4719 |
+| 4 | 0.8161 |
+| **16** | **0.9831** |
+| **64** | **0.9994** |
+
+Within noise of published FAISS IVF-FLAT numbers on SIFT1M. Index build times:
+IVF-FLAT 35.7s, IVF-PQ 38.6s on 1M × 128-dim.
+
+## Sustained-load soak (`IndexedNearestJoinSoakTest`)
+
+Production-readiness validation #2 — run concurrent queries for N minutes and watch
+for memory growth / handle leaks / latency drift.
+
+**Smoke soak** (10 min, |R|=1M, 8 concurrent queries, QPS target 2, pP=8, dim=128):
+
+```
+completed queries: 492 (0.82 QPS observed; latency-bounded, not QPS-bounded)
+failed queries: 0 (0% during the 10-min load window)
+
+LATENCY (ms)
+ p50: 11,551 p95: 16,032 p99: 18,962 max: 23,057
+
+HEAP over time (MB, driver-side)
+ t=0s: 163 t=120s: 218 t=240s: 213 t=360s: 226
+ t=480s: 262 t=540s: 248 end: 227
+```
+
+Heap oscillates 163–266 MB with no upward trend. Zero failures during the load window.
+Post-deadline ~30 queued queries failed with "stopped SparkContext" — harness bug
+(pool drain races `spark.stop()`), not a production leak. Verdict: pipeline is
+memory-stable under sustained concurrent load at this scale.
+
+## How to reproduce
+
+On any OSS Spark 3.5 / Kubernetes cluster with 8 × 4c/16g executor pods, a mounted
+volume for the JAR + parquet data, and `spark-submit` access:
+
+```sh
+# Build the fat JAR (Linux-x86_64-only natives to stay under typical
+# managed-Spark volume-upload timeouts).
+./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests
+
+# Upload target/lance-spark-knn_2.12--benchmark.jar + Cohere parquet shards
+# to your cluster's mounted volume (mechanism is vendor-specific).
+
+# Cohere wiki perf (small shape: 1K base + baseline + oracle).
+spark-submit \
+ --class org.lance.spark.knn.benchmark.WikipediaKnnPerfBenchmark \
+ --driver-memory 16g --driver-cores 4 \
+ --executor-memory 16g --executor-cores 4 \
+ --conf spark.driver.extraClassPath= \
+ --conf spark.executor.extraClassPath= \
+ --conf spark.rpc.message.maxSize=512 \
+ --conf spark.sql.crossJoin.enabled=true \
+ \
+ # env vars:
+ BENCH_CLUSTER_MODE=true \
+ BENCH_DATA_PATH=file:///knn-bench-data \
+ WIKI_PARQUET='/wiki-*.parquet' \
+ WIKI_NUM_RIGHT=1000 WIKI_NUM_LEFT=50 \
+ WIKI_RUN_BASELINE=true WIKI_MEASURE_RUNS=7
+
+# Scaling shape (100K base, indexed only): set WIKI_NUM_RIGHT=100000,
+# WIKI_NUM_LEFT=10000, WIKI_RUN_BASELINE=false. Synthetic medium:
+# IndexedNearestJoinBenchmark with BENCHMARK_SCALE=medium.
+```
+
+If you're on a managed-Spark distribution that ignores `spark.executor.instances` in
+standalone-per-app mode, use your distribution's equivalent executor-count knob
+(e.g., `ae.spark.executor.count` on some deployments). The
+`spark.{driver,executor}.extraClassPath` entries put the benchmark fat JAR earlier on
+the classpath than any cluster-bundled Arrow that might otherwise shadow ours.
diff --git a/lance-spark-knn_2.12/DESIGN.md b/lance-spark-knn_2.12/DESIGN.md
new file mode 100644
index 000000000..c313d2cb7
--- /dev/null
+++ b/lance-spark-knn_2.12/DESIGN.md
@@ -0,0 +1,637 @@
+# Indexed nearest-by join — feature design
+
+> **Audience.** Human reviewers approaching this feature for the first time. Read this top-to-
+> bottom before reading diffs. Several of the design choices look arbitrary in isolation but
+> are forced by Spark's existing rule ordering or by Lance's index shape — the rationale is
+> here, not in the source.
+>
+> **PoC scope.** All phases (0 / 1 / 1.5 / 2 / 3 / 3.x) currently ship together on the
+> `knn-phase0` fork branch (this is that branch). For upstream delivery to
+> `lance-format/lance-spark:main`, the branch will be split into 7 smaller PRs — see
+> [`UPSTREAM_DELIVERY_PLAN.md`](UPSTREAM_DELIVERY_PLAN.md) for the split. The phase
+> labels here describe the design layers — the order in which a reviewer should read the
+> code — not the eventual per-PR release boundaries.
+
+## TL;DR
+
+For SQL of the shape
+
+```sql
+SELECT *
+FROM queries q
+LEFT OUTER JOIN documents d
+ APPROX NEAREST 10 BY DISTANCE vector_l2_distance(q.vec, d.vec)
+```
+
+— route execution through a per-fragment Lance index probe + Spark-side merge instead of the
+default `O(|L| × |R|)` cross-product. Fast where Lance has a vector index; falls back to
+Spark's brute-force rewrite when conditions don't hold. **No public API change** for users
+once Phase 2 ships: registering one Spark session extension is the entire opt-in.
+
+The work is split across two modules:
+
+| Module | What it owns | Spark version | Status |
+|---|---|---|---|
+| `lance-spark-knn_2.12` (+ `_2.13` cross-build) | Pipeline primitives, public DataFrame API (`IndexedNearestJoin.apply`, `df.kNearestJoin` extension), 3-exec Catalyst staged plan, Phase 1.5 fragment grouping, Phase 3 recall knobs. | 3.4 / 3.5 / 4.0 (reflection bridge in `LanceKnnDatasetBridge`); 2.12 and 2.13 cross-build. | Phase 0 / 1 / 1.5 / 3 done. |
+| `lance-spark-knn-4.2_2.13` | Catalyst rule + logical + physical operators that intercept SQL `NearestByJoin`. | 4.2-SNAPSHOT (SPARK-56395 `NearestByJoin` only exists in master as of writing; re-pin once 4.2.0 releases). | Phase 2 done, but held out of upstream delivery until Spark 4.2 ships. |
+
+## Why this works on Lance specifically
+
+- Lance vector indexes (IVF-PQ, HNSW) are **fragment-local**. Per-fragment probes are
+ independent, which makes parallelism trivial.
+- Lance's Java API exposes single-vector nearest search via
+ `org.lance.ipc.Query` + `LanceScanner.create(dataset, ScanOptions, allocator)`. We call into
+ this primitive from Spark tasks.
+- `_rowid` (Lance virtual column) makes late materialization cheap: the probe stage emits
+ row IDs only, the materialize stage point-fetches by ID. We use `_rowid` rather than
+ `_rowaddr` because Lance's INDEXED nearest-search path materializes `_rowid` but not
+ `_rowaddr` — using `_rowid` works on both indexed and non-indexed paths uniformly.
+- Lance versioning gives consistent snapshots across distributed tasks.
+
+## Architecture
+
+### The three-stage pipeline
+
+```
+left.logicalPlan
+ -- LanceProbeLogicalPlan --> [_leftId, leftRow fields..., _refs]
+ -- LanceMergeLogicalPlan --> same shape
+ -- LanceMaterializeLogicalPlan --> final join output schema
+
+lowered via LanceKnnStagedStrategy to:
+
+ LanceProbeExec
+ --> (per-task) open LanceProbe, nearest-search per left row, emit inter-stage rows
+ ShuffleExchangeExec hashpartitioning(_leftId)
+ --> inserted by EnsureRequirements, wrapped by AdaptiveSparkPlanExec when AQE is on
+ LanceMergeExec
+ --> per-partition group-by-leftId, TopKHeap.merge, re-emit merged rows
+ LanceMaterializeExec
+ --> open LanceProbe, point-fetch right rows by _rowid, assemble join Rows
+```
+
+`IndexedNearestJoin.apply` builds the three-logical-plan tree on top of the user's left
+analyzed plan, registers `LanceKnnStagedStrategy` on the session's
+`experimentalMethods.extraStrategies` (idempotent), and wraps the root via
+`LanceKnnDatasetBridge.asDataFrame` (a trampoline to `Dataset.ofRows` — that method is
+`private[sql]`).
+
+`df.explain()` shows four Catalyst nodes (`LanceProbe → Exchange → LanceMerge →
+LanceMaterialize`) wrapped by `AdaptiveSparkPlanExec`. With AQE enabled,
+`AQEShuffleRead coalesced` appears on the merge-side shuffle after the first collection.
+
+The pipeline objects live in `org.lance.spark.knn.internal`:
+
+- **`LanceProbeStage`** — the RDD-level primitive. Opens a `LanceProbe` per task, probes
+ Lance's nearest-search per left row, emits `(leftId, ProbedLeft)`. Map-side combine via
+ `TopKHeap` keeps state at exactly K entries when the task probes multiple fragments.
+ When `probeParallelism > 1`, `runWithFragmentGroups` replicates left rows across G
+ fragment groups via an internal RDD-level `partitionBy` so each task sees one group.
+ (The fragment-grouped probe path's internal shuffle remains AQE-invisible — tracked as
+ future work.)
+- **`LanceMergeStage`** — per-partition group-by-leftId + `TopKHeap.merge`. No shuffle
+ inside the stage itself — the shuffle is the `ShuffleExchangeExec` above it, inserted
+ by Catalyst from `LanceMergeExec.requiredChildDistribution = ClusteredDistribution(leftId)`.
+ `merge` is associative + commutative so per-partition aggregation is equivalent to the
+ prior `reduceByKey` formulation.
+- **`LanceMaterializeStage`** — opens `LanceProbe` again per task, calls
+ `materialize(rowIds)` which lowers to `_rowid IN (...)` (Lance's row-id lookup path),
+ assembles join rows from the carried left payload + materialized right rows.
+
+The custom Catalyst nodes live in `org.lance.spark.knn.internal.staged`:
+
+- **`StagedPlans.scala`** — three `LogicalPlan` nodes. Critical detail:
+ `LanceMergeLogicalPlan` and `LanceMaterializeLogicalPlan` override
+ `lazy val references = child.outputSet`. This blocks Catalyst's `ColumnPruning` rule
+ from inserting `Project(Nil)` wrappers between the custom nodes when a downstream
+ consumer references no columns (`count(*)`, etc.) — an oversight in the initial 3-exec
+ implementation that caused `AssertionError` / SIGSEGV at runtime. See `IMPL_PLAN.md`
+ "3-exec staged split — root cause and fix" for the full post-mortem.
+- **`StagedExecs.scala`** — the three `SparkPlan` execs. Each `doExecute` decodes the
+ inter-stage `InternalRow`s via `ProbedLeftCodec.Decoder`, runs the stage primitive, and
+ re-encodes. `LanceMergeExec.requiredChildDistribution` is what triggers the Exchange.
+- **`ProbedLeftCodec.scala`** — flat inter-stage schema (`_leftId`, leftSchema fields
+ inlined, `_refs: array>`). Single `ExpressionEncoder` pass for
+ encode + direct `InternalRow` accessors for decode; earlier multi-pass codec attempts
+ introduced binary-layout issues.
+- **`LanceKnnStagedStrategy.scala`** — registered once per session; lowers each logical
+ plan to its matching exec.
+- **`LanceKnnDatasetBridge.scala`** (in `org.apache.spark.sql` package) — one-method
+ trampoline to the package-private `Dataset.ofRows`.
+
+Phase 2's `IndexedNearestByJoinRule` emits the same three logical plans described above.
+Both the DataFrame API path and the SQL path lower through `LanceKnnStagedStrategy` into
+the identical `LanceProbeExec → ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec`
+chain — the Catalyst rule is the only SQL-specific piece.
+
+### Why `_rowid` not `_rowaddr`
+
+Lance has two virtual columns that identify a row:
+
+| Column | Encoding | Available on |
+|---|---|---|
+| `_rowaddr` | physical address `(frag_id << 32) \| row_in_frag` | non-indexed scans only |
+| `_rowid` | logical Lance-assigned ID | indexed AND non-indexed scans |
+
+The probe stage emits one row-identifier value per ranked refsult; the materialize stage
+filters by that same identifier (`
IN (rowIds...)`) for point-fetch. So the column
+choice has to work on both code paths the probe stage exercises.
+
+The Phase 0/1 prototype used `_rowaddr` because it's the natural physical pointer. That
+worked on the no-index path. When the IVF-PQ recall test built an actual vector index, the
+probe failed with:
+
+```
+LanceError(Schema): Schema error: No field named _rowaddr. Did you mean '_rowid'?
+```
+
+Lance's indexed nearest-search materializes `_rowid` but not `_rowaddr`. So the whole
+pipeline uses `_rowid` (shipped as part of the Phase 3 hardening commit):
+
+- `ScanOptions.Builder.withRowAddress(true)` → `withRowId(true)` in `LanceProbe.probe`.
+- `_rowaddr IN (...)` → `_rowid IN (...)` in `LanceProbe.materialize` filter.
+- `LanceProbe.RowAddressColumn` constant → renamed to `RowIdColumn`, sourced from
+ `LanceConstant.ROW_ID`.
+
+Behavior on the no-index path is identical (both columns work there); the indexed path now
+works at all. The existing oracle test still passes — `_rowid` lookups via the row-id index
+have the same point-fetch semantics as `_rowaddr` lookups did on the row-address index.
+
+Variable names elsewhere (`rowAddrs: Seq[Long]`, `extractRowAddr`, `ScoredRowRef.rowAddr`)
+are retained for source-compat — the field type and lookup semantics are unchanged, only
+the underlying virtual column the value is read from is different.
+
+### Bandwidth math
+
+The substantive performance argument for the staged design is shuffle bandwidth.
+
+```
+Brute-force rewrite (Spark default):
+ cross-product = |L| × |R| rows shipped through shuffle
+ payload = full right row ~hundreds of bytes to KBs
+
+Indexed staged pipeline:
+ shuffle volume = |L| × N × K refs N = number of probe tasks
+ payload per ref = ~24B (8B addr + 4B score + overhead)
+```
+
+For `|L| = 10⁶, |R| = 10⁹, N = 100, K = 10`:
+- brute-force = 10⁶ × 10⁹ = 10¹⁵ pair evaluations.
+- staged = 10⁶ × 100 × 10 = 10⁹ refs through shuffle.
+
+That's six orders of magnitude. The win comes from late materialization — the probe stage
+emits refs only, and the materialize stage fetches payloads after the merge has already
+narrowed to top-K.
+
+### Why Lance brute-force still beats Spark cross-product (no index needed)
+
+A subtle finding from the benchmark: even WITHOUT a vector index, Lance's per-fragment scan
+beats Spark's `RewriteNearestByJoin` (`min_by_k` + `BroadcastNestedLoopJoin`) by ~17× on
+the same |L|×|R| pair-evaluation workload. Both paths do 10M pair evaluations on the
+small-scale benchmark; Spark takes 3,700 ms, Lance takes 220 ms. Why:
+
+1. **Native SIMD vs JVM expression evaluation.** Lance's distance kernel is hand-tuned Rust
+ with AVX-512 / NEON intrinsics. One L2 distance over a dim-128 vector takes ~8 cycles in
+ the SIMD kernel. Spark's `vector_l2_distance` is a `RuntimeReplaceable` lowered to
+ `StaticInvoke(VectorFunctionImplUtils.vectorL2Distance)` — JVM bytecode through Catalyst
+ expression evaluation per row. JIT can auto-vectorize the inner loop but loses to
+ hand-written intrinsics by ~5-10×.
+2. **Columnar contiguous arrays vs per-row deserialization.** Lance stores the vector column
+ as a contiguous Arrow `float32` array per fragment; the kernel iterates contiguous memory
+ (cache-friendly, prefetchable). Spark feeds each right row through Catalyst's iterator —
+ each `ArrayType(FloatType)` cell goes through `UnsafeArrayData.toFloatArray()` per-row
+ (per-row malloc + scattered loads).
+3. **Catalyst expression-evaluation overhead.** Spark's `RewriteNearestByJoin` lowers to
+ `Generate(Inline(Aggregate(min_by_k(struct(right.*), distance(L, R), K),
+ BroadcastNestedLoopJoin(LeftOuter, leftWith__qid, right))))`. Each (L, R) pair passes
+ through ~5 layers of expression evaluation: BNL row iterator → tag projection →
+ aggregate input projection → `vector_l2_distance` evaluation → `min_by_k` heap update.
+ Lance's path is just: scanner pulls a column batch, kernel iterates the float array,
+ updates a top-K heap. No JOIN, no struct serialization, no broadcast.
+
+Rough math sanity check (small scale, 10M pair evaluations across 18 cores):
+
+```
+Spark: 3,676 ms / 10M = 370 ns/pair (≈80 cycles/core/pair @ 4 GHz)
+Lance: 223 ms / 10M = 22 ns/pair (≈ 5 cycles/core/pair @ 4 GHz)
+```
+
+The Lance number is consistent with AVX-512 doing 16 float ops/cycle on dim-128 vectors
+(128/16 ≈ 8 cycles for the math, plus heap-maintenance and Arrow pointer dereferences).
+The Spark number is consistent with ~5 layers of Catalyst expression evaluation overhead
+on top of the underlying SIMD math.
+
+**Implication for the design.** The 17× SQL speedup (608× DataFrame) is real on a
+*no-index* dataset. An index then adds another order of magnitude on top. So the staged
+pipeline's value isn't conditional on having a vector index — Lance's native scan plus
+fragment-local parallelism beats Catalyst's per-pair JVM overhead even on the brute-force
+path. The index is a multiplier, not a prerequisite.
+
+### Recall
+
+Lance's vector indexes are approximate (IVF-PQ probes a subset of the inverted file). Recall is
+tunable via `nprobes` and an overfetch ratio (probe `K × overfetch`, then trim to `K` after
+the merge). Without an index, Lance falls back to brute-force per-fragment scan, which is
+exact (recall = 1.0) — that's how the Phase 0 oracle test is constructed.
+
+## Phases
+
+### Phase 0 — pure DataFrame API
+**Module: `lance-spark-knn_2.12`**.
+
+`IndexedNearestJoin.apply(left, lanceUri, leftVecCol, rightVecCol, k, ...)` Scala function. Pure
+RDD primitives wrapped around the `LanceProbe` per-task primitive. **No shuffle** — probe and
+materialize run in the same `mapPartitions` block. Existed first to validate per-task Lance
+access works at all and to baseline recall against a brute-force oracle.
+
+### Phase 1 — staged RDD pipeline + 3-exec Catalyst split (production path)
+**Module: `lance-spark-knn_2.12`**.
+
+Phase 0's inline `mapPartitions` was split into three stage objects (probe / merge /
+materialize). Public API unchanged. The DataFrame API path then split further into three
+explicit `SparkPlan` operators (`LanceProbeExec` / `LanceMergeExec` / `LanceMaterializeExec`)
+with a Catalyst-inserted `ShuffleExchangeExec` between probe and merge — this is the
+production path today, AQE-visible merge shuffle.
+
+The 3-exec split had a noteworthy debugging history during development: an early
+implementation produced reproducible `AssertionError` / SIGSEGV on `count()`-style
+consumers. Initial diagnosis blamed a JVM-aarch64 + JIT C2 interaction; that was wrong.
+The real root cause was Catalyst's `ColumnPruning` rule inserting `Project(Nil)`
+wrappers between the custom nodes when downstream consumers referenced no columns; the
+project codegens to 0-field `UnsafeRow`s which crash `ProbedLeftCodec.Decoder` at
+`ir.getLong(0)`. The fix is a `references = child.outputSet` override on the Merge /
+Materialize logical plans, which short-circuits ColumnPruning's subset guard. See
+`IMPL_PLAN.md` "3-exec staged split — root cause and fix" for the full post-mortem.
+
+The shuffle is structurally present but degenerate when `probeParallelism = 1` — each
+`leftId` has one contributor, so the merge function never fires. The full bandwidth win
+lands when fragment-grouping arrives in Phase 1.5.
+
+### Phase 2 — Catalyst integration
+**Module: `lance-spark-knn-4.2_2.13`**. Spark 4.2-SNAPSHOT only.
+
+A `postHocResolutionRule` pattern-matches `NearestByJoin(approx = true, ...)` over a Lance scan
+with a recognized vector-distance ranking expression and rewrites it to the same
+three-logical-plan tree produced by the DataFrame API path
+(`LanceProbeLogicalPlan → LanceMergeLogicalPlan → LanceMaterializeLogicalPlan`). The shared
+`LanceKnnStagedStrategy` then lowers that tree to the Probe/Merge/Materialize execs — SQL
+and DataFrame paths converge on the same physical shape.
+
+After Phase 2, users get the indexed path automatically from `APPROX NEAREST` SQL queries.
+EXACT queries and unsupported shapes flow through to Spark's existing brute-force rewrite —
+no functional regression.
+
+#### Why `injectPostHocResolutionRule`, NOT `injectOptimizerRule`
+
+This is the single most important detail in Phase 2. Spark 4.2's optimizer:
+
+```
+Optimizer
+ ├ Batch "Finish Analysis" ← RewriteNearestByJoin lives in here
+ │ ├ ReplaceExpressions (RuntimeReplaceable → StaticInvoke)
+ │ ├ ...
+ │ ├ RewriteNearestByJoin (NearestByJoin → cross-product + MaxMinByK)
+ │ └ ...
+ ├ Batch "Operator optimization batch"
+ │ └ rules added by injectOptimizerRule fire HERE
+ └ ...
+```
+
+By the time `injectOptimizerRule` rules fire, `RewriteNearestByJoin` has already replaced
+`NearestByJoin` with a cross-product + `MaxMinByK` plan. Nothing left for us to pattern-match.
+
+`injectPostHocResolutionRule` runs immediately after analysis — *before* the optimizer starts.
+We see the unrewritten `NearestByJoin` and the unreplaced `VectorL2Distance` /
+`VectorCosineSimilarity` / `VectorInnerProduct` ranking expressions. This same constraint
+applies to *any* engine wanting to substitute a different physical strategy for `APPROX NEAREST`
+queries.
+
+#### Pattern-match preconditions
+
+The rule rewrites only when ALL of these hold:
+
+| Check | Why |
+|---|---|
+| `approx = true` | EXACT mode is contractually deterministic; brute-force keeps owning it. |
+| `right` resolves to a Lance DSv2 relation (under at most a `SubqueryAlias`) | We need a URI to probe. Detection is class-name-based (`getClass.getName.contains("Lance")`); the URI comes from `options.get("path")` / `options.get("datasetUri")`. The probe + materialize stages are Lance-specific by construction (they call Lance's Java API directly), so there's no general "any vector backend" plug-in point — keeping detection simple is consistent with that. |
+| Ranking is one of `VectorL2Distance` (with `NearestByDistance`), `VectorCosineSimilarity` (with `NearestBySimilarity`), `VectorInnerProduct` (with `NearestBySimilarity`) | Direction must match the function's natural ordering. Lance's index supports L2 / cosine / dot — anything else has no fast path. |
+| Both arguments of the ranking function are bare attributes, one from each side | Mixed-side or composed expressions have no clean mapping to a Lance probe. Phase 3 may extend. |
+| `spark.lance.knn.indexedNearestByJoin.enabled = true` | Opt-in until Phase 3 cost gating lands. |
+
+When ANY condition fails the rule returns the plan unchanged and Spark's brute-force rewrite
+handles the query — no regression.
+
+#### Logical plans
+
+The rule emits the same three logical plans described in the "Architecture" section
+(`LanceProbeLogicalPlan` / `LanceMergeLogicalPlan` / `LanceMaterializeLogicalPlan`), wrapped
+in a `Project` that restores the original `NearestByJoin.output` attribute-for-attribute
+(including `ExprId`s) so any parent operator's references stay resolved — same contract
+`RewriteNearestByJoin` honors.
+
+The right-side Lance scan is **absorbed into the probe plan's config**, not kept as a child.
+Why: if right were still a child, Catalyst would happily plan a separate scan of it,
+defeating the whole optimization. The trade-off is that column-pruning / filter pushdown
+that would normally happen on the right side no longer happens automatically; the rule
+captures the projection set (and, for `SELECT * FROM lance WHERE ...`, the filter predicate)
+at rewrite time instead.
+
+#### Physical plans
+
+Exactly the same physical chain as the DataFrame API path — `LanceProbeExec →
+ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec` under `AdaptiveSparkPlanExec`.
+`LanceKnnStagedStrategy` is shared between both paths; the Catalyst rule is the only
+SQL-specific piece.
+
+The Project-above-Materialize also strips the trailing `__score` column, since
+`NearestByJoin`'s output contract doesn't include it (Phase 1 emits it internally for the
+probe/merge aggregation).
+
+### Phase 1.5 — fragment-grouped probing
+**Module: `lance-spark-knn_2.12`** (same as Phase 0/1).
+
+Adds an opt-in `probeParallelism: Int = 1` parameter on `IndexedNearestJoin.apply`. When > 1:
+
+1. Driver enumerates Lance fragment IDs via `Dataset.getFragments()`
+ (`internal/LanceFragments.scala`).
+2. Round-robins into N groups (or LPT bin-packs by per-fragment row count when
+ `balanceFragmentsByRowCount = true`); broadcasts the assignment to every
+ `LanceProbeExec` task.
+3. `LanceProbeStage.runWithFragmentGroups` replicates each left row across the N groups
+ via `flatMap`, then `partitionBy(HashPartitioner(N))` so each task processes a single
+ group only. Each task opens `LanceProbe` with its group's `fragmentIds`.
+4. Output keyed by `leftId` with N contributions per leftId. The merge stage is a
+ `LanceMergeExec` with `requiredChildDistribution = ClusteredDistribution(leftId)` —
+ Catalyst's `EnsureRequirements` inserts a `ShuffleExchangeExec` above it, and the exec
+ then per-partition groups by leftId and applies `TopKHeap.merge`. (The `partitionBy`
+ shuffle inside step 3 remains RDD-level and is NOT AQE-visible — tracked as Phase 3.x
+ future work. The merge-side Exchange inserted above step 4 IS AQE-visible.)
+
+This is where the bandwidth win the rest of this doc promises actually lands — Phase 1
+had the staged shape but a degenerate shuffle (one contributor per leftId, merge function
+never fired). Phase 1.5 makes the merge stage do real work.
+
+**Edge case** — when `probeParallelism > numFragments`, only one group has fragments and
+the rule degenerates back to the Phase 1 single-task path, avoiding a replicate-shuffle
+for nothing.
+
+**Cost** — two shuffles (replicate + merge) instead of Phase 1's one. Justified by the
+bandwidth win at scale; for tiny datasets stick with `probeParallelism = 1`.
+
+### Phase 3 — hardening (partial)
+
+Done:
+
+- **`refineFactor` and `ef`** — IVF-PQ re-rank pass and HNSW search depth, plumbed through to
+ `Query.Builder` calls in `LanceProbe.probe`.
+- **Row-count-aware fragment grouping** — `balanceFragmentsByRowCount` flag uses LPT greedy
+ bin-packing (4/3-optimal-makespan approximation) on `FragmentMetadata.getNumRows`.
+- **Prefilter pushdown** — when the right side of `NearestByJoin` is a `Filter` over a Lance
+ scan, `IndexedNearestByJoinRule` translates the predicate to a Lance SQL filter string and
+ threads it through `LanceProbeStage.Conf.prefilter` → `ScanOptions.filter()`. Lance applies
+ it BEFORE the index lookup (`prefilter = true` is always set), so top-K is computed over
+ only matching rows. Critical for correctness, not just perf: without it, a vector probe could
+ return K rows that are all later filtered out, masking truly-nearest-but-also-matching
+ rows further down the index.
+
+ Translation is conservative: bare `attr literal` comparisons, `IN`, `IS [NOT] NULL`, and
+ `AND`/`OR`/`NOT` over right-side attrs only. Anything else (UDFs, computed sub-expressions,
+ predicates touching the LEFT input) makes the rule REFUSE the rewrite and fall through to
+ Spark's brute-force cross-product. Refusal — not partial pushdown — because dropping a
+ residual conjunct would silently change result semantics; a slow-but-correct query is the
+ acceptable failure mode.
+
+ Project unwrapping: `SELECT * FROM lance WHERE ...` analyzes to
+ `Project(, Filter(cond, lance))`; the rule unwraps both. Non-passthrough
+ Projects (renames, drops, computed columns) fail the unwrap check and fall through.
+
+- **3-stage explicit physical operators (DataFrame API path)** — **Done.** The production
+ path is now `LanceProbeExec → ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec`
+ under `AdaptiveSparkPlanExec`. `df.explain()` shows all four nodes; AQE coalesces the
+ merge shuffle (`AQEShuffleRead coalesced`). An early SIGSEGV during development was
+ misattributed to JVM-aarch64; the real cause was Catalyst's `ColumnPruning` rule
+ inserting `Project(Nil)` wrappers between the custom nodes when downstream consumers
+ referenced no columns. Fix: `LanceMergeLogicalPlan` and `LanceMaterializeLogicalPlan`
+ override `lazy val references = child.outputSet`, short-circuiting `ColumnPruning`'s
+ subset guard. See `IMPL_PLAN.md` "3-exec staged split — root cause and fix" for the
+ full post-mortem. 60 tests pass.
+
+- **`df.kNearestJoin` DataFrame extension** — `LanceKnnImplicits._` provides an extension
+ method that hangs off any `DataFrame`, mirrors `df.join(other, ...)`, and works on
+ Spark 3.5 / 4.0 / 4.1 / 4.2+. It extracts the Lance URI from the right DataFrame's
+ analyzed plan automatically; non-Lance right sides (parquet, in-memory, alias-wrapped
+ non-Lance, etc.) fail fast with `IllegalArgumentException` naming the constraint.
+
+Outstanding (Phase 3.x — see `IMPL_PLAN.md` for the full table):
+
+- Cost gate replaces opt-in flag.
+- Spark version CI matrix (compile+test verified on 3.5 and 4.0 via the
+ reflection bridge; formal CI job still TODO).
+- AQE-visible shuffle for the fragment-grouped probe path
+ (`runWithFragmentGroups`'s internal `partitionBy` remains RDD-level; the
+ merge-side shuffle IS AQE-visible).
+- Per-executor `LanceProbe` cache to amortize dataset-open across small
+ partitions.
+- Left-side skew handling (today only the right's fragment groups are
+ balanced).
+
+## Public surface
+
+### DataFrame API
+
+Two equivalent forms — pick whichever fits the call site.
+
+**Idiomatic extension method.** Lives on every `DataFrame`, hangs off the left side the
+same way `join` does, and works on every Spark version the connector supports (3.5, 4.0,
+4.1, 4.2+). The right side must be a Lance scan
+(`spark.read.format("lance").load(uri)`); the extension extracts the Lance URI from the
+right DataFrame's analyzed plan automatically.
+
+```scala
+import org.lance.spark.knn.LanceKnnImplicits._
+
+val docs = spark.read.format("lance").load("/path/to/lance")
+val joined = queries.kNearestJoin(
+ right = docs,
+ leftVecCol = "qvec",
+ rightVecCol = "vec",
+ k = 10,
+ metric = "l2") // l2 | cosine | dot
+```
+
+A `Filter` / `SubqueryAlias` / `Project(passthrough)` over the right Lance scan is
+unwrapped before URI extraction; passing a non-Lance DataFrame throws
+`IllegalArgumentException` with a message naming the constraint.
+
+**URI form.** When you don't have a `DataFrame` for the right side and just have a path
+string (e.g. early in a job before Spark sees the dataset), call the underlying
+`IndexedNearestJoin.apply` directly.
+
+```scala
+import org.lance.spark.knn.IndexedNearestJoin
+
+val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = "/path/to/lance",
+ leftVecCol = "qvec",
+ rightVecCol = "vec",
+ k = 10,
+ metric = "l2")
+```
+
+Both forms return a `DataFrame` with schema `left.* ++ right.* ++ __score`.
+
+### SQL (Phase 2 path)
+
+```scala
+SparkSession.builder()
+ .config("spark.sql.extensions",
+ "org.lance.spark.knn.extensions.LanceKnnSparkSessionExtensions")
+ .config("spark.lance.knn.indexedNearestByJoin.enabled", "true")
+ .getOrCreate()
+```
+
+Then any:
+
+```sql
+SELECT *
+FROM left l [INNER | LEFT OUTER] JOIN right_lance r
+ APPROX NEAREST k BY {DISTANCE | SIMILARITY} f(l.vec, r.vec)
+```
+
+— rewrites automatically, where `f` is `vector_l2_distance` (with `DISTANCE`),
+`vector_cosine_similarity` (with `SIMILARITY`), or `vector_inner_product` (with `SIMILARITY`).
+Output schema matches `NearestByJoin.output` (no score column — the user can compute that in a
+project if needed).
+
+The extension can coexist with the connector's `LanceSparkSessionExtensions` in a comma-
+separated `spark.sql.extensions` value.
+
+## Reviewer's reading order
+
+Reviewers should start with **[`REVIEWER_GUIDE.md`](REVIEWER_GUIDE.md)** —
+it's the up-to-date "start here → engine → primitives" reading path, with
+a test map and trust-but-verify checklist. The below lists the specific
+Catalyst-side files to read for a review of the SQL path
+(`lance-spark-knn-4.2_2.13`), in order:
+
+1. **`catalyst/IndexedNearestByJoinRule.scala`** — the load-bearing pattern
+ match. The `for-yield` in `rewriteIfApplicable` short-circuits on every
+ precondition, then builds the three logical plans (probe/merge/materialize)
+ wrapped in a `Project` that restores the original `NearestByJoin.output`.
+2. **`extensions/LanceKnnSparkSessionExtensions.scala`** — smallest file.
+ Confirm the wiring uses `injectPostHocResolutionRule` (not
+ `injectOptimizerRule`; see reasoning in the Phase 2 section above) and
+ registers the shared `LanceKnnStagedStrategy`.
+
+For the shared logical plans, physical execs, and strategy, see the DataFrame
+API path in `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/`.
+
+## Test coverage
+
+| Test class | What it covers |
+|---|---|
+| `internal/TopKHeapTest` | Metric-aware ordering, eviction, drain order, merge. Pure unit. |
+| `internal/LanceFragmentsTest` | Round-robin and LPT bin-packing math. |
+| `internal/LanceProbeValidationTest` | Real Lance dataset; brute-force-equivalence oracle. |
+| `IndexedNearestJoinTest` | Phase 0 e2e; brute-force-equivalence oracle. Refine-factor wiring. |
+| `IndexedNearestJoinPlanShapeTest` | 3-exec plan shape (LanceProbe / LanceMerge / LanceMaterialize / Exchange all present). |
+| `IndexedNearestJoinAqeVisibilityTest` | Exchange hashpartitioning on `_leftId`; AQE wrap; no `!` missingInput prefix. |
+| `IndexedNearestJoinConsumerShapeTest` | `count()`, `agg(count("*"))`, `select(lit(1))`, `collect()` all succeed — regression for the ColumnPruning crash. |
+| `IndexedNearestJoinCorrectnessTest` | Brute-force oracle at 1K × 100 × dim 16 × K 10. |
+| `IndexedNearestJoinJitStressTest` | crossJoin warmup + 20 iterations at 10K × 100 × dim 128 — durability at benchmark scale. |
+| `internal/staged/StagedPlansReferencesTest` | Structural pin on `references = child.outputSet` override. |
+| `IndexedNearestJoinFragmentGroupingTest` | Phase 1.5 oracle equivalence + plan-shape with `probeParallelism > 1`. |
+| `catalyst/IndexedNearestByJoinRuleTest` | Phase 2 rule pattern-match (positive + negative cases); asserts the emitted `Project(LanceMaterialize(LanceMerge(LanceProbe)))` tree shape. |
+| `catalyst/IndexedNearestByJoinE2ETest` | SQL `APPROX NEAREST` against real Lance, Spark 4.2-SNAPSHOT. Rule on/off plus WHERE-pushdown oracle equivalence. |
+
+## Benchmark validation
+
+Both benchmarks (`IndexedNearestJoinBenchmark` for DataFrame, `IndexedNearestByJoinSqlBenchmark`
+for SQL) run a **pre-timing validation step** that compares EVERY config — including the
+slow baseline / rule-OFF path — against an in-memory brute-force oracle on a 16-row left
+subset. The benchmark `sys.error`'s out before timing if any config disagrees with ground
+truth, so the quoted speedups are on equivalent results.
+
+The 16-row subset keeps the slow baseline tractable: `O(16 × |R|)` cross-product evaluations
+is sub-second even at medium scale (16 × 1M = 16M pair evaluations). The full benchmark is
+unchanged in scope; the validation step is small relative to the timed runs.
+
+### Latest validated numbers — Apple M5 Max, 18 cores, 48 GB
+
+DataFrame benchmark (small scale, |R|=100K, |L|=100, dim=128, K=10):
+
+```
+Sanity check: all 5 configs match brute-force oracle (sample size: 16) ✅
+A: Spark crossJoin baseline 109,373 ms 1.00×
+B: Phase 0/1 (probeParallelism=1) 180 ms 608×
+C: Phase 1.5 (probeParallelism=4) 288 ms 380×
+D: Phase 1.5 (probeParallelism=8) 276 ms 396×
+E: Phase 1.5 (G=8, skew-balanced) 277 ms 395×
+```
+
+SQL benchmark (small scale, same shape):
+
+```
+Sanity check: rule ON and rule OFF agree on top-K (sample size: 16) ✅
+A: rule OFF (Spark RewriteNearestByJoin) 3,728 ms 1.00×
+B: rule ON (3-exec staged + shared strategy) 214 ms 17.4×
+```
+
+Why the SQL number is smaller: Spark's `RewriteNearestByJoin` is itself optimized — it
+lowers to a `min_by_k` heap aggregate over a `BroadcastNestedLoopJoin`, which avoids
+materializing all `|L|×|R|` pairs in JVM memory. The remaining 17.4× comes from delegating
+per-fragment scans to Lance's native vector kernels (AVX-512 / NEON) instead of evaluating
+`vector_l2_distance` row-by-row in the JVM.
+
+See `BENCHMARK_RESULTS.md` for the medium-scale numbers, the honest "Phase 1.5 doesn't help
+locally" finding, and where fragment-grouping would win (distributed cluster, object-store-
+backed Lance, or right side too large for one machine).
+
+### Cluster validation — OSS Spark 3.5
+
+The local numbers above are Apple M5 Max single-machine. The indexed path has also been
+validated on a real distributed OSS Spark 3.5 cluster (8 × 4 core × 16 GB executor
+pods, multi-tenant infrastructure):
+
+- **CohereLabs `wikipedia-2023-11-embed-multilingual-v3` (dim=1024, real embeddings),
+ 1K base × 50 queries:** indexed path is **100–200× faster than Spark crossJoin**
+ (7-iter median: 64.9 s → 406 ms at E: Phase 1.5 G=8 skew-balanced, = 160×). Exact
+ multiplier varies ±20% across runs due to multi-tenant CPU contention; the
+ order-of-magnitude story is robust. Measured with `write.format("noop")` timing
+ sink and a 16-row brute-force oracle gating correctness before each run. The
+ speedup **grows** with dim (128 → 1024) because Lance's SIMD kernel advantage
+ widens vs Spark's JVM UDF overhead.
+
+- **Synthetic medium (|R|=1M, |L|=1000, dim=128) on the same cluster:** Phase 1.5 D/E
+ **win** at ~55 s vs Phase 0/1 B at ~92 s. This is the OPPOSITE of the local-laptop
+ finding — when `probeParallelism == numFragments`, cross-machine parallelism across 8
+ executor JVMs beats Lance-internal threading on one machine. Two independent runs agree
+ within 2%.
+
+- **SIFT1M IVF-FLAT recall@10:** 0.98 at nprobes=16, 0.999 at nprobes=64 — within noise of
+ published FAISS numbers.
+
+The cluster results also surfaced two production-grade constraints documented in
+`BENCHMARK_RESULTS.md` § "Cluster benchmarks": a vendor-specific executor-count knob
+(some managed Spark distributions ignore `spark.executor.instances` in
+standalone-per-app mode and require their own knob), and
+`spark.{driver,executor}.extraClassPath` (required when the cluster's bundled Arrow-C
+version is older than ours and would otherwise be first on the classpath).
+
+**Real-backend e2e** (`IndexedNearestByJoinE2ETest`) — Phase 2 module ships a SQL-level e2e
+test against an actual Lance dataset on Spark 4.2-SNAPSHOT. The trick: lance-spark-4.1's
+source compiles cleanly against 4.2-SNAPSHOT and the resulting jar runs on 4.2's DSv2 API,
+so we recompile it locally and use it as the runtime Lance reader. Two test cases:
+
+- Rule on → the 3-exec chain (`LanceProbe` / `LanceMerge` / `LanceMaterialize`) is in the
+ executed plan, results match brute-force oracle.
+- Rule off → falls through to Spark's `RewriteNearestByJoin` (cross-product + `min_by_k`),
+ results still match oracle. Confirms the opt-in fallback path doesn't break correctness.
+
+## What this is NOT
+
+- **Not a brute-force fallback.** That's `RewriteNearestByJoin` in Spark, kept for EXACT
+ queries and unindexed cases.
+- **Not a re-implementation of Lance's index.** We delegate every probe to Lance.
+- **Not a vector-DB-style serving layer.** This is for batch joins inside Spark pipelines.
diff --git a/lance-spark-knn_2.12/IMPL_PLAN.md b/lance-spark-knn_2.12/IMPL_PLAN.md
new file mode 100644
index 000000000..096000430
--- /dev/null
+++ b/lance-spark-knn_2.12/IMPL_PLAN.md
@@ -0,0 +1,174 @@
+# Indexed Nearest-By-Join — Implementation Plan
+
+This module adds an indexed approximate-nearest-neighbor (ANN) join strategy on top of Lance's vector indexes, exposed through Spark. It complements Spark's built-in `NearestByJoin` (Spark 4.x), which today only has a brute-force cross-product rewrite.
+
+## Goal
+
+For SQL like:
+
+```sql
+SELECT * FROM queries q
+LEFT OUTER JOIN documents d
+APPROX NEAREST 10 BY DISTANCE l2_distance(q.vec, d.vec)
+```
+
+— or its Scala/Python DataFrame equivalent — execute via per-fragment Lance index probes plus a Spark-side merge, instead of an `O(|L| * |R|)` cross-product.
+
+## Why this works on Lance specifically
+
+- Lance's vector indexes (IVF-PQ, HNSW) are **fragment-local**. Each fragment carries its own index files; per-fragment probes are independent.
+- Lance's Java API exposes single-vector nearest search via `org.lance.ipc.Query` + `LanceScanner.create(dataset, ScanOptions, allocator)`. lance-spark already uses this for single-query reads (`LanceFragmentScanner.create`).
+- Lance row IDs (`_rowid` virtual column) make late materialization cheap: the probe phase emits row references, the materialize phase fetches full rows by ID. (We initially used `_rowaddr` but switched to `_rowid` because the indexed nearest-search path materializes `_rowid` only — see DESIGN.md "Why `_rowid` not `_rowaddr`".)
+- Snapshot pinning (Lance versioning) gives consistent results across distributed tasks.
+
+## Three-phase distributed design
+
+```
+ProbeExec (one task per fragment-group)
+ ├── opens Lance dataset, restricted to assigned fragments
+ ├── per left row: maintains local top-K' heap across owned fragments (map-side combine)
+ └── emits (left_id, [row_addr, score]) ← refs only, no payload
+ │
+ Exchange (hash by left_id) ← lightweight: O(|L| × N × K)
+ │
+ ▼
+MergeExec
+ ├── K-way merge across N task contributions per left_id
+ └── emits (left_id, [chosen_row_addr, score])
+ │
+ ▼
+MaterializeExec
+ ├── for each chosen row_addr: fetch right row from Lance (random access)
+ └── emits full join rows
+```
+
+The merge is across at most `N = numTasks` contributions per left row (not across `F` fragments) because each task does a map-side combine — this is the difference between a manageable shuffle (~hundreds of GB at scale) and a catastrophic one (~tens of TB).
+
+## Rollout phases
+
+### Phase 0 — Pure Scala API (this module's first deliverable)
+
+A `IndexedNearestJoin.apply(left, lanceTablePath, ...)` Scala function that takes a left DataFrame and produces a result DataFrame, with no Catalyst integration whatsoever. Pure RDD primitives (`mapPartitions`, `reduceByKey`) on top of a `LanceProbe` JVM helper that wraps the Lance Java vector-search API.
+
+This phase proves:
+- Per-fragment Lance probes work from JVM tasks
+- The shuffle volume math holds up
+- Recall vs. brute-force is acceptable
+- The deferred-materialization pattern delivers the expected payload-size win
+
+Phase 0 ships **before** any Spark version dependency is needed (works on any Spark version lance-spark supports), making benchmarks possible without committing to Spark 4.x-only code paths.
+
+### Phase 1 — Refactor execution into Spark physical operators
+
+Same logic, but as proper `SparkPlan` subclasses (`IndexedNearestProbeExec`, `IndexedNearestMergeExec`, `IndexedNearestMaterializeExec`). The Phase 0 API becomes a thin wrapper that constructs the exec triple. Pure refactor — no new functionality, no new tests beyond plan-shape tests.
+
+### Phase 2 — Catalyst integration (Spark 4.x only)
+
+A `postHocResolutionRule` that pattern-matches `NearestByJoin(approx = true, ...)` over a qualifying Lance scan with a vector index, replacing it with an `IndexedNearestByJoin` logical node. A custom strategy maps that to the physical execs from Phase 1. Wired via `SparkSessionExtensions`.
+
+Critical detail: Spark's `RewriteNearestByJoin` runs in the optimizer's first batch (`FinishAnalysis`). `injectOptimizerRule` would fire **after** the rewrite and miss the `NearestByJoin`. Only `injectPostHocResolutionRule` (which runs before the optimizer entirely) gives us access to the unrewritten operator.
+
+After Phase 2, users get the indexed path automatically from `APPROX NEAREST` SQL queries — no API change.
+
+### Phase 3 — Hardening
+
+Cost gate, recall tuning (`overfetch`, `refineFactor` re-rank pass), filter pushdown into the probe (`prefilter`), skew handling, full docs, cross-version build matrix.
+
+## Open questions / risks
+
+These are tracked as Phase 0 validation tasks:
+
+1. **Per-task Lance dataset open cost.** Each probe task opens a Lance dataset. If expensive, need a per-executor singleton cache.
+2. **`Query` per-call overhead.** `Query` is constructed per left row. If hidden setup cost exists, batch internally or push for batched Lance API.
+3. **`_rowid` filter performance.** Materialization relies on `WHERE _rowid IN (...)` resolving to point fetches inside Lance. Confirm this isn't a scan-with-post-filter. (Original Phase 0 question — switched from `_rowaddr` to `_rowid` for indexed-path compatibility; same point-fetch semantics.)
+4. **Concurrent fragment scans within a task.** Can we run multiple `LanceScanner` instances in parallel within one task, or is Lance single-threaded at the dataset level?
+5. **`Query.queryParallelism` semantics.** May handle some intra-query parallelism for free.
+6. **`Dataset.listIndexes()` availability and metadata.** Required for Phase 2 capability detection.
+7. **Snapshot pinning across stages.** All three execs must read at the same Lance version. lance-spark's read options carry this; needs explicit test.
+
+## What this module is NOT
+
+- Not a brute-force fallback — that's `RewriteNearestByJoin` in Spark, kept for `EXACT` queries and unindexed cases.
+- Not a re-implementation of Lance's index. We delegate every probe to Lance.
+- Not a vector-DB-style serving layer. This is for batch joins inside Spark pipelines.
+
+## Status
+
+| Phase | Status |
+|---|---|
+| 0 — Scala API | Done (`IndexedNearestJoin.apply`, oracle test, LanceProbe primitive) |
+| 1 — Staged RDD pipeline | Done (probe / merge / materialize stages, plan-shape test) |
+| 1.5 — Fragment-grouping | Done (`probeParallelism` parameter; LanceFragments enumeration; oracle equivalence + 2-shuffle plan-shape test) |
+| 2 — Catalyst integration | Done (`lance-spark-knn-4.2_2.13` module: rule, logical, physical, extension; 18 tests including SQL e2e against real Lance + Spark 4.2-SNAPSHOT, with prefilter pushdown coverage) |
+| 3 — Hardening | Partially done — see "What's left" below |
+| 3.x — Explicit physical operators (DataFrame API) | **Done.** Production path is now `LanceProbeExec → ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec`. `df.explain()` shows all four nodes under `AdaptiveSparkPlanExec`; with AQE on, `AQEShuffleRead coalesced` appears on the merge shuffle. An early development iteration hit reproducible SIGSEGV / AssertionError on `count()`-style consumers — misdiagnosed as a JVM-aarch64 bug; the real cause was Catalyst's `ColumnPruning` rule inserting `Project(Nil)` wrappers between the custom nodes when downstream consumers referenced no columns, which codegen'd to 0-field `UnsafeRow`s and crashed `ProbedLeftCodec.Decoder` in interpreter mode (AssertionError) / C2 mode (SIGSEGV). Fix: `LanceMergeLogicalPlan` and `LanceMaterializeLogicalPlan` override `lazy val references = child.outputSet`, which short-circuits `ColumnPruning`'s `!child.outputSet.subsetOf(references)` guard. 60 tests in lance-spark-knn_2.12. See "3-exec staged split — root cause and fix" below for details. |
+| 3.x — `df.kNearestJoin` extension | Done (`LanceKnnImplicits._`; works on Spark 3.5 / 4.0 / 4.1 / 4.2+; URI auto-extracted from right DataFrame's analyzed plan; non-Lance right side fails fast with `IllegalArgumentException`) |
+| Benchmarks | Done (608× DataFrame, 17.4× SQL — both oracle-validated) |
+
+See `PHASE_PROGRESS.md` for the resume-without-context overview, file inventory, and the
+substantive limitations carried forward from each phase.
+
+## What's left (Phase 3.x)
+
+Phase 3 done so far: `refineFactor` /
+`ef` recall knobs, row-count-aware fragment grouping for skew (`balanceFragmentsByRowCount`).
+
+Phase 3.x — outstanding work:
+
+| Item | Module | Notes |
+|---|---|---|
+| Cost gate replaces opt-in flag | `lance-spark-knn-4.2_2.13` | Heuristic deciding indexed vs. brute-force based on `\|R\|` cardinality and right-side selectivity. Until then `spark.lance.knn.indexedNearestByJoin.enabled` is the gate. |
+| ~~`prefilter` pushdown into Lance probe~~ | DONE | Rule detects `Filter(cond, lance)` (and `Project(, Filter(...))` for `SELECT *` shape), translates the predicate to a Lance SQL filter string, and threads it through `LanceProbeStage.Conf.prefilter` → `ScanOptions.filter()`. Translator handles bare `attr literal` for `=`, `!=`, `<`, `<=`, `>`, `>=`, plus `IN`, `IS [NOT] NULL`, and `AND`/`OR`/`NOT`. Anything else (UDFs, computed expressions, predicates touching the LEFT input) → rule REFUSES the rewrite. No partial pushdown. |
+| ~~Real recall test against IVF-PQ-indexed dataset~~ | DONE | `IndexedNearestJoinIvfPqRecallTest` builds an IVF-PQ index via Lance Java's `Dataset.createIndex` and measures recall@K. With 1024 rows × dim 32 × 4 IVF partitions: recall@10 = 0.73 at defaults, **1.00 with `refineFactor = 8`** (exact-distance re-rank recovers all true neighbors). Surfaced and fixed a real bug in the process — Lance's indexed scan materializes `_rowid` not `_rowaddr`, so the whole pipeline switched to `_rowid` (works on both paths). |
+| ~~`LanceProbe.vectorColumn` cleanup~~ | DONE | `vectorColumn` moved from constructor to per-call `probe()` arg. Materialize stage no longer constructs the probe with a placeholder. |
+| ~~Filter pushdown's interaction with `prefilter = true`~~ | RESOLVED | We always set `prefilter = true` and call `ScanOptions.filter(sql)` from `LanceProbe.probe`. Lance applies the predicate before the index lookup, so the top-K is computed only over matching rows — confirmed by the e2e WHERE-pushdown test against a brute-force-on-filtered oracle. |
+| Spark version matrix for the connector | build infra | The Phase 2 module pins to 4.2-SNAPSHOT. Once `NearestByJoin` lands in a release, re-pin and add 4.2_2.13 module path. Phase 1.5 / Phase 0/1 work on any Spark 3.4+ via the existing connector modules. |
+| ~~Real-backend e2e test for Phase 2~~ | DONE | `lance-spark-knn-4.2_2.13/src/test/.../IndexedNearestByJoinE2ETest.scala`. Recompile `lance-spark-4.1_2.13` against `4.2.0-SNAPSHOT` (its source compiles cleanly against 4.2; runtime API is compatible) and use it as the test-scope Lance reader. Three test cases: rule-on goes through the 3-exec staged chain and matches oracle; WHERE-pushdown round-trips the prefilter and matches the filtered oracle; rule-off falls through to Spark's `RewriteNearestByJoin` and still matches oracle. |
+| ~~Cross-version DataFrame API parity~~ | DONE (compile+test) / TODO (CI matrix) | Module compiles and tests pass against Spark **3.5 AND 4.0**. Single-source validated: flip `spark.version=${spark40.version}` + `arrow.version=${arrow18.version}` + swap test runtime `lance-spark-4.0_2.13`, 41/41 tests pass. One source fix was required — `LanceKnnDatasetBridge` used `org.apache.spark.sql.Dataset.ofRows` which moved to `org.apache.spark.sql.classic.Dataset.ofRows` in Spark 4.0. Replaced with a reflection-based lookup that tries both packages; one cache-miss per Spark session. CI matrix against 3.4 / 3.5 / 4.0 / 4.1 still TODO. End-to-end cluster validation done on OSS Spark 3.5.4. |
+| ~~Production-shape benchmark (real embeddings)~~ | DONE | `WikipediaKnnPerfBenchmark` uses CohereLabs `wikipedia-2023-11-embed-multilingual-v3` (dim=1024). On 8 × 4c/16g OSS Spark 3.5 cluster: indexed path is **100-200× faster than Spark crossJoin** at small scale (7-iter median 160×; noop sink + oracle-verified). Speedup grows with dim (128 → 1024) because Lance's native SIMD advantage widens vs Spark's JVM UDF overhead. `CohereWikiRecallBenchmark` complements with IVF-FLAT recall on the same corpus: 95% recall at nprobes=16, 99% at nprobes=64, 10-16 ms/query. Numbers in `BENCHMARK_RESULTS.md` § "Cluster benchmarks". |
+| ~~SIFT1M ANN-benchmark validation~~ | DONE | `SiftRecallBenchmark` against the canonical `ftp.irisa.fr/.../sift.tar.gz` corpus. OSS Spark 3.5 cluster, 1M × dim 128, IVF-FLAT 256 partitions: recall@10 = 0.98 at nprobes=16, 1.00 at nprobes=64. Within noise of published FAISS numbers. |
+| ~~Sustained concurrent load soak~~ | DONE (harness) / PARTIAL (data) | `IndexedNearestJoinSoakTest` runs N concurrent queries for M minutes while sampling driver heap + GC metrics. 10-min smoke on OSS Spark 3.5 cluster: 492 queries at 8 concurrency, 0 failures, heap stable 163–266 MB (no drift). Harness has a known bookkeeping bug (post-deadline queued queries are counted as failures after pool drain races `spark.stop()`); production qualification at 2–4 hours is deferred. |
+| ~~Real-world-embedding benchmark~~ | DONE | `ClusteredEmbeddings.generate` produces clustered Gaussian-mixture vectors on the unit sphere — the geometry of typical sentence-transformer / image-feature embeddings. Used in `IndexedNearestJoinIvfPqRecallTest.testClusteredEmbeddingsRecallSurvives`, which measures recall@K on production-shaped data and asserts it clears 0.5 at default IVF-PQ settings. Both uniform and clustered recall numbers are printed for comparison; we do NOT assert `clustered >= uniform` because Lance's IVF k-means init is non-deterministic across JVM sessions and the run-to-run noise on a 1024-row dataset routinely exceeds the structural advantage. A reliable comparative would need much larger N or seed-averaging. |
+| ~~AQE-aware partition sizing for the merge stage~~ | **DONE** | `LanceMergeExec` declares `requiredChildDistribution = ClusteredDistribution(leftIdAttr)`; `EnsureRequirements` auto-inserts a `ShuffleExchangeExec` between probe and merge. With AQE enabled, Catalyst wraps the stage under `AdaptiveSparkPlanExec`, applies `CoalesceShufflePartitions` / `OptimizeSkewJoin` / `OptimizeShuffleWithLocalRead`, and the final plan visibly shows `AQEShuffleRead coalesced` on the merge-side shuffle. Verified by `IndexedNearestJoinAqeVisibilityTest`. |
+| Per-task `LanceProbe` reuse / connection pooling | `lance-spark-knn_2.12` | Currently each Spark task opens its own dataset. For very small partitions this dominates cost. A per-executor singleton cache could amortize. |
+| Skew handling for left side too | `lance-spark-knn_2.12` | Phase 1.5 / Phase 3 balance the right side's fragment groups but the left RDD's natural partitioning can still be skewed. Repartition by `leftId` before probe is the obvious next move. |
+
+## 3-exec staged split — root cause and fix
+
+The three-operator staged split (`LanceProbeExec → ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec`) had a debugging detour during development — reproducible `AssertionError: index (0) should < 0` / SIGSEGV in `UnsafeRow.getLong` on JVM-aarch64 was initially blamed on the JVM. That was wrong — this section preserves the investigation for anyone who hits similar "it looks like a JVM bug but it's Catalyst" symptoms in future work.
+
+**Initial JVM-aarch64 diagnosis was wrong.** The crash reproduces on JVM-aarch64 but it isn't a JVM bug. A multi-step isolation found it:
+
+1. `InterStageShuffleReproTest` — synthetic rows at the staged codec's schema, through `repartition(_leftId)` + 100-iteration JIT-stress. **Passes.** Rules out Spark's UnsafeRow shuffle + our schema.
+2. `InterStageShuffleWithLanceReproTest` — same but with rows sourced from a real Lance scan. **Passes.** Rules out the Lance→Spark boundary.
+3. `StagedExecDirectDriveReproTest` — directly drives the staged execs at tiny scale (4 left × 8 right) via `count()`. **Crashes deterministically on first invocation.** Not a JIT issue at all.
+
+Diagnostic instrumentation on `LanceMaterializeExec.doExecute` caught a 0-field `UnsafeRow` arriving from `child.execute()`. `LanceMergeExec` emits correctly-shaped 4-field rows; something between them truncates to 0 fields. Dumping `df.queryExecution.executedPlan` for the `count()` case showed:
+
+```
+*(2) HashAggregate(partial_count(1))
++- *(2) Project ← 0-column projection!
+ +- LanceMaterialize
+ +- *(1) Project ← another 0-column projection!
+ +- LanceMerge
+ +- Exchange hashpartitioning(_leftId)
+ +- LanceProbe
+```
+
+**Root cause:** Catalyst's `ColumnPruning` optimizer rule. Its `Aggregate(child)` guard is `!child.outputSet.subsetOf(a.references)`. For `count(*)`, `Aggregate.references` is empty. The custom logical plans (`LanceMergeLogicalPlan`, `LanceMaterializeLogicalPlan`) inherited `references` from `QueryPlan`'s default — empty for pass-through nodes. That made `child.outputSet.subsetOf(references)` false, and `prunedChild` inserted `Project(Nil)` wrappers. Spark's codegen'd `ProjectExec(Nil)` emits 0-field `UnsafeRow`s. `ProbedLeftCodec.Decoder.decode` reads `ir.getLong(0)` on those rows — in interpreter mode → `AssertionError`; in C2-compiled code → the assertion is elided and the read hits unmapped memory → SIGSEGV. Hence the misleading "JVM-aarch64 bug" blame on the revert.
+
+**Fix.** `LanceMergeLogicalPlan` and `LanceMaterializeLogicalPlan` override `lazy val references = child.outputSet`. This makes `child.outputSet.subsetOf(references)` trivially true (equality ⇒ subset), short-circuiting `ColumnPruning`'s guard. No `Project(Nil)` ever gets inserted between the custom nodes. `StagedPlansReferencesTest` pins this invariant structurally.
+
+**Verification.**
+
+- Unit-level: `StagedPlansReferencesTest` (3 tests) pins the override on both logical plans and explicitly checks `ColumnPruning`'s subset guard.
+- Plan-shape: `IndexedNearestJoinAqeVisibilityTest` (5 tests) asserts `ShuffleExchangeExec hashpartitioning(_leftId)` is in the executed plan (AQE on and off), AQE wraps with `AdaptiveSparkPlanExec`, all three custom execs appear in the tree, no `!` missingInput prefix.
+- Consumer-shape: `IndexedNearestJoinConsumerShapeTest` (4 tests) — `count()`, `agg(count("*"))`, `select(lit(1))`, `collect()` all succeed. These are the exact shapes that crashed the reverted code.
+- Correctness: `IndexedNearestJoinCorrectnessTest` — recall = 1.0 against brute-force oracle at 1000 right × 100 left × dim 16 × K 10.
+- Durability: `IndexedNearestJoinJitStressTest` — crossJoin JIT warmup + 20 iterations of `collect()` and `count()` at 10K right × 100 left × dim 128 × K 10. No SIGSEGV.
+
+All 60 tests pass in `lance-spark-knn_2.12`.
+
+**AQE now engages on the merge shuffle** — `AQEShuffleRead coalesced` is visible on the post-merge plan when AQE is enabled. This was the whole point of the staged split, and it now works.
+
+**Caveat: plan-level integration, not exec-level.** `LanceMaterializeExec` doesn't declare `requiredChildDistribution`, so the Exchange still lives as a child of `LanceMergeExec`, not inside a larger Catalyst-optimized tree. The fragment-grouped probe path's internal `partitionBy` shuffle is still AQE-invisible.
diff --git a/lance-spark-knn_2.12/NEARESTBYJOIN_ANN_PROPOSAL.md b/lance-spark-knn_2.12/NEARESTBYJOIN_ANN_PROPOSAL.md
new file mode 100644
index 000000000..68c64bf9c
--- /dev/null
+++ b/lance-spark-knn_2.12/NEARESTBYJOIN_ANN_PROPOSAL.md
@@ -0,0 +1,329 @@
+# Indexed `NearestByJoin` via Lance — a PoC for SPARK-56395
+
+**Context:** [SPARK-56395](https://issues.apache.org/jira/browse/SPARK-56395) adds
+`NearestByJoin` as a first-class logical operator in Spark 4.2, currently lowered by the
+built-in `RewriteNearestByJoin` rule to a cross-product + `min_by_k` aggregate
+(`BroadcastNestedLoopJoin` + heap aggregate). The [design
+doc](https://docs.google.com/document/d/1opFVcQJgEWDWUVB7uVlFMlNomRwxqRu8iW0JmvCvxF0/)
+calls out that a true indexed path is out of scope for that ticket but is the natural
+follow-up — "any vector-index-backed data source should be able to intercept
+`NearestByJoin` before the cross-product rewrite and substitute an index-backed plan."
+
+This repo (`lance-spark-knn`) is **one concrete implementation** of that hook, for Lance
+datasets. The purpose of this doc is to share the shape of that implementation with Spark
+maintainers as a reference point, and to sketch how the same pattern could extend to
+non-Lance formats (parquet, delta) via the Lance **sidecar index** pattern — without
+Spark itself having to ship a vector-index backend.
+
+**Not a proposal to change apache/spark.** The PoC lives entirely in an ecosystem
+connector. It depends on only the public extension points SPARK-56395 provides (and one
+pre-existing one, `injectPostHocResolutionRule`). The discussion questions at the end
+are the places where a small Spark-side change *might* help; they are intentionally left
+open for maintainers to weigh in on.
+
+## What's in apache/spark, what's in this connector
+
+| Apache/Spark side | Ecosystem connector side |
+|---|---|
+| `NearestByJoin` logical plan (SPARK-56395) | `IndexedNearestByJoinRule` Catalyst rule (postHocResolutionRule) |
+| `RewriteNearestByJoin` optimizer rule (default: brute-force) | Three `LogicalPlan` + `SparkPlan` nodes that form a Catalyst-visible staged pipeline |
+| `VectorL2Distance` / `VectorCosineSimilarity` / `VectorInnerProduct` ranking expressions | `LanceKnnStagedStrategy` lowering the three logical plans to three execs |
+| Extension points: `injectPostHocResolutionRule`, `injectPlannerStrategy` | Registration via `spark.sql.extensions` — opt-in per-session |
+
+No changes to apache/spark are required for the Lance-specific implementation. The
+hookpoints are what SPARK-56395 and the older extensions API already provide.
+
+## What the connector does — shape summary
+
+Full details: [`DESIGN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/DESIGN.md).
+This section is the one-page summary for context.
+
+### 1. Catalyst rule
+
+```
+Rule: injectPostHocResolutionRule (NOT injectOptimizerRule)
+Pattern: NearestByJoin(approx=true, recognized-ranking, Lance-scan-on-right)
+Action: rewrite to a 3-logical-plan tree
+```
+
+**Why `postHocResolutionRule` and not `injectOptimizerRule`** — Spark's
+`RewriteNearestByJoin` runs in the optimizer's `FinishAnalysis` batch, which precedes the
+`operatorOptimizationBatch` that `injectOptimizerRule` adds rules to. An injected
+optimizer rule fires after `RewriteNearestByJoin` has already rewritten the operator to
+a cross-product; there's nothing left to pattern-match. `injectPostHocResolutionRule`
+runs after analysis but before any optimizer batch — the only injection point that sees
+the unrewritten `NearestByJoin`.
+
+This is the **single load-bearing constraint** any future engine needs to respect to
+substitute an alternative physical strategy for `NearestByJoin`. It would be worth
+calling out in the `NearestByJoin` scaladoc.
+
+### 2. Three-stage plan
+
+```
+left.logicalPlan
+ LanceProbeLogicalPlan (per-task probe: left vectors → top-K (rowId, score))
+ LanceMergeLogicalPlan (co-locate by _leftId, bounded TopKHeap.merge down to final K)
+ LanceMaterializeLogicalPlan (point-fetch right rows by _rowid, assemble join rows)
+ [Project to drop the trailing score attr so output matches NearestByJoin.output]
+```
+
+Lowered via the shared strategy to:
+
+```
+LanceProbeExec
+ ↓
+ShuffleExchangeExec hashpartitioning(_leftId) ← Catalyst inserts this via
+ ↓ EnsureRequirements (driven by
+LanceMergeExec LanceMergeExec.requiredChildDistribution
+ ↓ = ClusteredDistribution(_leftId))
+LanceMaterializeExec
+```
+
+Wrapped by `AdaptiveSparkPlanExec` when AQE is on. The Exchange is AQE-visible, so
+`CoalesceShufflePartitions` / `OptimizeSkewJoin` / `OptimizeShuffleWithLocalRead` all
+engage on the merge shuffle.
+
+### 3. Prefilter pushdown
+
+`SELECT * FROM lance WHERE p APPROX NEAREST K BY ...` — the rule detects the
+`Filter(p, lance)` shape, translates the predicate to a Lance SQL filter string
+(conservative: bare `attr literal`, `IN`, `IS [NOT] NULL`, `AND`/`OR`/`NOT`), and
+threads it into the probe. Lance applies the filter **before** the index lookup, so
+top-K is computed over matching rows only — avoiding the
+"index-returns-K-rows-all-filtered-out-post-join" recall bug.
+
+Untranslatable predicates (UDFs, computed expressions, left-side references) cause the
+rule to **refuse** the rewrite and fall through to Spark's brute-force cross-product.
+Refusal — not partial pushdown — because dropping a residual would silently change query
+semantics.
+
+### 4. Correctness and scale validation
+
+- `recall=1.0` vs brute-force oracle on 1K × 100 × dim=16 (unindexed Lance) and on
+ IVF-PQ-indexed datasets with `refineFactor=8`.
+- 100–200× faster than Spark `RewriteNearestByJoin` cross-product on real Cohere Wikipedia
+ embeddings (dim=1024, 1K × 50), on an 8 × 4-core / 16-GB OSS Spark 3.5 cluster.
+- SIFT1M IVF-FLAT recall@10 = 0.98 at nprobes=16, 1.00 at nprobes=64 — within noise of
+ published FAISS numbers.
+- Local M5 Max (100K × 100 × dim=128): 17× vs Spark's brute-force. Smaller gap than the
+ headline because Spark's built-in is already `min_by_k` over
+ `BroadcastNestedLoopJoin` — not a full `|L|×|R|` materialization. The remaining 17×
+ comes from Lance's native-SIMD distance kernels beating Catalyst's JVM expression
+ evaluation per pair.
+
+Full numbers: [`BENCHMARK_RESULTS.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/BENCHMARK_RESULTS.md).
+
+## Extending the same shape to parquet / delta — Lance sidecar pattern
+
+The interesting part, and the reason to share this doc.
+
+### The observation
+
+Only the **probe** and **materialize** stages are format-specific. They open a Lance
+dataset and call Lance's native nearest-search / row-id point-fetch APIs. The rest of
+the pipeline — the Catalyst rule, the three logical/physical plans, the inter-stage
+schema, the bounded-merge heap, the Exchange insertion — is format-agnostic Catalyst
+plumbing.
+
+So a user who has their **primary data** in parquet or delta can still get the indexed
+path by building a Lance **sidecar** keyed by a shared row identifier:
+
+```
+Primary data: a parquet or delta table with (user_id, name, ..., embedding: array)
+
+Sidecar: a Lance dataset with just (user_id, embedding) — built once, maintained
+ incrementally as the primary table grows. Vector index (IVF-PQ or HNSW)
+ is built on the Lance sidecar.
+
+Query time: APPROX NEAREST K BY DISTANCE vector_l2_distance(q.vec, d.embedding)
+ with the rule configured to probe the sidecar's Lance URI, then
+ materialize by a foreign-key join to the primary parquet/delta table:
+
+ SELECT q.qid, p.name
+ FROM queries q INNER JOIN parquet_primary p
+ APPROX NEAREST K BY DISTANCE vector_l2_distance(q.lvec, p.embedding)
+```
+
+The rule's right-side detection (currently "is the scan a Lance DSv2 relation?") would
+need one small extension: "is there a registered sidecar index for this table?" The
+registration mechanism is a small catalog-side detail — could be a table property
+(`lance.sidecar.uri`), a config map, or a Spark session extension. The probe then runs
+against the sidecar; materialize runs against the **primary** table, joining on the
+shared key.
+
+### What changes vs. the Lance-native implementation
+
+The rule's detection predicate expands from "right side is Lance" to "right side is Lance
+*or* has a registered Lance sidecar":
+
+```scala
+case rel: DataSourceV2Relation if isLanceTable(rel.table) =>
+ LanceScanInfo(uri = rel.options.get("path"), ...)
+case rel: DataSourceV2Relation if hasRegisteredSidecar(rel.table) =>
+ val sidecar = lookupSidecar(rel.table)
+ // probe against sidecar.uri, materialize against rel's primary URI
+ SidecarScanInfo(sidecarUri = sidecar.uri, primaryRel = rel, joinKey = sidecar.joinKey, ...)
+```
+
+Materialize changes from "point-fetch from Lance by `_rowid`" to "recover the primary
+table's payload for the surviving top-K row keys." The right mechanism depends on the
+format:
+
+#### Why parquet/delta can't do cheap point-fetch
+
+Lance's `_rowid IN (...)` materialize path works because Lance has a row-id index that
+translates a rowId to a `(fragment, offset)` address and the columnar reader supports
+random-access within a fragment — it's comparable in cost to a secondary-index lookup in
+an OLTP engine.
+
+Parquet and delta don't have this. A `WHERE user_id IN (k1, ..., kN)` over parquet
+pushes down to row-group-level min/max filtering (column stats), then still reads and
+decodes every row group that *might* contain a match. Delta adds optional file skipping
+via its own min/max stats and optional `DataSkippingNumIndexedCols` / Z-order /
+bloom-filter-on-column indexes, but the fundamental primitive is still "read whole row
+groups, filter predicates, emit matches" — not true random access. With an arbitrary
+set of `|L| × K` keys drawn from the full key space (which is typically the case for
+vector similarity — nearest neighbors are distributed, not clustered), the scan
+degenerates to "read most of the table."
+
+So the naive "just equi-join the merged top-K back to the primary table" is **not**
+free on parquet/delta. The top-K is small (`|L| × K` ≤ 10⁴ rows for typical queries)
+but the primary table is large (10⁶–10⁹ rows). A broadcast-join the small side onto a
+full parquet scan costs one full table scan per query. At that point, the brute-force
+cross-product isn't obviously worse — both are `O(|right|)` reads.
+
+#### Three materialize strategies, in order of increasing sophistication
+
+1. **Carry the needed columns in the sidecar (simplest, works today).** If the user
+ knows which columns are projected in the APPROX NEAREST query at sidecar-build time,
+ store those columns in the Lance sidecar alongside the embedding. Materialize runs
+ entirely against Lance (the current Lance-native path). Cost: sidecar duplicates
+ columns. Benefit: works without any change to the probe/materialize exec split.
+
+ Best fit: queries always project a stable small set of columns
+ (`SELECT q.qid, d.title, d.url FROM ...`). Store `(user_id, embedding, title, url)`
+ in the sidecar.
+
+2. **Equi-join broadcast the top-K list against primary (works for small-enough tables).**
+ Materialize emits `(leftId, rightKey)` pairs; a subsequent equi-join materializes
+ payload. Cost: one full scan of the primary table per query.
+
+ Best fit: tables small enough that a full parquet scan is already acceptable
+ (< ~100M rows at typical parquet compression), or queries already touching the
+ primary table for other reasons (the scan amortizes). Delta's data-skipping stats
+ help here if the primary table is partitioned and the top-K keys cluster on a
+ partition column — but vector-nearest rarely clusters on anything the user
+ partitioned by.
+
+3. **Format-native point-fetch, when the format supports it (parquet with a row-index
+ extension; delta with the row-id preview; iceberg with its row lineage feature).**
+ These formats have been adding optional row-id / row-position metadata exactly to
+ support this kind of random access. A future materialize stage could consult that
+ metadata and do O(K) I/O per query instead of O(|right|). Requires the format to
+ persist the sidecar-key → file-offset mapping, which none of parquet/delta does by
+ default today — but the trajectory in all three communities is toward supporting it.
+
+The PoC today implements option 1 by construction (Lance is both sidecar and primary).
+Option 2 is a ~100-LoC extension — replace `LanceMaterializeExec` with a
+`ForeignKeyJoinMaterializeExec` that equi-joins against the primary relation. Option 3
+depends on format work outside Spark's control.
+
+**Honest assessment:** Option 2 only beats the SPARK-56395 brute-force cross-product on
+perf if the primary table is under a certain size threshold (workload-dependent,
+roughly hundreds of millions of rows for typical column counts). For larger primary
+tables, the user is better off either (a) using option 1 with the columns carried in
+the sidecar, or (b) waiting for option 3 as parquet/delta row-indexing matures. This is
+worth stating plainly — the sidecar pattern isn't a universal win, and reviewers should
+know where it breaks down.
+
+### Format-agnostic extraction, if the appetite exists
+
+If multiple ecosystems (Lance, iceberg-spark, delta, hudi, native parquet with a vector-
+index extension) converge on this shape, a cleaner long-term split is:
+
+- **`NearestByJoinIndexProvider` trait** in `apache/spark` (or an ecosystem-shared location):
+ ```scala
+ trait NearestByJoinIndexProvider {
+ def probe(left: DataFrame, metric: Metric, k: Int, prefilter: Option[Expression])
+ : RDD[(Long, Seq[ScoredRowRef])]
+ def materialize(rowIds: RDD[(Long, Seq[Long])]): RDD[Row]
+ }
+ ```
+- **`IndexedNearestByJoinRule`** in Spark (or ecosystem-shared) — format-agnostic
+ pattern match, dispatches to whichever provider is registered for the right-side
+ relation.
+- **Per-backend module** — Lance, FAISS-on-parquet, delta-with-vector-index — each ships
+ a `NearestByJoinIndexProvider` implementation.
+
+That's explicitly **not** something this PoC proposes to do in apache/spark today.
+Refactoring to the trait shape costs a nontrivial amount of complexity for a trait
+that, right now, has one implementation. But the shape of this PoC is intentionally
+close to what such a generalized interface would look like — the three stages, the
+inter-stage schema, the Exchange-on-`_leftId`, and the prefilter contract are all
+format-agnostic.
+
+## Open questions for Spark maintainers
+
+The following are points where maintainer input would materially change the upstream
+story. They're phrased as questions because the PoC works without resolving them; but
+each is a place where a small apache/spark change would help.
+
+1. **Scaladoc on `NearestByJoin` about the `injectPostHocResolutionRule` constraint.**
+ Any future engine wanting to substitute a different physical plan must know that
+ `injectOptimizerRule` is too late. A one-line note on `NearestByJoin`'s class
+ scaladoc (or on `RewriteNearestByJoin`) would save the next implementer the hour of
+ debugging we spent learning this.
+
+2. **Extending the ranking-expression allowlist.** The rule pattern-matches on
+ `VectorL2Distance` / `VectorCosineSimilarity` / `VectorInnerProduct` — the three
+ expressions SPARK-56395 recognizes. If a future PR adds, e.g., Hamming distance for
+ binary vectors, downstream indexed-path implementations would have to be updated in
+ lock-step. A stable "is this a recognized vector-distance expression" predicate (or
+ a trait on the expression class) would let ecosystem rules match forward-
+ compatibly.
+
+3. **`NearestByJoin` attribute stability across rewrites.** The PoC's rule preserves
+ `NearestByJoin.output` attribute-for-attribute (same `ExprId`s) to avoid unresolving
+ references from parent operators — the same contract `RewriteNearestByJoin`
+ honors. That contract isn't explicitly documented on the class today; documenting it
+ would make future rewrite implementations safer.
+
+4. **Is there interest in an `@DeveloperApi` hook on `NearestByJoin` to register
+ alternative physical strategies declaratively?** Instead of every ecosystem needing
+ to write a `postHocResolutionRule` + a `SparkStrategy`, Spark could provide a
+ single registration point (e.g.,
+ `NearestByJoinStrategyRegistry.register(predicate, strategy)`). If the broader
+ community is interested, this would be a small, focused apache/spark PR. If it's
+ not, the current extension points are sufficient and the PoC lives entirely
+ downstream.
+
+5. **Row-identity metadata as a Spark-level contract for sidecar materialize.** The
+ "extending to parquet/delta via sidecar" story (above) depends on the primary
+ table exposing a stable row identifier that the sidecar can key against.
+ Parquet/delta/iceberg each have in-flight work to surface this (parquet row-index
+ extension, delta's row-id preview, iceberg row lineage) but the APIs differ across
+ formats. A Spark-level abstraction — e.g., a `SupportsRowIdentifier` mixin on DSv2
+ tables that returns a row-id column the optimizer can treat as unique + stable —
+ would let any format plug into a sidecar materialize path without the sidecar
+ needing format-specific knowledge. This is a bigger design discussion than the
+ other items here; flagging it because sidecar materialize is the primary
+ blocker to making the SPARK-56395 indexed path useful outside Lance-native
+ storage.
+
+## References
+
+- **Implementation:** [`lance-spark-knn-4.2_2.13` module](https://github.com/sezruby/lance-spark/tree/knn-phase0/lance-spark-knn-4.2_2.13) (Catalyst rule + session extension), [`lance-spark-knn_2.12` module](https://github.com/sezruby/lance-spark/tree/knn-phase0/lance-spark-knn_2.12) (shared logical/physical plans + RDD stages).
+- **Design overview:** [`DESIGN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/DESIGN.md).
+- **Reviewer reading order:** [`REVIEWER_GUIDE.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/REVIEWER_GUIDE.md).
+- **Benchmark numbers:** [`BENCHMARK_RESULTS.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/BENCHMARK_RESULTS.md).
+- **SPARK-56395 Spark JIRA:** https://issues.apache.org/jira/browse/SPARK-56395
+- **SPARK-56395 design doc:** https://docs.google.com/document/d/1opFVcQJgEWDWUVB7uVlFMlNomRwxqRu8iW0JmvCvxF0/
+
+**Status:** PoC on `sezruby/lance-spark:knn-phase0`. Recall=1.0 on unindexed + indexed
+paths, 100-200× speedup vs cross-product on real embeddings, 60+17 tests passing across
+Scala 2.12/2.13 and Spark 3.5/4.0/4.2-SNAPSHOT.
+
+**What's NOT claimed:** this isn't an RFC for apache/spark; it's a reference point for
+one concrete way to implement the indexed path the SPARK-56395 design doc mentions as
+future work.
diff --git a/lance-spark-knn_2.12/PHASE_PROGRESS.md b/lance-spark-knn_2.12/PHASE_PROGRESS.md
new file mode 100644
index 000000000..d8475e209
--- /dev/null
+++ b/lance-spark-knn_2.12/PHASE_PROGRESS.md
@@ -0,0 +1,415 @@
+# lance-spark-knn — phase progress & resume notes
+
+This document is the single source of truth for picking up the indexed nearest-by join work
+without the original chat context. Read it top-to-bottom before changing anything.
+
+If you are continuing this work, **do not skip the design rationale**: several of the design
+choices look arbitrary in isolation but lock with each other. The annotations are deliberate —
+the *why* is recorded inline in the source comments.
+
+## Where this lives
+
+- Repo: `https://github.com/sezruby/lance-spark` (fork of `eto-ai/lance-spark` — yes the upstream
+ org is Eto.ai).
+- Branch: `knn-phase0`
+- Open PR: `https://github.com/sezruby/lance-spark/pull/1` (draft; benchmarking PoC, not for
+ immediate merge upstream).
+
+## Goal
+
+Indexed approximate-nearest-neighbor join for Spark over a Lance dataset, exposed as a Scala
+function (and eventually a SQL `APPROX NEAREST 10 BY ...` clause via Catalyst). Backed by
+Lance's fragment-local IVF-PQ vector indexes, executed via per-task Lance probes plus a
+Spark-side merge — instead of an `O(|L| × |R|)` cross-product like the built-in
+`RewriteNearestByJoin` rule.
+
+The upstream Spark `NearestByJoin` operator (Spark 4.x) only has the cross-product rewrite. The
+work here is a layer-on solution that does **not** need to ship via Spark's SPIP process —
+Phase 2's Catalyst integration uses `injectPostHocResolutionRule`, not `injectOptimizerRule`,
+because `RewriteNearestByJoin` runs in `FinishAnalysis` (the optimizer's first batch) and would
+beat us to the punch otherwise.
+
+## Module layout
+
+```
+lance-spark-knn_2.12/ ← canonical sources (Phase 0/1)
+ pom.xml ← depends on lance-spark-base_2.12
+ + lance-spark-3.5_2.12 as test-scope
+ IMPL_PLAN.md ← architecture / phasing
+ DESIGN.md ← review-friendly overall feature design
+ PHASE_PROGRESS.md ← this file
+ src/main/scala/org/lance/spark/knn/
+ IndexedNearestJoin.scala ← public DataFrame API
+ internal/
+ Metric.scala ← L2 / Cosine / Dot enum, smallerIsBetter flag
+ ScoredRowRef.scala ← (rowId, score) tuple shipped through shuffle (field
+ named `rowAddr` for source-compat; semantically a row ID)
+ ProbedLeft.scala ← (leftRow, refs[]) tuple, the shuffle value
+ TopKHeap.scala ← bounded heap for map-side combine
+ LanceProbe.scala ← per-task probe primitive (open-once dataset)
+ LanceProbeStage.scala ← Phase 1 stage 1: probe RDD transformer
+ LanceMergeStage.scala ← Phase 1 stage 2 config (aggregation lives in LanceMergeExec)
+ LanceMaterializeStage.scala ← Phase 1 stage 3: point-fetch right rows
+ src/test/scala/org/lance/spark/knn/
+ IndexedNearestJoinTest.scala ← oracle equivalence (recall=1.0 vs. brute force)
+ IndexedNearestJoinPlanShapeTest.scala ← Phase 1 plan-shape assertions
+ internal/
+ TopKHeapTest.scala ← TopKHeap unit tests
+ LanceProbeValidationTest.scala ← LanceProbe primitive validation
+
+lance-spark-knn_2.13/ ← cross-build pom only
+ pom.xml ← sources point at ../lance-spark-knn_2.12/...
+
+lance-spark-knn-4.2_2.13/ ← Phase 2 Catalyst integration; Spark 4.2-only
+ pom.xml ← Arrow 19, spark-sql 4.2.0-SNAPSHOT (provided)
+ src/main/scala/org/lance/spark/knn/
+ catalyst/
+ IndexedNearestByJoinRule.scala ← postHocResolutionRule pattern match; emits
+ the 3-plan tree shared with the DataFrame path
+ extensions/
+ LanceKnnSparkSessionExtensions.scala ← user-facing entry point; registers rule + shared strategy
+ src/test/scala/org/lance/spark/knn/catalyst/
+ IndexedNearestByJoinRuleTest.scala ← pattern-match tests (positive + negative)
+ IndexedNearestByJoinE2ETest.scala ← SQL e2e on real Lance (Spark 4.2-SNAPSHOT)
+```
+
+The `_2.12` directory holds the canonical sources; `_2.13` is a thin pom that re-uses them. This
+matches `lance-spark-base_2.12` / `lance-spark-base_2.13` and `lance-spark-3.5_*` — same
+convention as the rest of the project.
+
+## What's done
+
+> A standalone, review-friendly design overview of the whole feature is in
+> `DESIGN.md` (next to this file). Read that for the "what / why / shape", and read
+> this file for "where things live, how to resume, gotchas".
+
+### Phase 0 — pure DataFrame API
+
+`IndexedNearestJoin.apply(left, rightLanceUri, leftVecCol, rightVecCol, k, ...)` — opens the
+right Lance dataset per task, probes Lance's nearest search per left row, materializes top-K
+right rows, emits join rows. All inside one `mapPartitions` block. No shuffle, no Catalyst.
+
+This phase exists to validate the per-task Lance access pattern works at all and to baseline
+recall/correctness against a brute-force oracle.
+
+### Phase 2 — Catalyst integration (new module `lance-spark-knn-4.2_2.13`)
+
+Spark 4.2-SNAPSHOT only. Adds a `postHocResolutionRule` that pattern-matches
+`NearestByJoin(approx = true, ...)` over a Lance scan and rewrites it to the 3-plan
+staged tree (`LanceProbeLogicalPlan → LanceMergeLogicalPlan → LanceMaterializeLogicalPlan`)
+wrapped in a `Project` that restores `NearestByJoin.output` exactly. The shared
+`LanceKnnStagedStrategy` lowers that tree to the matching physical execs — the SQL path
+and the DataFrame API path converge on the same physical shape.
+
+**Public surface unchanged.** Wiring is one extension class registration:
+
+```scala
+SparkSession.builder()
+ .config("spark.sql.extensions",
+ "org.lance.spark.knn.extensions.LanceKnnSparkSessionExtensions")
+ .config("spark.lance.knn.indexedNearestByJoin.enabled", "true")
+```
+
+After registration, any `APPROX NEAREST k BY {DISTANCE | SIMILARITY} f(l.vec, r.vec)` SQL
+query against a Lance table rewrites automatically. EXACT queries and unrecognized shapes
+flow through to Spark's existing brute-force rewrite — no regression.
+
+The rule is gated by `spark.lance.knn.indexedNearestByJoin.enabled` (default `false`) to
+keep it opt-in until the Phase 3 cost gate lands.
+
+**The single most important detail** — `injectPostHocResolutionRule`, NOT
+`injectOptimizerRule`. Spark's `RewriteNearestByJoin` runs in the optimizer's
+`FinishAnalysis` batch (the *first* batch). `injectOptimizerRule` adds rules to
+`operatorOptimizationBatch` which runs *after*; by then `NearestByJoin` is already gone.
+`injectPostHocResolutionRule` runs after analysis but before any optimizer batch — it's
+the only injection point that sees the unrewritten operator.
+
+Tests cover rule pattern-match (positive + negative) plus real-backend SQL e2e. The trick
+for the e2e: lance-spark-4.1's source compiles cleanly against 4.2-SNAPSHOT
+(`-Dspark41.version=4.2.0-SNAPSHOT -Darrow183.version=19.0.0`) and the resulting jar runs
+on 4.2's DSv2 API, so we recompile it locally and use it as the runtime Lance reader. Three
+e2e cases: rule-on routes through the 3-exec staged chain and matches oracle; WHERE-pushdown
+round-trips the prefilter and matches the filtered oracle; rule-off falls through to Spark's
+`RewriteNearestByJoin` and still matches oracle.
+
+### Phase 3 — hardening (partial)
+
+Two substantive items shipped:
+
+ 1. **`refineFactor` / `ef`** parameters on `IndexedNearestJoin.apply`. Plumbed through
+ `LanceProbeStage.Conf` to `LanceProbe.probe`, which calls `Query.Builder.setRefineFactor`
+ / `setEf`. IVF-PQ recall tuning + HNSW search depth respectively. Defaults preserve
+ current behavior.
+ 2. **`balanceFragmentsByRowCount`** flag. When true, `LanceFragments` enumerates fragments
+ with their row counts and runs LPT (Longest Processing Time) greedy bin-packing into N
+ groups. 4/3-approximation of optimal makespan. Default false = round-robin (Phase 1.5
+ behavior, fine for evenly-sized fragments).
+
+A `SupportsApproxNearestNeighborSearch` marker trait was prototyped during
+development and then dropped. The indexed-path executor calls Lance's Java API
+directly, so it's Lance-specific by construction; a general-purpose extension trait
+implies portability the rule can't actually deliver. Class-name detection + standard
+DSv2 options give us everything we need.
+
+`LanceProbe.vectorColumn` was moved from a constructor field to a per-call argument on
+`probe()` so the materialize stage no longer constructs the probe with a `vectorColumn = ""`
+placeholder — a code smell flagged in Phase 0.
+
+`prefilter` pushdown landed: when the right side is `Filter(cond, lance)` (or
+`Project(, Filter(cond, lance))`, the SQL `WHERE` shape), the rule translates
+the predicate to Lance SQL and threads it through `LanceProbeStage.Conf.prefilter` →
+`ScanOptions.filter()`. Translation handles binary comparisons, `IN`, `IS [NOT] NULL`,
+`AND`/`OR`/`NOT` over right-side attrs vs. literals. Anything else (UDFs, computed
+expressions, predicates referencing the LEFT input) → rule REFUSES the rewrite (returns
+the original `NearestByJoin`, falls through to brute force). Refusal — not partial
+pushdown — because dropping a residual conjunct would silently change semantics. Slow
+but correct.
+
+**3-stage explicit physical operators (DataFrame API path) — DONE.** The
+current shape ships the operator split + AQE-visible merge shuffle (via
+`ClusteredDistribution`) + single-pass inter-stage codec as one clean
+commit. An early development iteration hit reproducible `AssertionError` /
+SIGSEGV in `UnsafeRow.getLong` on `count()`-style consumers, initially
+misdiagnosed but narrowed to the real root cause during investigation.
+
+The initial diagnosis blamed a JVM-aarch64 + JIT C2 interaction — that was
+wrong. The crash reproduces on aarch64 but isn't a JVM bug. Step-wise
+isolation (`InterStageShuffleReproTest` → `InterStageShuffleWithLanceReproTest` →
+`StagedExecDirectDriveReproTest`) narrowed it to the staged execs specifically, not
+Spark's UnsafeRow shuffle or the Lance→Spark boundary. Diagnostic instrumentation caught
+0-field `UnsafeRow`s arriving at `LanceMaterializeExec.doExecute`. Dumping the executed
+plan for `count()` showed Catalyst's `ColumnPruning` rule inserting `Project(Nil)`
+wrappers between the custom nodes — empty projections that codegen to 0-field
+UnsafeRows. The decoder then crashed with `AssertionError` in interpreter mode, SIGSEGV
+in C2 (assertion elided → unmapped-memory read).
+
+Fix: `LanceMergeLogicalPlan` and `LanceMaterializeLogicalPlan` override
+`lazy val references = child.outputSet`. That makes `child.outputSet.subsetOf(references)`
+trivially true and short-circuits `ColumnPruning`'s guard, so no `Project(Nil)` ever
+gets inserted. `StagedPlansReferencesTest` pins the invariant structurally.
+
+Production today: `LanceProbeExec → ShuffleExchangeExec → LanceMergeExec →
+LanceMaterializeExec` wrapped by `AdaptiveSparkPlanExec`. `df.explain()` shows all four
+nodes; with AQE enabled, `AQEShuffleRead coalesced` appears on the merge-side shuffle.
+60 tests pass in `lance-spark-knn_2.12`. See `IMPL_PLAN.md` "3-exec staged split — root
+cause and fix" for the full post-mortem.
+
+**`df.kNearestJoin` extension.** Idiomatic DataFrame API mirroring
+`df.join(other, ...)` — works on Spark 3.5 / 4.0 / 4.1 / 4.2+ since it goes straight to
+`IndexedNearestJoin.apply` without touching the Phase 2 SQL parser. Extracts the Lance
+URI from the right DataFrame's analyzed plan; non-Lance right sides (parquet, in-memory
+DataFrames, alias-wrapped non-Lance) fail fast with `IllegalArgumentException`.
+
+What's left (see `IMPL_PLAN.md` for the full table): cost gate, Spark version matrix,
+AQE-visible shuffle for the fragment-grouped probe path (`runWithFragmentGroups`'s
+internal `partitionBy` is still RDD-level).
+
+### Benchmarks + SQL e2e
+
+Two oracle-validated benchmarks on M5 Max:
+
+ - **DataFrame** (`IndexedNearestJoinBenchmark` in `lance-spark-knn_2.12`): indexed staged
+ pipeline vs. naive Spark `crossJoin + array_distance UDF + row_number window`.
+ Headline: **608×** at 100K × 100 (109,373 ms → 180 ms).
+ - **SQL** (`IndexedNearestByJoinSqlBenchmark` in `lance-spark-knn-4.2_2.13`): same
+ `APPROX NEAREST` SQL with the Phase 2 rule ON vs. OFF (= Spark's `RewriteNearestByJoin`
+ cross-product + `min_by_k`). Headline: **17.4×** at 100K × 100 (3,728 ms → 214 ms).
+
+Both run a pre-timing oracle equivalence check on a 16-row left subset comparing every
+config (including the slow baseline / rule-OFF) against an in-memory brute-force ground
+truth. The benchmark `sys.error`'s if any disagrees, so quoted speedups are on validated
+results.
+
+Real-backend SQL e2e against Spark 4.2-SNAPSHOT works because lance-spark-4.1's source
+compiles cleanly against 4.2 and the resulting jar runs on 4.2's DSv2 API. Surfaced two
+real bugs in Phase 2 during development — View wrapper not unwrapped, `producedAttributes`
+missing — both fixed.
+
+Honest finding: Phase 1.5 fragment-grouping is *slower* than Phase 0/1 at every measured
+scale on this single laptop. Lance's internal cross-fragment merge already parallelizes via
+vectorized native kernels; Spark task-boundary parallelism doesn't help in shared-memory
+local mode. Plumbing is correct (oracle test passes); the win lands on a true distributed
+cluster, object-store-backed Lance, or a right side too large for one machine. Documented
+in `BENCHMARK_RESULTS.md`.
+
+### Phase 1.5 — fragment-grouped probing
+
+Same module as Phase 0/1 (`lance-spark-knn_2.12`). Adds an opt-in
+`probeParallelism: Int = 1` parameter on `IndexedNearestJoin.apply`. When > 1:
+
+ 1. Driver enumerates Lance fragment IDs via `Dataset.getFragments()` (helper
+ `internal/LanceFragments.scala`).
+ 2. Round-robin into N groups; broadcast.
+ 3. Replicate each left row across the N groups via `flatMap`, partition by
+ `groupIdx` so each task handles a single group.
+ 4. Each task opens `LanceProbe` with its group's `fragmentIds` and probes only those.
+ 5. Output keyed by `leftId` produces N contributions per leftId; downstream
+ `LanceMergeExec` (with `ClusteredDistribution(leftId)`) aggregates contributions
+ per-partition via `TopKHeap.merge` after the Catalyst-inserted exchange co-locates
+ them.
+
+The flatMap + partitionBy is one shuffle; the Catalyst-inserted Exchange above
+`LanceMergeExec` is the second. Two shuffles total — what the IMPL_PLAN's three-stage
+diagram has always shown.
+
+This is where the bandwidth win the IMPL_PLAN promises ("refs only ~24B") actually lands.
+Phase 1 had the staging but degenerate (single contributor per leftId). Phase 1.5 makes
+the merge stage do real work.
+
+**Edge case**: when `probeParallelism > numFragments`, only one group has fragments and
+the rule degenerates back to the Phase 1 single-task path — avoiding a replicate shuffle
+for nothing.
+
+**3 new tests**:
+ - Oracle equivalence with `probeParallelism = 4` and a 4-fragment right dataset. With
+ no index, every probe is exact, so the merge result must match brute force exactly.
+ - Plan-shape: `probeParallelism > 1` adds a second `ShuffledRDD` to the lineage
+ (verified by counting `ShuffledRDD` occurrences in `toDebugString`).
+ - `probeParallelism > numFragments` (e.g. 8 over a single-fragment dataset) still
+ produces correct results.
+
+Total knn module tests: 23 (was 16).
+
+### Phase 1 — staged RDD pipeline + Phase 3.x — explicit physical operators
+
+The Phase 0 inline `mapPartitions` was first split into three stage objects connected via
+`reduceByKey`. Phase 3.x then promoted the DataFrame path to three explicit `SparkPlan`
+operators with a Catalyst-inserted `ShuffleExchangeExec` between probe and merge. Current
+shape:
+
+```
+left.analyzed
+ -- LanceProbeLogicalPlan → LanceProbeExec --> (per-task) nearest-search
+ -- ShuffleExchangeExec hashpartitioning(_leftId) --> AQE wraps this when enabled
+ -- LanceMergeLogicalPlan → LanceMergeExec --> per-partition TopKHeap merge
+ -- LanceMaterializeLogicalPlan → LanceMaterializeExec --> _rowid point-fetch, assemble
+```
+
+**Public API unchanged.** `IndexedNearestJoin.apply(...)` callers see the same signature
+and the same output schema. Phase 0's oracle test passes unmodified — proves the
+refactor preserves correctness.
+
+**Plan-shape assertions**:
+- `IndexedNearestJoinPlanShapeTest`: executed plan contains `LanceProbe`, `LanceMerge`,
+ `LanceMaterialize`, and `Exchange`.
+- `IndexedNearestJoinAqeVisibilityTest`: `ShuffleExchangeExec hashpartitioning(_leftId)`
+ is in the tree (AQE on and AQE off), `AdaptiveSparkPlanExec` wraps it when AQE is on,
+ no `!` missingInput prefix.
+- `df.rdd.toDebugString` contains `ShuffledRowRDD` (the Catalyst shuffle reader — not
+ the pre-Phase-3.x `ShuffledRDD` produced by RDD-level `reduceByKey`).
+
+### What's not yet built
+
+#### Limitations remaining after Phase 0/1/1.5/2/3
+
+1. **Single-task probing when `probeParallelism = 1` (default)**: each `leftId` has exactly one
+ probe contributor and the merge function never fires — the shuffle is structurally present
+ but degenerate. Pass `probeParallelism > 1` to engage Phase 1.5's fragment-grouped path
+ where the merge stage actually aggregates contributions.
+2. **Left payload in shuffle**: `ProbedLeft` carries the full `leftRow` through the shuffle.
+ Cost is `~payload + 24B × K` per leftId-group instead of `~24B × K`. Fixing requires
+ repartitioning the left RDD by `leftId` up front and joining back at materialize via
+ `cogroup`. Deferred to Phase 3.x.
+3. **Synthetic leftId from `zipWithUniqueId`**, not a user-supplied join key. Means we can't
+ yet co-partition the left payload alongside `(leftId, refs)`.
+4. **Filter pushdown** (`prefilter`) — DONE. The Catalyst rule detects `Filter(cond, lance)`
+ on the right side, translates the predicate to a Lance SQL filter string, and threads it
+ through `LanceProbeStage.Conf.prefilter` → `ScanOptions.filter()`. Refuses rewrite if any
+ conjunct doesn't translate (no partial pushdown).
+5. **Vector column on materialize stage** — DONE. `LanceProbe.vectorColumn` is now a per-call
+ `probe()` argument, not a constructor field; the materialize stage opens `LanceProbe`
+ without any vector-column placeholder.
+
+The full Phase 3.x backlog (cost gate, real-recall test, etc.) lives in `IMPL_PLAN.md`'s
+"What's left" table.
+
+## Lance Java API surface used
+
+These were validated during Phase 0 from the upstream Lance Java sources before being used.
+If a future Lance version breaks any of them, that is the first place to look:
+
+| Class / method | Purpose |
+|---|---|
+| `Dataset.open()` (builder) | Open Lance dataset; takes `.uri()` `.allocator()` `.readOptions()` |
+| `ReadOptions.Builder().setVersion(v)` | Pin to a specific Lance version |
+| `LanceScanner.create(dataset, options, allocator)` | Create a scanner |
+| `ScanOptions.Builder.nearest(query)` | Configure vector nearest search |
+| `ScanOptions.Builder.prefilter(true)` | Required when `fragmentIds` is non-empty |
+| `ScanOptions.Builder.withRowId(true)` | Surface `_rowid` in result (we use this; the indexed nearest-search path doesn't materialize `_rowaddr`) |
+| `ScanOptions.Builder.fragmentIds(list)` | Restrict to specific fragments |
+| `ScanOptions.Builder.columns(emptyList)` | Project nothing — refs only |
+| `Query.Builder` (column, key, k, distanceType, nprobes) | Build the nearest-search query |
+| `org.lance.index.DistanceType` | L2 / Cosine / Dot enums |
+
+## Build & test
+
+The whole module builds via Maven (no SBT here — lance-spark project uses Maven):
+
+ cd /Users/esong/repos/lance-spark
+ ./mvnw -pl lance-spark-knn_2.12 test-compile
+ ./mvnw -pl lance-spark-knn_2.12 test
+ ./mvnw -pl lance-spark-knn_2.12 test -Dtest='IndexedNearestJoinTest'
+ ./mvnw -pl lance-spark-knn_2.12 test -Dtest='*PlanShape*'
+
+Cross-build (2.13):
+
+ ./mvnw -pl lance-spark-knn_2.13 test # uses _2.12 sources via pom
+
+Phase 2 module (Spark 4.2-SNAPSHOT, Scala 2.13 only) — first install Spark master locally:
+
+ cd /path/to/spark/master
+ ./build/mvn install -DskipTests -DskipChecks -Drat.skip=true -Dscalastyle.skip=true \
+ -pl sql/core -am
+
+Then in lance-spark:
+
+ ./mvnw install -DskipTests -pl lance-spark-knn_2.13 -am
+ ./mvnw -pl lance-spark-knn-4.2_2.13 test
+
+**Don't pass `-am` to surefire-filtered runs** — surefire's test pattern then runs against
+the base module too, which has zero matching tests and fails. Just run `-pl ` alone.
+
+## Gotchas observed during Phase 0/1 — keep these in mind
+
+1. **Scala 2.13 `Seq` doesn't match `mutable.ArraySeq`.** Spark `Row.get` on `ArrayType` returns
+ `mutable.ArraySeq`. `case s: Seq[_]` in 2.13 only matches `immutable.Seq`. Use the root
+ trait: `case s: scala.collection.Seq[_]`. See `LanceProbeStage.extractVector`.
+2. **Arrow `FieldVector.getObject` returns `JsonStringArrayList`**, not Scala `Seq`. Spark's
+ `RowEncoder` won't accept it for ArrayType slots. `LanceProbe.toSparkValue` recursively
+ converts `java.util.List → Seq`, `Map → Map`, `Text → String`.
+3. **Spark driver bind error in tests**: every `SparkSession.builder()` in tests must set
+ `spark.driver.bindAddress=127.0.0.1` and `spark.driver.host=127.0.0.1` or it fails to bind
+ on restricted networks (CI sandboxes, dev containers).
+4. **Lance's `format("lance")` registration** comes from `lance-spark-3.5_2.12`'s
+ `META-INF/services/org.apache.spark.sql.sources.DataSourceRegister`. The knn module's
+ sources don't depend on that, but its **tests** do — `lance-spark-3.5_2.12` is a test-scope
+ dependency in the pom for that reason.
+5. **`vectorColumn = ""` on the materialize-only LanceProbe**: not a bug, the param is just
+ unused on that path. Documented inline.
+
+## Validation checklist for new changes
+
+Before declaring a phase done:
+
+- [ ] `./mvnw -pl lance-spark-knn_2.12 test` — all tests pass
+- [ ] Phase 0 oracle test (`testInnerJoinMatchesBruteForceOracle`) passes — recall = 1.0 vs.
+ plain-Scala brute force is a load-bearing correctness check
+- [ ] `df.rdd.toDebugString` for the output of `IndexedNearestJoin.apply` matches the phase's
+ expected staged form (current: contains `ShuffledRowRDD` — the Catalyst shuffle
+ reader, produced by the 3-exec staged plan)
+- [ ] Spotless / scalastyle / checkstyle all clean (`./mvnw spotless:check`)
+- [ ] No new non-ASCII chars in source (especially smart quotes / em-dashes from copy-paste)
+- [ ] Update IMPL_PLAN.md status table if a phase moved
+- [ ] Update this file's "What's done" section
+
+## Quick map: where to look for X
+
+| Question | File |
+|---|---|
+| How does the public API work? | `IndexedNearestJoin.scala` — start at `apply` |
+| Why these three stages? | `IMPL_PLAN.md` "Three-phase distributed design" |
+| How is the shuffle bandwidth bounded? | `ScoredRowRef.scala` doc comment |
+| Why `injectPostHocResolutionRule` for Phase 2? | `IMPL_PLAN.md` "Phase 2 — Catalyst integration" |
+| What does `_rowid IN (...)` lower to in Lance? | `LanceProbe.materialize` — pushdown into row-id lookup, point-fetch path. Switched from `_rowaddr` for indexed-path compatibility; see DESIGN.md "Why `_rowid` not `_rowaddr`". |
+| How does fragment-grouping (Phase 1.5) work? | Pass `probeParallelism > 1` to `IndexedNearestJoin.apply`. `LanceFragments.enumerateGroups` round-robins fragment IDs into N groups, `LanceProbeStage.runWithFragmentGroups` replicates rows × groups, downstream merge aggregates. |
+| Why does `Metric` carry `smallerIsBetter`? | `Metric.scala` — drives heap eviction direction |
diff --git a/lance-spark-knn_2.12/REVIEWER_GUIDE.md b/lance-spark-knn_2.12/REVIEWER_GUIDE.md
new file mode 100644
index 000000000..ff6438440
--- /dev/null
+++ b/lance-spark-knn_2.12/REVIEWER_GUIDE.md
@@ -0,0 +1,209 @@
+# Reviewer guide — `lance-spark-knn`
+
+This PR is ~13 K LoC across ~60 files. It lands as a single branch but the
+**intended upstream delivery is 7 smaller PRs** (see
+[`UPSTREAM_DELIVERY_PLAN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md)). This guide helps
+you navigate the current monolithic branch in **review-meaningful reading
+order** so you can form an opinion on the design without slogging through
+files in alphabetical order.
+
+## Start here (10 minutes)
+
+Read these 3 files first. After them, you know what the feature is and can
+decide whether you want to dig deeper:
+
+1. **[`DESIGN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/DESIGN.md)** — overall architecture + the "why no-index
+ Lance beats Spark cross-product" SIMD/columnar breakdown.
+2. **[`IndexedNearestJoin.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/IndexedNearestJoin.scala)** — the public entry point. 232 lines. One function,
+ `IndexedNearestJoin.apply(...)`, builds the 3-logical-plan tree and
+ hands it to Catalyst. The scaladoc on `apply` documents every parameter.
+3. **[`LanceKnnImplicits.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala)** — the `df.kNearestJoin(rightDf, ...)` extension method.
+ ~160 lines. Syntactic sugar over `IndexedNearestJoin.apply` with URI
+ auto-extraction.
+
+After those 3 you know the shape of the feature from a user perspective.
+
+## Next: read the engine (30 minutes)
+
+Four files. These are the 3-exec staged Catalyst operators on top of
+the RDD primitives. This is the heart of the feature and the subtlest
+part to review:
+
+1. **[`StagedPlans.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedPlans.scala)** — the three logical plan nodes
+ (`LanceProbeLogicalPlan`, `LanceMergeLogicalPlan`,
+ `LanceMaterializeLogicalPlan`).
+ **Critical invariant:** Merge and Materialize override
+ `lazy val references = child.outputSet`. Removing that line reintroduces
+ the ColumnPruning → `Project(Nil)` → 0-field UnsafeRow → SIGSEGV crash
+ the scaladoc describes. `IMPL_PLAN.md` "3-exec staged split — root
+ cause and fix" has the full post-mortem.
+2. **[`StagedExecs.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedExecs.scala)** — the three physical operators
+ (`LanceProbeExec`, `LanceMergeExec`, `LanceMaterializeExec`).
+ **Key decision:** `LanceMergeExec.requiredChildDistribution =
+ ClusteredDistribution(leftId)` — that's what makes Catalyst's
+ `EnsureRequirements` insert the `ShuffleExchangeExec` between probe
+ and merge, which is what AQE wraps.
+3. **[`ProbedLeftCodec.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/ProbedLeftCodec.scala)** — encode/decode for the inter-stage
+ `(leftId, leftRow, refs)` rows that cross the `ShuffleExchangeExec`
+ boundary. Flat schema (not nested struct) because nested struct tripped
+ a Spark serializer bug on arm64 at benchmark scale.
+4. **[`LanceKnnStagedStrategy.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/LanceKnnStagedStrategy.scala)** — the `SparkStrategy` that lowers the
+ three logical plans to the three physical execs.
+
+## Next: read the RDD primitives (30 minutes)
+
+These are called from the `Exec` nodes' `doExecute`. They're tested
+directly by `LanceProbeValidationTest` / `TopKHeapTest` and were the Phase
+0/1 foundation before the Catalyst operators were added.
+
+- **[`LanceProbe.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbe.scala)** — per-task Lance dataset handle +
+ `probe()` + `materialize()`. Closes the dataset on `close()`.
+- **[`LanceProbeStage.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbeStage.scala)** — two `run` methods: one for
+ `probeParallelism = 1` (default, `mapPartitions`), one for Phase 1.5
+ fragment-grouped via `flatMap` + `partitionBy`.
+- **[`LanceMergeStage.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMergeStage.scala)** — merge-stage `Conf` (the per-partition
+ `TopKHeap.merge` aggregation itself lives in `LanceMergeExec.doExecute`).
+- **[`LanceMaterializeStage.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMaterializeStage.scala)** — point-fetch right rows by
+ `_rowid`, assemble the output join rows.
+- **[`TopKHeap.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/TopKHeap.scala)** — min/max heap for bounded top-K.
+- **[`LanceFragments.scala`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceFragments.scala)** — driver-side fragment enumeration +
+ round-robin + LPT bin-packing for skew balancing.
+
+## Optional deeper reads
+
+**The ColumnPruning investigation** (post-mortem): [`IMPL_PLAN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/IMPL_PLAN.md)
+§ "3-exec staged split — root cause and fix". Required reading before
+touching `StagedPlans.scala`'s `references` override.
+
+**SQL integration (Spark 4.2-only)**: the [`lance-spark-knn-4.2_2.13/`](https://github.com/sezruby/lance-spark/tree/knn-phase0/lance-spark-knn-4.2_2.13) module intercepts
+Spark 4.2's `NearestByJoin` operator. Uses the same `staged/` pipeline
+under the hood. Note: Spark 4.2 is SNAPSHOT at time of this PR — this
+module depends on an unreleased Spark version. See
+[`UPSTREAM_DELIVERY_PLAN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md) § "Out-of-scope"
+for why it's deferred from upstream delivery.
+
+**Phase 3 hardening** (refineFactor / ef / prefilter pushdown / index-name
+handling): read [`IMPL_PLAN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/IMPL_PLAN.md) § "What's left (Phase 3.x)"
+— each row in that table links the feature to its test.
+
+**Benchmark results + infrastructure**: [`BENCHMARK_RESULTS.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/BENCHMARK_RESULTS.md)
+has both M5 Max local numbers and OSS Spark 3.5 cluster numbers.
+The cluster section is where the Phase 1.5 design claim was validated:
+**100–200× faster than Spark crossJoin on Cohere Wikipedia embeddings at
+dim=1024** (7-iter median 160×; noop sink + 16-row brute-force oracle check).
+
+## Test map — what to run
+
+```sh
+# Phase 0/1 core — probe primitive + staged RDD pipeline
+./mvnw -pl lance-spark-knn_2.12 test -Dtest='LanceProbeValidationTest,TopKHeapTest,IndexedNearestJoinTest,IndexedNearestJoinCorrectnessTest'
+
+# Phase 1.5 fragment grouping
+./mvnw -pl lance-spark-knn_2.12 test -Dtest='LanceFragmentsTest,IndexedNearestJoinFragmentGroupingTest'
+
+# Phase 2 — 3-exec Catalyst operators + AQE + consumer-shape regression
+./mvnw -pl lance-spark-knn_2.12 test -Dtest='StagedPlansReferencesTest,IndexedNearestJoinAqeVisibilityTest,IndexedNearestJoinPlanShapeTest,IndexedNearestJoinConsumerShapeTest,IndexedNearestJoinJitStressTest'
+
+# ColumnPruning isolation tests (regression coverage for the reverted-then-restored crash)
+./mvnw -pl lance-spark-knn_2.12 test -Dtest='InterStageShuffleReproTest,InterStageShuffleWithLanceReproTest'
+
+# Phase 3 — IVF-PQ recall
+./mvnw -pl lance-spark-knn_2.12 test -Dtest='IndexedNearestJoinIvfPqRecallTest'
+
+# DataFrame extension
+./mvnw -pl lance-spark-knn_2.12 test -Dtest='LanceKnnImplicitsTest'
+
+# All of the above at once
+./mvnw -pl lance-spark-knn_2.12 test
+```
+
+Or on Scala 2.13:
+
+```sh
+./mvnw -pl lance-spark-knn_2.13 test
+```
+
+60 tests across the 2.12/2.13 module. All pass.
+
+## Trust-but-verify checklist
+
+If you're reviewing a specific aspect, here's where to look:
+
+| Concern | Check |
+|---|---|
+| **Correctness.** Does the indexed path produce the same top-K as brute force? | `IndexedNearestJoinCorrectnessTest` — brute-force oracle at 1K×100×dim=16, recall=1.0. `IndexedNearestJoinTest` — 4 end-to-end tests. `IndexedNearestJoinBenchmark` pre-timing validates ALL configs against oracle on 16-row subset. |
+| **AQE actually engages.** Can the merge shuffle be coalesced / rebalanced? | `IndexedNearestJoinAqeVisibilityTest` (5) — checks `ShuffleExchangeExec hashpartitioning(_leftId)` is in the tree, AQE wraps with `AdaptiveSparkPlanExec`, and `AQEShuffleRead coalesced` appears on the merge shuffle. |
+| **No regression on consumer shapes.** `count()` / `agg(count("*"))` / `select(lit(1))` don't crash. | `IndexedNearestJoinConsumerShapeTest` (4). These are the exact shapes that crashed the reverted code. |
+| **No JIT-level crash at scale.** | `IndexedNearestJoinJitStressTest` — 20-iter × 10K right × 100 left × dim=128. Passes clean. |
+| **The `references = child.outputSet` fix is structurally pinned.** | `StagedPlansReferencesTest` (3) — asserts the override exists on both logical plans and that ColumnPruning's subset guard short-circuits. |
+| **Lance's per-stage dataset open is efficient.** | `LanceProbeValidationTest` exercises probe across batches; `LanceProbe.close()` is called in the `try`/`finally` of every stage's `doExecute`. |
+| **Fragment grouping doesn't produce duplicates.** | `IndexedNearestJoinFragmentGroupingTest` — oracle equivalence with G=4 and G=8, + skew-balanced variant. |
+| **IVF-PQ integration produces real recall.** | `IndexedNearestJoinIvfPqRecallTest` (3) — builds a real IVF-PQ index on a Lance dataset, recall@10 = 0.73 at defaults, **1.00 with `refineFactor=8`**. |
+
+## Files that LOOK large but are mechanical
+
+- `InterStageShuffleReproTest.scala` + `InterStageShuffleWithLanceReproTest.scala` — ~700 lines
+ combined. They're the isolation tests from the ColumnPruning investigation.
+ Each has ~5-10 lines of actual assertion logic; the rest is test data
+ setup (synthetic row generators, Lance write helpers, JIT-stress loops).
+ Kept as regression coverage.
+- `IndexedNearestJoinBenchmark.scala` — the performance benchmark. Long
+ scaladoc comments explain design trade-offs; actual code is the timing
+ harness + 5 config runs. Not a regression test. See
+ [`UPSTREAM_DELIVERY_PLAN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md) § "Out of scope"
+ for why benchmarks aren't shipped upstream.
+
+## Out of scope for upstream review
+
+Per [`UPSTREAM_DELIVERY_PLAN.md`](https://github.com/sezruby/lance-spark/blob/knn-phase0/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md), these files
+will NOT ship to `lance-format/lance-spark` (they're kept on the fork):
+
+- All of `benchmark/` (6 files)
+- The `lance-spark-knn-4.2_2.13/` module (deferred until Spark 4.2 releases)
+- `lance-spark-knn_2.12/pom.xml`'s Linux-x86_64-only shade filter
+ (deployment-specific to some managed-Spark distributions' ingress timeout on
+ volume uploads)
+- `BENCHMARK_RESULTS.md` § "Cluster benchmarks — OSS Spark 3.5 on Kubernetes"
+
+If you're reviewing with an eye toward upstream merge, skip those.
+
+## Commit-by-commit alternative
+
+The branch is organized as 9 feature-boundary commits matching the
+upstream delivery plan — if you prefer commit-based review, walk the
+`git log` in order and read each commit's body for the "why":
+
+1. `feat(knn): Phase 0 foundation — LanceProbe primitive + metric types`
+2. `feat(knn): staged RDD pipeline + IndexedNearestJoin.apply + bounded TopKHeap`
+3. `feat(knn): Phase 1.5 — fragment-grouped probing for multi-task parallelism`
+4. `feat(knn): 3-exec Catalyst-visible staged plan with AQE-visible merge shuffle`
+ — **the heaviest commit; read this one closely.**
+5. `feat(knn): df.kNearestJoin DataFrame extension method`
+6. `feat(knn): Phase 3 hardening — refineFactor, prefilter pushdown, IVF-PQ recall`
+7. `feat(knn): Spark 4.2 SQL integration — IndexedNearestByJoinRule`
+8. `test(knn-bench): benchmark suite — synthetic, Wikipedia perf, SIFT/Cohere recall, SQL`
+9. `docs(knn): design, impl plan, reviewer guide, ANN proposal, benchmark results`
+
+Commit #4 is the subtle one — it introduces the 3-exec staged split
+with the `references = child.outputSet` override that prevents
+ColumnPruning from inserting `Project(Nil)` wrappers. The commit body
+has the full root-cause narrative; also see `IMPL_PLAN.md` "3-exec
+staged split — root cause and fix".
+
+## Questions, sharp edges, and known limitations
+
+- **`probeParallelism > 1` is a tradeoff**, not a pure win. Phase 1.5
+ fragment grouping pays a 2-shuffle cost for `|R|/G` per-task work. On
+ local-laptop runs it's net negative. On true distributed clusters with
+ `probeParallelism == numFragments` it's a ~1.7× win. Documented in
+ `BENCHMARK_RESULTS.md` § "Surprise: Phase 1.5 doesn't help at
+ local-laptop scale" and § "Cluster benchmarks".
+- **`probeParallelism > 1` uses an RDD-level shuffle** (`partitionBy`)
+ inside `runWithFragmentGroups` that is NOT AQE-visible. Fixing would
+ require a different shape that fits Catalyst's
+ `requiredChildDistribution` model. Documented in `IMPL_PLAN.md` as
+ follow-up.
+- **Cost gate** — the Phase 2 SQL rule is opt-in via
+ `spark.lance.knn.indexedNearestByJoin.enabled = true`. Production-grade
+ delivery needs a cost-based heuristic (on `|R|`, selectivity) to decide
+ automatically. Documented as TODO.
diff --git a/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md b/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md
new file mode 100644
index 000000000..c1d861f81
--- /dev/null
+++ b/lance-spark-knn_2.12/UPSTREAM_DELIVERY_PLAN.md
@@ -0,0 +1,310 @@
+# Upstream delivery plan — `knn-phase0` → `lance-format/lance-spark:main`
+
+The `knn-phase0` branch is ~13,300 LoC ahead of upstream `main` across ~60
+files, organized as 9 feature-boundary commits. That's too large to land
+as one PR. This document lays out a split into independent, reviewable
+PRs, broadly aligned with the 9-commit structure.
+
+## Scope limits
+
+- **Benchmarks not shipped.** The `benchmark/` directory (both `src/main/`
+ and `src/test/`) is internal tooling: it justifies design decisions and
+ produced the headline speedup numbers, but isn't part of the public API
+ and pulls in HTTP-fetch / local-FS logic that doesn't belong in the
+ connector. Keep benchmarks on the fork / document them in an external
+ `BENCHMARK_RESULTS.md` post.
+- **Spark 4.2 module not shipped yet.** The `lance-spark-knn-4.2_2.13`
+ module depends on Spark 4.2-SNAPSHOT's `NearestByJoin` logical plan
+ (SPARK-56395), which isn't in a released Spark yet. Defer this PR until
+ Spark 4.2.0 publishes to Maven Central.
+
+## Redundancy audit
+
+Scanned the branch for dead / duplicate / unused files. Two items found + resolved; three
+items checked and kept (with rationale).
+
+### Removed during development (not present on the branch today)
+
+1. **An earlier duplicate `IndexedNearestJoinBenchmark.scala`** sat in
+ `src/test/` after the benchmark was moved to `src/main/` to ship in
+ the shaded fat JAR. The duplicate is gone in the current branch.
+
+2. **Phase 2 single-exec SQL path + `LanceMergeStage.run` (with `reduceByKey`).**
+ An earlier iteration shipped `IndexedNearestByJoinPlan` +
+ `IndexedNearestByJoinExec` + `IndexedNearestByJoinStrategy` — a single
+ UnaryExec that called the old `LanceProbeStage.run` /
+ `LanceMergeStage.run` / `LanceMaterializeStage.run` helpers (where
+ `LanceMergeStage.run` did the shuffle via RDD `reduceByKey`). Once the
+ DataFrame path moved to the 3-exec Catalyst-visible design, the
+ SQL-specific single-exec became redundant. The current branch emits
+ the 3-plan logical tree from `IndexedNearestByJoinRule` and shares
+ `LanceKnnStagedStrategy` between both paths; `LanceMergeStage.run`
+ is gone (merge is now a per-partition `mapPartitions` inside
+ `LanceMergeExec`, fed by a Catalyst-inserted `ShuffleExchangeExec`).
+
+### Kept — not redundant
+
+1. **`InterStagePayloadOverheadBench.scala`** — microbenchmark that backs the
+ Catalyst-struct vs Kryo-blob choice documented in `ProbedLeftCodec`'s scaladoc.
+ Test-scope only, not shipped to users. Kept as re-runnable evidence for the design
+ claim in the comment (cited by name from `StagedExecs.scala` and `ProbedLeftCodec.scala`).
+
+2. **`IndexedNearestJoinJitStressTest` + `InterStageShuffleReproTest` +
+ `InterStageShuffleWithLanceReproTest`** — diagnostic tests from the
+ ColumnPruning investigation that led to the `references = child.outputSet`
+ fix. They rule out what *wasn't* the cause (JVM-aarch64, Spark UnsafeRow
+ shuffle, Lance→Spark boundary). Kept as regression coverage — they would
+ catch the same class of crash if it returned in a different place.
+
+3. **Legacy RDD path in `IndexedNearestJoin.apply`** — the public entry
+ point now builds the 3-exec staged logical plan tree directly. The
+ RDD-level `LanceProbeStage` / `LanceMergeStage` / `LanceMaterializeStage`
+ helpers are still called from the `Exec` nodes' `doExecute`, so they're
+ not dead.
+
+4. **`LanceVectorIndexBuilder.scala`** in test/ — helper used only by
+ `IndexedNearestJoinIvfPqRecallTest`. Not a duplicate; test-scoped utility.
+
+## PR split strategy
+
+The split axis is "minimum reviewable unit" — each PR introduces a feature
+that stands on its own, with tests that exercise only that feature. Phase
+ordering is preserved so reviewers can read commits chronologically.
+
+### PR 1: Phase 0 foundation — `LanceProbe` primitive
+
+**Goal:** Ship the per-task Lance nearest-search primitive + oracle tests.
+This is the smallest self-contained unit of the feature. Nothing here is
+user-facing; the primitive is private and exercised via unit tests.
+
+Files (~700 lines):
+
+- `lance-spark-knn_2.12/pom.xml` (new module; minimal deps)
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbe.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/Metric.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ScoredRowRef.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceProbeValidationTest.scala`
+- `lance-spark-knn_2.13/pom.xml` (shared-source 2.13 build)
+- `pom.xml` (reactor registration of the two modules)
+
+**Tests:** `LanceProbeValidationTest` (4) — probe returns K row refs,
+ordered by distance, with correct score column.
+
+**Why it's reviewable standalone:** no connection to Catalyst, no
+DataFrame API, no extension points. Just "here's a primitive that opens
+Lance, runs `nearest` search, returns `(rowId, score)` pairs." A reviewer
+can decide on the API shape without worrying about Spark integration.
+
+**Estimated review time:** 2–3 hours.
+
+### PR 2: Phase 0/1 — `IndexedNearestJoin.apply` + staged RDD pipeline
+
+**Goal:** Add the three-stage RDD pipeline (probe → shuffle → merge →
+materialize) and the bounded-merge heap. This is the functional core.
+
+Files (~1,500 lines):
+
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/IndexedNearestJoin.scala` (entry point — RDD-only shape for this PR; PR 4 rewires it to build the 3-plan logical tree)
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbeStage.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMergeStage.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMaterializeStage.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ProbedLeft.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/TopKHeap.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/TopKHeapTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinCorrectnessTest.scala`
+
+**Tests:** oracle equivalence, plan-shape (`ShuffledRDD` in lineage),
+outer-join / custom scoreCol / projection.
+
+**Why split from PR 1:** even a reviewer happy with `LanceProbe`'s API may
+want to debate the staged-pipeline design separately. The pipeline shape
+(`zipWithUniqueId` for leftId, Catalyst-inserted hash-partitioned exchange
+above `LanceMergeExec` for the merge shuffle, point-fetch materialize) is
+the "claim" this PR makes.
+
+**Estimated review time:** 4–6 hours.
+
+**Caveat:** this temporarily regresses behavior vs the current
+`knn-phase0` branch (no AQE on merge shuffle, no Catalyst operators in
+`df.explain()`). PR 4 restores that.
+
+### PR 3: Phase 1.5 — fragment-grouped probe (`probeParallelism > 1`)
+
+**Goal:** Add the optional fragment-grouping that enables parallel
+probe tasks across a distributed cluster.
+
+Files (~600 lines):
+
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceFragments.scala`
+- Modifications to `LanceProbeStage.scala`: add `runWithFragmentGroups`
+- Modifications to `IndexedNearestJoin.scala`: `probeParallelism` +
+ `balanceFragmentsByRowCount` parameters
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceFragmentsTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinFragmentGroupingTest.scala`
+
+**Tests:** LPT bin-packing math, oracle equivalence with G=4 and G=8
+groups, skew-balanced variant.
+
+**Why split from PR 2:** fragment grouping is opt-in (`probeParallelism = 1`
+is the default and doesn't use it). A reviewer can accept the staged
+pipeline first, then debate whether / how to expose the
+`probeParallelism` knob. Cluster evidence shows fragment grouping pays
+off only when `probeParallelism == numFragments` — the default of 1
+is correct for single-machine / single-executor.
+
+**Estimated review time:** 2–3 hours.
+
+### PR 4: Phase 2 — 3-exec staged Catalyst operators + AQE visibility
+
+**Goal:** Replace the RDD-only execution path with
+`LanceProbeExec → ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec`
+so `df.explain()` sees the pipeline and AQE can engage on the merge
+shuffle.
+
+Files (~1,300 lines):
+
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/ProbedLeftCodec.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedPlans.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedExecs.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/LanceKnnStagedStrategy.scala`
+- `lance-spark-knn_2.12/src/main/scala/org/apache/spark/sql/LanceKnnDatasetBridge.scala`
+- Modifications to `IndexedNearestJoin.apply`: route through logical plans
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinAqeVisibilityTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinPlanShapeTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinConsumerShapeTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinJitStressTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/staged/StagedPlansReferencesTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleReproTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleWithLanceReproTest.scala`
+
+**PR description** should include the full "3-exec staged split — root
+cause and fix" post-mortem from the current `IMPL_PLAN.md`: the
+ColumnPruning guard, the `Project(Nil)` insertion, the 0-field UnsafeRow
+crash, the `references = child.outputSet` fix. This is the most subtle
+part of the feature and reviewers need the history.
+
+**Tests:**
+
+- `StagedPlansReferencesTest` (3) — pins the `references` override
+- `IndexedNearestJoinAqeVisibilityTest` (5) — AQE engages, `AdaptiveSparkPlanExec` wraps, all 3 execs in tree
+- `IndexedNearestJoinConsumerShapeTest` (4) — `count()`, `agg(count("*"))`, `select(lit(1))`, `collect()` — the shapes that crashed the reverted code
+- `IndexedNearestJoinJitStressTest` (2) — 20-iteration JIT stress
+- `InterStageShuffleReproTest` (6) + `InterStageShuffleWithLanceReproTest` (4) — the isolation tests from the investigation, kept as regression coverage
+
+**Estimated review time:** 6–10 hours. This is the heaviest PR.
+
+### PR 5: `df.kNearestJoin` DataFrame extension
+
+**Goal:** User-facing extension method on `DataFrame` that mirrors
+`df.join(other, ...)`. Wraps `IndexedNearestJoin.apply` with URI extraction
+from the right DataFrame's analyzed plan.
+
+Files (~250 lines):
+
+- `lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala`
+
+**Tests:** extension method end-to-end, Filter-on-right unwrap, Lance-only
+format guard (rejects non-Lance right).
+
+**Why a separate PR:** pure syntactic sugar over PR 2-4. A reviewer can
+debate the API name + ergonomics without touching the engine.
+
+**Estimated review time:** 1 hour.
+
+### PR 6: Phase 3 hardening — `refineFactor`, `ef`, IVF-PQ recall test
+
+**Goal:** Add the IVF-PQ recall knobs (`refineFactor`, `ef`) + a real
+recall test built against an actual Lance vector index.
+
+Files (~600 lines):
+
+- Modifications to `IndexedNearestJoin.apply`: new `refineFactor` / `ef` params
+- Modifications to `LanceProbeStage.Conf` + `LanceProbe.probe`: plumb the knobs
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinIvfPqRecallTest.scala`
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceVectorIndexBuilder.scala` (test util)
+- `lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/testutil/ClusteredEmbeddings.scala` (synthetic clustered data for recall tests)
+
+**Tests:** recall@10 at default IVF-PQ = 0.73, with `refineFactor=8`
+reaches 1.00. Clustered-embeddings recall is printed but not asserted
+(run-to-run variance too large at test scales).
+
+**Why split:** the recall knobs compose cleanly with the existing probe
+API; this PR just adds parameters.
+
+**Estimated review time:** 2 hours.
+
+### PR 7: Spark 4.0 compatibility (reflection bridge)
+
+**Goal:** Make the module work on Spark 4.0+ where
+`org.apache.spark.sql.Dataset` moved to
+`org.apache.spark.sql.classic.Dataset`.
+
+Files (~60 lines):
+
+- Modifications to `LanceKnnDatasetBridge.scala` — reflection-based
+ `ofRows` lookup
+
+**Tests:** 41/41 tests pass on Spark 4.0 + Scala 2.13 with this bridge
+(validated on `knn-phase0` with temporary pom overrides).
+
+**Why separate:** trivially small, easy to review, makes the
+`lance-spark-knn_2.13` module able to compile against `spark40.version`.
+Doesn't need to block anything else.
+
+**Estimated review time:** 30 minutes.
+
+## Suggested review order for upstream
+
+```
+PR 1 (Phase 0: LanceProbe primitive)
+ └─ PR 2 (Phase 0/1: IndexedNearestJoin.apply, staged RDD pipeline)
+ ├─ PR 3 (Phase 1.5: fragment grouping)
+ ├─ PR 4 (Phase 2: 3-exec Catalyst operators, AQE visibility)
+ │ └─ PR 5 (df.kNearestJoin extension)
+ ├─ PR 6 (Phase 3: refineFactor / ef / IVF-PQ recall)
+ └─ PR 7 (Spark 4.0 compat) — parallel-reviewable, small
+```
+
+PR 1 → 2 → 4 is the critical path for the "real" feature to work. PR 3, 5,
+6, 7 can land in parallel once PR 4 is in.
+
+## Out-of-scope (keep on fork indefinitely)
+
+- **All benchmarks** (`benchmark/` in both `main/` and `test/` trees):
+ `IndexedNearestJoinBenchmark`, `SiftRecallBenchmark`,
+ `CohereWikiRecallBenchmark`, `WikipediaKnnPerfBenchmark`,
+ `IndexedNearestJoinSoakTest`, `InterStagePayloadOverheadBench`.
+- **Deployment-specific build config**: the Linux-x86_64 shade filter in
+ `lance-spark-knn_2.12/pom.xml` (only needed for managed-Spark distributions
+ with volume-upload ingress timeouts).
+- **Cluster results section** in `BENCHMARK_RESULTS.md` (specific cluster
+ instance; the benchmark tooling is generic).
+- **`lance-spark-knn-4.2_2.13`** module (Spark 4.2 still SNAPSHOT; revisit
+ after Spark 4.2.0 releases to Maven Central).
+
+## Mechanics — how to actually create the PRs
+
+Each PR is a separate branch cut from `origin/main`, cherry-picking only
+its commits:
+
+```sh
+git checkout -b upstream/pr1-lance-probe origin/main
+# cherry-pick the phase-0 commits that touch only LanceProbe.scala / Metric.scala / ScoredRowRef.scala
+# resolve conflicts against upstream main
+# run tests, push, open PR against lance-format/lance-spark:main
+```
+
+Commits on `knn-phase0` don't map 1:1 to PRs — the branch's history
+includes the reverted-and-restored 3-exec saga and benchmark additions
+that should be squashed. Each PR should present as 1–3 clean commits with
+thorough messages, not the full investigation timeline.
+
+## Known gaps to fix before submitting
+
+- [ ] Each PR branch should rebase onto current `origin/main` and run its
+ test subset.
+- [ ] Open a JIRA ticket or GitHub issue on the upstream repo describing
+ the feature + PR sequence before opening the first PR, so
+ reviewers have context.
diff --git a/lance-spark-knn_2.12/pom.xml b/lance-spark-knn_2.12/pom.xml
new file mode 100644
index 000000000..f2a8e7418
--- /dev/null
+++ b/lance-spark-knn_2.12/pom.xml
@@ -0,0 +1,218 @@
+
+
+ 4.0.0
+
+
+ org.lance
+ lance-spark-root
+ 0.4.0-beta.4
+ ../pom.xml
+
+
+ lance-spark-knn_2.12
+ ${project.artifactId}
+ Indexed nearest-neighbor join for Lance datasets in Spark
+ jar
+
+
+
+ org.lance
+ lance-spark-base_2.12
+ ${project.version}
+
+
+ org.apache.spark
+ spark-sql_${scala.compat.version}
+ provided
+
+
+ org.lance
+ lance-spark-base_2.12
+ ${project.version}
+ test-jar
+ test
+
+
+
+ org.lance
+ lance-spark-3.5_2.12
+ ${project.version}
+ test
+
+
+ org.junit.jupiter
+ junit-jupiter
+ ${junit.version}
+ test
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ ${scala-maven-plugin.version}
+
+
+ scala-compile-first
+ process-resources
+
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ -feature
+
+
+
+
+
+ org.codehaus.mojo
+ exec-maven-plugin
+ 3.1.0
+
+
+ compile
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+
+ test-jar
+
+
+
+
+
+
+
+
+
+
+ benchmark
+
+
+
+ org.lance
+ lance-spark-3.5_2.12
+ ${project.version}
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ ${maven-shade-plugin.version}
+
+
+ benchmark-fat-jar
+
+ shade
+
+ package
+
+ true
+ benchmark
+
+
+
+ org.apache.spark:*
+
+ org.scala-lang:*
+
+ io.netty:*
+
+
+
+
+
+
+
+
+ org.lance.spark.knn.benchmark.IndexedNearestJoinBenchmark
+
+
+
+
+ *:*
+
+ LICENSE
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+ nativelib/darwin-aarch64/**
+ nativelib/linux-aarch64/**
+ aarch_64/**
+ x86_64/*.dylib
+ x86_64/*.dll
+
+
+
+
+
+
+
+
+
+ org.codehaus.mojo
+ exec-maven-plugin
+ 3.1.0
+
+ compile
+
+
+
+
+
+
+
diff --git a/lance-spark-knn_2.12/src/main/scala/org/apache/spark/sql/LanceKnnDatasetBridge.scala b/lance-spark-knn_2.12/src/main/scala/org/apache/spark/sql/LanceKnnDatasetBridge.scala
new file mode 100644
index 000000000..e2a260a2e
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/apache/spark/sql/LanceKnnDatasetBridge.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * Lives in the `org.apache.spark.sql` package solely to access `Dataset.ofRows`, which is
+ * `private[sql]` and not reachable from `org.lance.spark.knn`. The `IndexedNearestJoin`
+ * DataFrame API path needs to wrap a custom `LogicalPlan` (the staged probe / merge /
+ * materialize tree) as a `DataFrame`; the only public entry to do that goes through
+ * `SparkSession#sql()` (which expects a SQL string) or constructing `Dataset` directly,
+ * both of which require one-liner trampolines like this.
+ *
+ * Same trick most JVM Spark connectors use when they need to bridge between user-facing
+ * code and Catalyst internals. Kept minimal — exposing exactly one method, nothing else.
+ *
+ * == Spark 3.x vs 4.x: the `ofRows` location ==
+ *
+ * Spark 3.x keeps `Dataset` in `org.apache.spark.sql.Dataset`. Spark 4.0 moved the
+ * concrete DataFrame implementation to `org.apache.spark.sql.classic.Dataset` (the
+ * "classic" Dataset, vs Connect's remote DataFrame). `ofRows` is a `private[sql]` method
+ * in BOTH packages with compatible signatures, so reflection picks either at runtime.
+ *
+ * Single-source policy: the knn module is Spark-version-agnostic at compile time
+ * (binds only to `lance-spark-base` + `spark-sql` `provided`). If we compile-linked
+ * against `org.apache.spark.sql.Dataset.ofRows`, Spark 4.x runtime would
+ * `NoSuchMethodError`. Reflection avoids that — one lookup, cached at first call.
+ */
+object LanceKnnDatasetBridge {
+
+ // Look up once on first call; cache for subsequent invocations.
+ // Lazy initialization dodges a startup-time reflection cost when this object is loaded
+ // but kNearestJoin is never called.
+ private lazy val ofRowsInvoker: (SparkSession, LogicalPlan) => DataFrame = {
+ val candidates = Seq(
+ "org.apache.spark.sql.Dataset", // Spark 3.4 / 3.5
+ "org.apache.spark.sql.classic.Dataset" // Spark 4.0+
+ )
+ var found: Option[(SparkSession, LogicalPlan) => DataFrame] = None
+ val errors = scala.collection.mutable.ArrayBuffer.empty[String]
+ for (cls <- candidates if found.isEmpty) {
+ try {
+ val companion = Class.forName(cls + "$").getField("MODULE$").get(null)
+ // Find `ofRows` by shape rather than by exact parameter types: Spark 4.0+ takes
+ // `classic.SparkSession` as the first parameter instead of the abstract
+ // `sql.SparkSession`, so `getDeclaredMethod(..., classOf[SparkSession], ...)`
+ // won't match. `getMethods` + filter on name + arity + 2nd-param type
+ // (`LogicalPlan`, stable across versions) picks the canonical 2-arg overload
+ // and skips the 3-arg `(session, plan, tracker)` and 4-arg variants introduced
+ // in Spark 4.0+.
+ val method = companion.getClass.getMethods.find { m =>
+ m.getName == "ofRows" && m.getParameterCount == 2 &&
+ m.getParameterTypes()(1) == classOf[LogicalPlan]
+ }.getOrElse(throw new NoSuchMethodException(s"ofRows not found on $cls"))
+ method.setAccessible(true)
+ found = Some((s: SparkSession, p: LogicalPlan) => {
+ method.invoke(companion, s, p).asInstanceOf[DataFrame]
+ })
+ } catch {
+ case _: ClassNotFoundException | _: NoSuchMethodException =>
+ errors += cls
+ }
+ }
+ found.getOrElse(
+ throw new UnsupportedOperationException(
+ s"Could not locate Dataset.ofRows in any of: ${errors.mkString(", ")}. " +
+ "Unsupported Spark version."))
+ }
+
+ def asDataFrame(spark: SparkSession, plan: LogicalPlan): DataFrame =
+ ofRowsInvoker(spark, plan)
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/IndexedNearestJoin.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/IndexedNearestJoin.scala
new file mode 100644
index 000000000..616491d5f
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/IndexedNearestJoin.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{DataFrame, LanceKnnDatasetBridge}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.types._
+import org.lance.spark.knn.internal.{LanceFragments, LanceMaterializeStage, LanceMergeStage, LanceProbeStage, Metric}
+import org.lance.spark.knn.internal.staged.{LanceKnnStagedStrategy, LanceMaterializeLogicalPlan, LanceMergeLogicalPlan, LanceProbeLogicalPlan, ProbedLeftCodec}
+
+/**
+ * Public entry point for the indexed nearest-by join over a Lance dataset.
+ *
+ * Phase 1 — staged RDD pipeline. The previous Phase 0 inline `mapPartitions` is split into a
+ * three-stage pipeline that mirrors the IMPL_PLAN's eventual physical-operator design:
+ *
+ * {{{
+ * left.rdd
+ * -- leftKeyed (zipWithUniqueId) --> leftId stable per row
+ * -- LanceProbeStage --> (leftId, ProbedLeft) per-task probe
+ * -- reduceByKey (Exchange) --> shuffle by hash(leftId) -- refs travel through
+ * -- LanceMergeStage --> (leftId, ProbedLeft) K-way bounded merge
+ * -- LanceMaterializeStage --> Row point-fetch right rows
+ * }}}
+ *
+ * Public API is unchanged from Phase 0 — this is a pure refactor. The shuffle is degenerate in
+ * Phase 1 (single contributor per `leftId` because we still probe the whole dataset from each
+ * task) but is structurally present, so plan-shape inspection sees the staged form.
+ *
+ * == Trade-offs vs. Phase 0 ==
+ *
+ * - Phase 0 had no shuffle: probe + materialize ran in the same `mapPartitions`. Phase 1 adds an
+ * `Exchange` (`reduceByKey`) between the stages, shipping `(leftId, ProbedLeft)` pairs across
+ * the network. With single-task probing this is wasted bandwidth — the IMPL_PLAN's
+ * "refs only ~24B" bandwidth win lands when fragment-grouping is wired in (multiple probe
+ * tasks per right dataset). Until then Phase 1 is strictly slower than Phase 0 on small data.
+ * - The materialize stage now opens its own Lance dataset handle per task (separate task from
+ * the probe stage). One extra manifest read per task — cheap on Lance.
+ *
+ * Limitations carried forward unchanged from Phase 0:
+ * - No left broadcast / no per-fragment partitioning (Phase 1.5).
+ * - No filter pushdown into the probe (Phase 3).
+ * - Uses a synthetic `leftId` from `zipWithUniqueId`, not a user-supplied join key. Means we
+ * can't yet co-partition the left payload alongside `(leftId, refs)` to drop the leftRow
+ * from the shuffle.
+ */
+object IndexedNearestJoin {
+
+ /**
+ * Run an approximate-nearest-neighbor join.
+ *
+ * @param left left DataFrame; one query vector per row in `leftVecCol`
+ * @param rightLanceUri Lance dataset URI for the right side
+ * @param leftVecCol name of the vector column in `left`. Must be `ArrayType[Float]`.
+ * @param rightVecCol name of the indexed (or to-be-searched) vector column on the right
+ * @param k top-K rows per left row
+ * @param metric distance/similarity metric: "l2" / "cosine" / "dot" (and synonyms)
+ * @param rightProjection columns to materialize from the right side. Defaults to all data
+ * columns. The score column is added separately.
+ * @param outerJoin when true, left rows with zero matches are preserved with NULL right-
+ * side columns. Defaults to false (inner-join semantics).
+ * @param scoreCol name of the synthesized score column added to the output. Defaults to
+ * `__score`.
+ * @param overfetch ratio of internal candidates to k for indexed approximate metrics; with
+ * no index this has no effect because Lance returns exact top-K. Defaults
+ * to 1 because the merge stage's value comes from N-task aggregation, not
+ * single-task overfetch.
+ * @param nprobes optional override of Lance's `nprobes` for IVF-PQ indexes
+ * @param version optional Lance version pin; if unset, latest version is used
+ * @param refineFactor IVF-PQ recall knob. When set, Lance fetches `k * refineFactor`
+ * approximate candidates, re-ranks them with exact distance, and trims
+ * back to k. Higher = better recall, more compute. `None` leaves Lance's
+ * default (= 1, no re-rank). Ignored for non-IVF-PQ indexes / unindexed.
+ * @param ef HNSW search depth. Higher = better recall, more compute. `None` leaves
+ * Lance's default (the index's build-time `ef_construction` value).
+ * Ignored for non-HNSW indexes / unindexed.
+ * @param balanceFragmentsByRowCount when true (Phase 1.5 + Phase 3 skew handling), fragment
+ * groups are formed via LPT greedy bin-packing on per-fragment row
+ * counts so groups have roughly equal total work. Default `false` =
+ * round-robin (cheaper, fine when fragments are evenly sized).
+ * @param probeParallelism Phase 1.5 fragment-grouping degree. `1` (default) keeps the Phase 1
+ * path: one task probes the whole dataset per left row, with Lance doing
+ * the cross-fragment merge internally. `> 1` enumerates Lance fragments
+ * on the driver, splits them into N round-robin groups, and replicates
+ * each left row across the N groups so the merge stage actually has work
+ * to do. Capped at the number of Lance fragments — extra groups are
+ * empty and skipped.
+ */
+ def apply(
+ left: DataFrame,
+ rightLanceUri: String,
+ leftVecCol: String,
+ rightVecCol: String,
+ k: Int,
+ metric: String = "l2",
+ rightProjection: Option[Seq[String]] = None,
+ outerJoin: Boolean = false,
+ scoreCol: String = "__score",
+ overfetch: Int = 1,
+ nprobes: Option[Int] = None,
+ version: Option[Long] = None,
+ probeParallelism: Int = 1,
+ refineFactor: Option[Int] = None,
+ ef: Option[Int] = None,
+ balanceFragmentsByRowCount: Boolean = false): DataFrame = {
+
+ require(k > 0, "k must be positive")
+ require(overfetch >= 1, "overfetch must be >= 1")
+ require(probeParallelism >= 1, "probeParallelism must be >= 1")
+
+ val spark = left.sparkSession
+ val parsedMetric = Metric.fromName(metric)
+ val internalK = k * overfetch
+
+ // Snapshot right-side schema on the driver before any executor work happens.
+ val rightSchema: StructType = {
+ val reader = spark.read.format("lance")
+ version.foreach(v => reader.option("version", v.toString))
+ val raw = reader.load(rightLanceUri)
+ val pruned = rightProjection match {
+ case Some(cols) if cols.nonEmpty => raw.select(cols.head, cols.tail: _*)
+ case _ => raw
+ }
+ pruned.schema
+ }
+
+ val outputSchema = buildOutputSchema(left.schema, rightSchema, scoreCol)
+ val leftFieldCount = left.schema.fields.length
+ val leftVecIdx = left.schema.fieldIndex(leftVecCol)
+ val rightProjectionCols: Seq[String] =
+ rightProjection.getOrElse(rightSchema.fieldNames.toSeq)
+
+ val probeConf = LanceProbeStage.Conf(
+ datasetUri = rightLanceUri,
+ fragmentIds = None,
+ vectorColumn = rightVecCol,
+ version = version,
+ metric = parsedMetric,
+ k = internalK,
+ nprobes = nprobes,
+ leftVecIdx = leftVecIdx,
+ refineFactor = refineFactor,
+ ef = ef)
+
+ val mergeConf = LanceMergeStage.Conf(
+ finalK = k,
+ smallerIsBetter = parsedMetric.smallerIsBetter)
+
+ val materializeConf = LanceMaterializeStage.Conf(
+ datasetUri = rightLanceUri,
+ version = version,
+ rightProjection = rightProjectionCols,
+ rightFields = rightSchema.fields.toSeq,
+ leftFieldCount = leftFieldCount,
+ outerJoin = outerJoin)
+
+ // Driver-side fragment-group enumeration for the Phase 1.5 path. Done here so the
+ // probe operator doesn't have to talk to Lance's Java API during planning; the result
+ // is carried in the logical plan as a serialisable field.
+ val fragmentGroups: Option[Seq[Seq[Int]]] = if (probeParallelism > 1) {
+ val rawGroups = if (balanceFragmentsByRowCount) {
+ LanceFragments.enumerateGroupsByRowCount(rightLanceUri, version, probeParallelism)
+ } else {
+ LanceFragments.enumerateGroups(rightLanceUri, version, probeParallelism)
+ }
+ val nonEmpty = rawGroups.filter(_.nonEmpty)
+ if (nonEmpty.size <= 1) None else Some(nonEmpty)
+ } else {
+ None
+ }
+
+ // Three-logical-plan tree:
+ // LanceMaterializeLogicalPlan
+ // ↳ LanceMergeLogicalPlan (requiredChildDistribution=ClusteredDistribution(_leftId)
+ // at the Exec level ⇒ Catalyst inserts `ShuffleExchangeExec` here, AQE engages)
+ // ↳ LanceProbeLogicalPlan
+ // ↳ user's left logical plan
+ //
+ // The `references = child.outputSet` override on Merge/Materialize (see
+ // `StagedPlans.scala`) blocks Catalyst's `ColumnPruning` from inserting
+ // `Project(Nil)` wrappers — that insertion was what caused an early 3-exec
+ // iteration to crash with `AssertionError` / SIGSEGV in
+ // `ProbedLeftCodec.Decoder.decode` reading 0-field UnsafeRows.
+ LanceKnnStagedStrategy.ensureRegistered(spark)
+
+ val leftSchema = left.schema
+ val interStageAttrs = ProbedLeftCodec.interStageAttributes(leftSchema)
+ val finalAttrs = outputSchema.fields.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+ }.toSeq
+
+ val probeLogical = LanceProbeLogicalPlan(
+ child = left.queryExecution.analyzed,
+ stageConf = probeConf,
+ fragmentGroups = fragmentGroups,
+ leftSchema = leftSchema,
+ interStageOutput = interStageAttrs)
+ val mergeLogical = LanceMergeLogicalPlan(
+ child = probeLogical,
+ stageConf = mergeConf,
+ leftSchema = leftSchema,
+ interStageOutput = interStageAttrs)
+ val materializeLogical = LanceMaterializeLogicalPlan(
+ child = mergeLogical,
+ stageConf = materializeConf,
+ leftSchema = leftSchema,
+ finalSchema = outputSchema,
+ finalOutput = finalAttrs)
+
+ LanceKnnDatasetBridge.asDataFrame(spark, materializeLogical)
+ }
+
+ private def buildOutputSchema(
+ left: StructType,
+ right: StructType,
+ scoreCol: String): StructType = {
+ val rightNullable = right.fields.map(f => f.copy(nullable = true))
+ val score = StructField(scoreCol, FloatType, nullable = true)
+ StructType(left.fields ++ rightNullable :+ score)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala
new file mode 100644
index 000000000..5b349dd85
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/LanceKnnImplicits.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+
+/**
+ * Idiomatic DataFrame extension for the indexed nearest-K join. The Phase 2 SQL syntax
+ * (`APPROX NEAREST K BY DISTANCE ...`) requires Spark 4.2+ because that's where the
+ * `NearestByJoin` operator landed. The DataFrame API path here works on every Spark version
+ * the lance-spark connector supports (3.5, 4.0, 4.1, 4.2+) — it just calls Lance's Java probe
+ * API directly through `IndexedNearestJoin.apply`, no Catalyst rule, no SQL.
+ *
+ * Usage:
+ * {{{
+ * import org.lance.spark.knn.LanceKnnImplicits._
+ *
+ * val docs = spark.read.format("lance").load("/path/to/lance/dataset")
+ * val joined = queries.kNearestJoin(
+ * right = docs,
+ * leftVecCol = "qvec",
+ * rightVecCol = "vec",
+ * k = 10,
+ * metric = "l2")
+ * }}}
+ *
+ * The right DataFrame MUST be a Lance scan — `spark.read.format("lance").load(uri)`. The
+ * extension extracts the underlying URI from the right-side analyzed plan; if it can't find a
+ * `LanceTable` it throws `IllegalArgumentException`. This is intentional: the indexed path
+ * uses Lance's Java API directly to open the dataset, so a non-Lance DataFrame cannot be
+ * substituted (there's no general "any DataFrame" indexed path).
+ *
+ * == Why an extension method, not a builder ==
+ *
+ * Builder-style APIs (`new KNearestJoin(...).build()`) are heavier syntactically than what
+ * users want for what should be a one-line call. The extension method makes the verb
+ * (`kNearestJoin`) hang off the left DataFrame the same way `join` does, so users discover it
+ * via IDE autocomplete and can reach for it without learning a new pattern.
+ */
+object LanceKnnImplicits {
+
+ implicit class LanceKnnDataFrameOps(val df: DataFrame) extends AnyVal {
+
+ /**
+ * Approximate top-K nearest-neighbor join over a Lance-backed right DataFrame. The right
+ * DataFrame must be a `spark.read.format("lance").load(uri)` (any plan that wraps a
+ * `LanceTable` — `Filter`, `SubqueryAlias`, `Project` are unwrapped). For a non-Lance
+ * right side or a derived plan that loses the URI, this method throws.
+ *
+ * @param right Lance-backed right DataFrame
+ * @param leftVecCol name of the vector column on `this` (left)
+ * @param rightVecCol name of the vector column on `right`
+ * @param k number of nearest neighbors per left row
+ * @param metric distance / similarity metric: "l2" | "cosine" | "dot"
+ * @param rightProjection columns to materialize from `right`. `None` = all columns.
+ * @param outerJoin left-outer mode: emit a left row even if zero neighbors found
+ * @param scoreCol name of the appended score column (default `__score`)
+ * @param overfetch multiplier on `k` during the probe before final trim
+ * @param nprobes IVF cluster count to visit per query (None = Lance default)
+ * @param refineFactor IVF-PQ exact-distance re-rank factor (None = no re-rank)
+ * @param ef HNSW search depth (None = Lance default; only meaningful for
+ * HNSW indexes)
+ * @param probeParallelism fragment groups for Phase 1.5 probing. 1 = single task probes
+ * the whole dataset (recommended on a single-machine setup); >1
+ * splits fragments across N tasks for true distributed clusters
+ * @param balanceFragments when probeParallelism > 1, use row-count-aware LPT bin-packing
+ * for fragment groups instead of round-robin
+ */
+ // scalastyle:off parameter.number
+ def kNearestJoin(
+ right: DataFrame,
+ leftVecCol: String,
+ rightVecCol: String,
+ k: Int,
+ metric: String = "l2",
+ rightProjection: Option[Seq[String]] = None,
+ outerJoin: Boolean = false,
+ scoreCol: String = "__score",
+ overfetch: Int = 1,
+ nprobes: Option[Int] = None,
+ refineFactor: Option[Int] = None,
+ ef: Option[Int] = None,
+ probeParallelism: Int = 1,
+ balanceFragments: Boolean = false): DataFrame = {
+ val (uri, version) = LanceKnnImplicits.extractLanceUri(right)
+ IndexedNearestJoin(
+ left = df,
+ rightLanceUri = uri,
+ leftVecCol = leftVecCol,
+ rightVecCol = rightVecCol,
+ k = k,
+ metric = metric,
+ rightProjection = rightProjection,
+ outerJoin = outerJoin,
+ scoreCol = scoreCol,
+ overfetch = overfetch,
+ nprobes = nprobes,
+ version = version,
+ probeParallelism = probeParallelism,
+ refineFactor = refineFactor,
+ ef = ef,
+ balanceFragmentsByRowCount = balanceFragments)
+ }
+ // scalastyle:on parameter.number
+ }
+
+ /**
+ * Walk a DataFrame's analyzed plan looking for a `LanceTable`-backed
+ * `DataSourceV2Relation`. Skips through wrappers that don't change the underlying
+ * relation: `SubqueryAlias`, `View`, `Project`, `Filter`. Returns `(uri, optional version)`
+ * pulled from the relation's options. Throws `IllegalArgumentException` if no Lance scan
+ * is found.
+ *
+ * Lance detection mirrors `IndexedNearestByJoinRule.isLanceTable` —
+ * class-name match (`getClass.getName.contains("Lance")`) — to keep the user-facing
+ * extension working without a hard dependency on the connector's internal types. The
+ * extension only needs to be able to spot a Lance relation; it doesn't operate on it
+ * directly.
+ *
+ * Public for tests.
+ */
+ private[knn] def extractLanceUri(df: DataFrame): (String, Option[Long]) = {
+ val rel = findLanceRelation(df.queryExecution.analyzed).getOrElse {
+ throw new IllegalArgumentException(
+ "kNearestJoin requires the right DataFrame to be a Lance scan " +
+ "(spark.read.format(\"lance\").load(uri)). Plan was:\n" +
+ df.queryExecution.analyzed)
+ }
+ val opts = rel.options
+ val uri = Option(opts.get("path"))
+ .orElse(Option(opts.get("datasetUri")))
+ .getOrElse(throw new IllegalArgumentException(
+ "Lance relation found but no `path` / `datasetUri` option set; cannot extract URI"))
+ val version = Option(opts.get("version")).map(_.toLong)
+ (uri, version)
+ }
+
+ private def findLanceRelation(plan: LogicalPlan): Option[DataSourceV2Relation] = plan match {
+ case rel: DataSourceV2Relation if isLanceTable(rel) => Some(rel)
+ case other =>
+ // Iterator.find avoids 2.13's `nextOption()` so this stays Scala 2.12-compatible.
+ val it = other.children.iterator.map(findLanceRelation).filter(_.isDefined)
+ if (it.hasNext) it.next() else None
+ }
+
+ private def isLanceTable(rel: DataSourceV2Relation): Boolean = {
+ val cls = rel.table.getClass.getName
+ cls.contains("Lance") || cls.contains("lance")
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/CohereWikiRecallBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/CohereWikiRecallBenchmark.scala
new file mode 100644
index 000000000..6a2b89f77
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/CohereWikiRecallBenchmark.scala
@@ -0,0 +1,523 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.lance.{Dataset, ReadOptions}
+import org.lance.index.{IndexOptions, IndexParams, IndexType => LanceIndexType}
+import org.lance.index.vector.VectorIndexParams
+import org.lance.spark.LanceRuntime
+import org.lance.spark.knn.IndexedNearestJoin
+import org.lance.spark.knn.internal.Metric
+
+/**
+ * Cohere Wikipedia dim=768 recall benchmark -- production-shape companion to
+ * [[SiftRecallBenchmark]].
+ *
+ * SIFT validates the mechanics at 128-dim with natural image features. Real RAG /
+ * production embedding workloads are dim=768 (`bge-base`, `E5-base`,
+ * `sentence-transformers/all-mpnet-base-v2`) or dim=1024-1536 (`bge-large`, `text-embedding-3`).
+ * IVF-PQ behavior at high dim is materially different: PQ quantization error grows with
+ * dim, centroid Voronoi cells become more uniform, and `refineFactor` matters much more.
+ *
+ * Input: Cohere's `wikipedia-22-12` embeddings, pre-computed with
+ * `Cohere/multilingual-22-12` at dim=768. Hosted on HuggingFace:
+ * https://huggingface.co/datasets/Cohere/wikipedia-22-12
+ * Ships as Parquet files with columns `(id, title, text, url, wiki_id, paragraph_id,
+ * langs, emb)` where `emb` is `list` at dim=768.
+ *
+ * Unlike SIFT, **no ground truth is shipped**. We compute it ourselves via a brute-force
+ * Spark crossJoin + top-K pass on a held-out query sample. The base set is the remainder.
+ *
+ * == Downloading the data ==
+ *
+ * English Wikipedia chunks only (35M rows), partitioned across many Parquet files:
+ *
+ * {{{
+ * pip install huggingface_hub
+ * huggingface-cli download Cohere/wikipedia-22-12 \
+ * --repo-type dataset \
+ * --include 'en/[star].parquet' \
+ * --local-dir /tmp/cohere-wiki
+ * # -> /tmp/cohere-wiki/en/[star].parquet (~100 GB on disk for full EN)
+ * # (Replace [star] with the literal glob `*`. Written this way because Scaladoc's
+ * # parser interprets `[star]/` inside a comment as the end marker.)
+ * }}}
+ *
+ * For initial validation you don't want the full 35M. Point `COHERE_SOURCE_LIMIT` at
+ * 1-10M and we'll sample via Spark.
+ *
+ * == What this benchmark measures ==
+ *
+ * 1. **Index build time** for IVF-PQ / IVF-FLAT on N real Cohere embeddings. Expect
+ * materially slower than SIFT at same N because dim=768 is 6x wider.
+ * 2. **Recall@K** with ground truth computed by brute-force Spark. Distance is L2 by
+ * default; cosine also supported since these are unit-normalized embeddings and
+ * cosine ≡ L2 up to a monotone transform.
+ * 3. **Latency** per query at each (nprobes, refineFactor) point. Useful for picking a
+ * production config given a recall target.
+ *
+ * == Cluster run ==
+ *
+ * {{{
+ * ./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests
+ * # upload the fat jar + (optionally) the parquet files
+ *
+ * BENCH_CLUSTER_MODE=true \
+ * BENCH_DATA_PATH=s3://bucket/cohere-bench \
+ * COHERE_PARQUET=s3://my-bucket/cohere-wiki/en \
+ * COHERE_SOURCE_LIMIT=1000000 \
+ * COHERE_NUM_QUERIES=1000 \
+ * COHERE_K=10 \
+ * COHERE_NUM_PARTITIONS=1024 \
+ * COHERE_NUM_SUB_VECTORS=96 \
+ * COHERE_NPROBES_LIST=1,4,16,64 \
+ * COHERE_REFINE_LIST=1,4,16 \
+ * spark-submit --class org.lance.spark.knn.benchmark.CohereWikiRecallBenchmark
+ * }}}
+ *
+ * == Env knobs ==
+ *
+ * - `COHERE_PARQUET=` -- source parquet dir (required). Local or object store.
+ * - `COHERE_EMB_COL=emb` -- embedding column name (default `emb`).
+ * - `COHERE_SOURCE_LIMIT=1000000` -- sample this many rows total (base + queries)
+ * before splitting (default 1M). Set `0` for full
+ * dataset.
+ * - `COHERE_NUM_QUERIES=1000` -- held-out query count (default 1000). Brute-force
+ * GT is O(Nqueries × Nbase), so keep this modest.
+ * - `COHERE_K=10` -- top-K to measure (default 10).
+ * - `COHERE_METRIC=l2` -- l2 | cosine | dot (default l2). Cohere embeddings
+ * are unit-normalized so cosine ≡ (1 - dot) and
+ * L2² = 2 - 2·dot; they produce the same top-K ordering.
+ * - `COHERE_NUM_PARTITIONS=1024` -- IVF cluster count (default 1024 for ~1M rows;
+ * rule of thumb: sqrt(N) to N^(2/3)).
+ * - `COHERE_NUM_SUB_VECTORS=96` -- PQ subvectors; must divide 768 evenly
+ * (default 96 = 8 dims per subvector, 8-bit codes).
+ * - `COHERE_NPROBES_LIST=1,4,16,64` -- grid of nprobes to test.
+ * - `COHERE_REFINE_LIST=1,4,16` -- grid of refineFactor (IVF-PQ only).
+ * - `COHERE_INDEX=both` -- `ivfpq` | `ivfflat` | `both` (default `both`).
+ * - `COHERE_SEED=1337` -- RNG seed for base/query split.
+ * - `COHERE_SKIP_PREP=false` -- if "true", assume Lance base + query parquet + GT
+ * parquet already exist at `BENCH_DATA_PATH`.
+ * - `COHERE_SKIP_INDEX=false` -- if "true", skip index build (reuse existing).
+ * - `BENCH_CLUSTER_MODE`, `BENCH_DATA_PATH` -- same as other benchmarks.
+ *
+ * == What this does NOT do ==
+ *
+ * - Build ground truth in parallel with index grid sweeps. GT is computed once upfront,
+ * materialized to parquet under `BENCH_DATA_PATH/gt`, then all grid points compare
+ * against the same GT.
+ * - Support cross-lingual subsets (only English configured by default). Pass
+ * `COHERE_PARQUET` at a different subdir to benchmark German / French / etc.
+ * - Assert on specific recall numbers. Unlike SIFT, published IVF-PQ numbers for real
+ * dim=768 embeddings are less standardized. Use this benchmark to pick a nprobes /
+ * refine point for YOUR recall target, not to validate against a fixed expectation.
+ */
+object CohereWikiRecallBenchmark {
+
+ // -- env knobs ------------------------------------------------------------------------------
+
+ private val ClusterMode: Boolean =
+ sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true"))
+ private val DataPath: String =
+ sys.env.getOrElse("BENCH_DATA_PATH", "/tmp/lance-cohere-wiki")
+ private val ParquetPath: String = sys.env.getOrElse(
+ "COHERE_PARQUET",
+ sys.error("COHERE_PARQUET is required (path to Cohere wiki-22-12 parquet files)"))
+ private val EmbCol: String = sys.env.getOrElse("COHERE_EMB_COL", "emb")
+ private val SourceLimit: Long =
+ sys.env.get("COHERE_SOURCE_LIMIT").map(_.toLong).getOrElse(1000000L)
+ private val NumQueries: Int =
+ sys.env.get("COHERE_NUM_QUERIES").map(_.toInt).getOrElse(1000)
+ private val K: Int = sys.env.get("COHERE_K").map(_.toInt).getOrElse(10)
+ private val MetricName: String = sys.env.getOrElse("COHERE_METRIC", "l2").toLowerCase
+ private val NumPartitions: Int =
+ sys.env.get("COHERE_NUM_PARTITIONS").map(_.toInt).getOrElse(1024)
+ private val NumSubVectors: Int =
+ sys.env.get("COHERE_NUM_SUB_VECTORS").map(_.toInt).getOrElse(96)
+ private val NprobesList: Seq[Int] = sys.env
+ .getOrElse("COHERE_NPROBES_LIST", "1,4,16,64").split(",").map(_.trim.toInt).toSeq
+ private val RefineList: Seq[Int] = sys.env
+ .getOrElse("COHERE_REFINE_LIST", "1,4,16").split(",").map(_.trim.toInt).toSeq
+ private val IndexMode: String =
+ sys.env.getOrElse("COHERE_INDEX", "both").toLowerCase
+ private val Seed: Long = sys.env.get("COHERE_SEED").map(_.toLong).getOrElse(1337L)
+ private val SkipPrep: Boolean =
+ sys.env.get("COHERE_SKIP_PREP").exists(_.equalsIgnoreCase("true"))
+ private val SkipIndex: Boolean =
+ sys.env.get("COHERE_SKIP_INDEX").exists(_.equalsIgnoreCase("true"))
+
+ private lazy val BaseUri = s"$DataPath/base"
+ private lazy val QueriesUri = s"$DataPath/queries_parquet"
+ private lazy val GtUri = s"$DataPath/gt_parquet"
+
+ // -- main -----------------------------------------------------------------------------------
+
+ def main(args: Array[String]): Unit = {
+ val spark = buildSparkSession()
+ try {
+ logBanner(spark)
+
+ if (SkipPrep) {
+ println(s"[cohere] COHERE_SKIP_PREP=true -> reusing existing Lance + queries + GT")
+ } else {
+ prepareDatasets(spark)
+ }
+
+ val dim = detectDim(spark)
+ println(s"[cohere] detected dim=$dim from base dataset")
+ require(
+ dim % NumSubVectors == 0,
+ s"COHERE_NUM_SUB_VECTORS=$NumSubVectors does not divide dim=$dim evenly")
+
+ val indexesToRun: Seq[String] = IndexMode match {
+ case "ivfpq" => Seq("ivfpq")
+ case "ivfflat" => Seq("ivfflat")
+ case "both" => Seq("ivfflat", "ivfpq")
+ case other => sys.error(s"Unknown COHERE_INDEX=$other (expected ivfpq|ivfflat|both)")
+ }
+
+ if (!SkipIndex) indexesToRun.foreach(buildIndex)
+
+ // Load queries + GT once; reuse across the grid sweep.
+ val queriesDf = spark.read.parquet(QueriesUri).cache()
+ val queriesCount = queriesDf.count()
+ val gtDf = spark.read.parquet(GtUri).cache()
+ val gtCount = gtDf.count()
+ println(f"[cohere] queries: $queriesCount%,d ; ground-truth rows: $gtCount%,d " +
+ f"(expected ${queriesCount * K.toLong}%,d)")
+
+ // Build a driver-side GT lookup once: Map[qid -> Set[topK-rid]].
+ val gtByQid: Map[Long, Set[Long]] = gtDf
+ .groupBy("qid")
+ .agg(collect_list(col("rid")).as("rids"))
+ .collect()
+ .map(r =>
+ r.getLong(0) -> r.getSeq[Long](1).take(K).toSet)
+ .toMap
+
+ indexesToRun.foreach { idx =>
+ println()
+ println("=" * 80)
+ println(s" RECALL GRID: index=$idx, dim=$dim, metric=$MetricName, K=$K, " +
+ s"base sample=$SourceLimit, queries=$queriesCount")
+ println("=" * 80)
+ val refineGrid: Seq[Int] = if (idx == "ivfpq") RefineList else Seq(1)
+ println(
+ f"${"nprobes"}%8s ${"refine"}%6s ${"recall@K"}%10s ${"mean_ms"}%10s ${"queries"}%8s")
+ for (nprobes <- NprobesList; refine <- refineGrid) {
+ val (recall, meanMs) = runRecallGrid(
+ spark,
+ queriesDf,
+ gtByQid,
+ nprobes = nprobes,
+ refineFactor = if (idx == "ivfpq") Some(refine) else None)
+ println(f"$nprobes%8d $refine%6d $recall%10.4f $meanMs%10.2f $queriesCount%8d")
+ }
+ }
+ } finally {
+ spark.stop()
+ }
+ }
+
+ // -- Spark --------------------------------------------------------------------------------
+
+ private def buildSparkSession(): SparkSession = {
+ val b = SparkSession.builder().appName("CohereWikiRecallBenchmark")
+ if (!ClusterMode) {
+ b.master("local[*]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ }
+ b.getOrCreate()
+ }
+
+ private def logBanner(spark: SparkSession): Unit = {
+ println("=" * 80)
+ println("CohereWikiRecallBenchmark")
+ println("=" * 80)
+ println(f" Spark version: ${spark.version}")
+ println(f" master: ${spark.sparkContext.master}")
+ println(f" cluster mode: $ClusterMode")
+ println(f" source parquet: $ParquetPath")
+ println(f" emb column: $EmbCol")
+ println(f" source sample limit: $SourceLimit%,d (0 = no limit)")
+ println(f" num queries held out: $NumQueries%,d")
+ println(f" K: $K")
+ println(f" metric: $MetricName")
+ println(f" num IVF partitions: $NumPartitions")
+ println(f" num PQ subvectors: $NumSubVectors")
+ println(f" nprobes grid: ${NprobesList.mkString(",")}")
+ println(f" refine grid: ${RefineList.mkString(",")}")
+ println(f" index: $IndexMode")
+ println(f" data path: $DataPath")
+ println(f" seed: $Seed")
+ println("=" * 80)
+ println()
+ }
+
+ // -- preparation -------------------------------------------------------------------------
+
+ /**
+ * Reads the Cohere parquet, normalises schema, samples to `SourceLimit`, splits into
+ * held-out queries + base, writes base as Lance, writes queries as parquet, computes
+ * brute-force top-K ground truth and writes as parquet. All three artifacts land under
+ * `DataPath/` and are reused on subsequent runs with `COHERE_SKIP_PREP=true`.
+ */
+ private def prepareDatasets(spark: SparkSession): Unit = {
+ println(s"[cohere] reading source parquet: $ParquetPath")
+ val raw = spark.read.parquet(ParquetPath)
+ require(
+ raw.schema.fieldNames.contains(EmbCol),
+ s"Source parquet does not contain $EmbCol column; fields: ${raw.schema.fieldNames.mkString(",")}")
+
+ // Keep only an `id` (if present) and `emb` column. Rename to `rid` / `vec` for the
+ // downstream Lance schema. `id` in the Cohere dataset is a string; we synthesise a
+ // stable `rid: long` via monotonically_increasing_id to match our `_rowid`-oriented
+ // world.
+ val embColExpr: org.apache.spark.sql.Column = col(EmbCol).cast(ArrayType(FloatType))
+ val withId = raw.select(embColExpr.as("vec"))
+ .withColumn("rid", monotonically_increasing_id())
+ .select("rid", "vec")
+
+ val limited = if (SourceLimit > 0) withId.limit(SourceLimit.toInt) else withId
+
+ // Split: first `NumQueries` rows (after a shuffle for randomness) become queries;
+ // the rest is base. orderBy(rand(Seed)) is the cheap way to get a deterministic random
+ // split without a union-all join dance later.
+ val shuffled = limited.orderBy(rand(Seed))
+ val queries = shuffled.limit(NumQueries).cache()
+ val base = shuffled.exceptAll(queries)
+
+ println(s"[cohere] writing base to Lance: $BaseUri")
+ // Lance requires the vector column to be `FixedSizeList` for indexing.
+ // Passing via `.option("vec.arrow.fixed-size-list.size", ...)` on DataFrameWriter
+ // does NOT propagate to the written Lance schema (it's only honoured by the
+ // TBLPROPERTIES path in `CREATE TABLE`). The `cast(ArrayType(FloatType))` above
+ // also strips any field metadata the parquet reader might have surfaced.
+ //
+ // Working shape (mirrors BaseVectorCreateTableTest:182): build the DataFrame from
+ // fresh Rows using a StructType whose `vec` StructField has the
+ // `arrow.fixed-size-list.size` metadata. The driver-side round-trip is unavoidable
+ // at our dims — `cast(...).withColumn(...)` drops field metadata, and spark-sql has
+ // no expression API to re-tag.
+ //
+ // Also drops mode("overwrite") — Lance catalog's overwrite path calls drop-then-
+ // create and throws NoSuchTableException when the target path has never existed.
+ // Default (ErrorIfExists) is correct for first-run; set COHERE_SKIP_PREP=true on
+ // rerun to reuse the existing dataset.
+ val dim = detectDimFromFirstRow(base)
+ val embMeta = new MetadataBuilder()
+ .putLong("arrow.fixed-size-list.size", dim.toLong)
+ .build()
+ val taggedSchema = new StructType(Array(
+ StructField("rid", LongType, nullable = false),
+ StructField(
+ "vec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ embMeta)))
+ val baseRows = base.collect()
+ println(f"[cohere] collected ${baseRows.length}%,d rows to driver for schema retagging")
+ val javaRows = new java.util.ArrayList[Row](baseRows.length)
+ var i = 0
+ while (i < baseRows.length) {
+ val r = baseRows(i)
+ val s = r.getAs[scala.collection.Seq[Float]]("vec")
+ val arr = new Array[Float](s.length)
+ var j = 0
+ while (j < s.length) { arr(j) = s(j); j += 1 }
+ javaRows.add(org.apache.spark.sql.RowFactory.create(
+ java.lang.Long.valueOf(r.getLong(0)),
+ arr))
+ i += 1
+ }
+ val taggedBase = spark.createDataFrame(javaRows, taggedSchema)
+ taggedBase.write.format("lance").save(BaseUri)
+
+ println(s"[cohere] writing queries to parquet: $QueriesUri")
+ // Renumber query rids so they start at 0 -- the GT rids from the BASE set must not
+ // collide with query rids.
+ val qidCol = row_number().over(Window.orderBy("rid")).cast(LongType) - 1L
+ val queriesOut = queries.withColumn("qid", qidCol).select("qid", "vec")
+ queriesOut.write.mode("overwrite").parquet(QueriesUri)
+
+ println(s"[cohere] computing brute-force ground truth (k=$K, metric=$MetricName) -> $GtUri")
+ computeAndWriteGroundTruth(spark, queriesOut, base)
+ }
+
+ /**
+ * Compute recall-1.0 ground truth by brute-force crossJoin + distance + top-K window.
+ * Cost: `O(Nqueries × Nbase)` distance evaluations. At Nqueries=1000 × Nbase=1M ×
+ * dim=768, that's ~1 TB of computations; still minutes on a modest cluster.
+ *
+ * Output schema: `(qid: long, rid: long, rank: int)` with `rank` in [0, K). Written as
+ * parquet for reuse across grid sweeps.
+ */
+ private def computeAndWriteGroundTruth(
+ spark: SparkSession,
+ queriesDf: DataFrame,
+ baseDf: DataFrame): Unit = {
+ val distanceExpr = MetricName match {
+ case "l2" => l2DistSq(col("q.vec"), col("b.vec")).as("dist")
+ case "cosine" =>
+ // Assuming unit-normalized vectors (true for Cohere embeddings), cosine ≡ 1 - dot.
+ // Since 1 - dot is monotone in -dot, we use negative dot as the distance and sort
+ // ASC below -- same top-K as cosine distance ASC.
+ (-dotProduct(col("q.vec"), col("b.vec"))).as("dist")
+ case "dot" =>
+ // For raw dot, "nearest" = largest dot. Negate to use ASC sort semantics uniformly.
+ (-dotProduct(col("q.vec"), col("b.vec"))).as("dist")
+ case other => sys.error(s"Unsupported COHERE_METRIC=$other (expected l2|cosine|dot)")
+ }
+
+ val crossed = queriesDf.as("q")
+ .crossJoin(baseDf.as("b"))
+ .select(col("q.qid"), col("b.rid"), distanceExpr)
+
+ val w = Window.partitionBy("qid").orderBy(col("dist").asc)
+ val topK = crossed
+ .withColumn("rank", row_number().over(w) - 1)
+ .where(col("rank") < K)
+ .select("qid", "rid", "rank")
+
+ topK.write.mode("overwrite").parquet(GtUri)
+ }
+
+ /**
+ * L2² distance via element-wise subtract + square + sum. Spark 3.5+ has
+ * `vector_l2_distance` but we can't count on version; keep it portable.
+ */
+ private def l2DistSq(
+ a: org.apache.spark.sql.Column,
+ b: org.apache.spark.sql.Column): org.apache.spark.sql.Column = {
+ aggregate(
+ zip_with(
+ a,
+ b,
+ (x, y) => {
+ val d = x - y
+ d * d
+ }),
+ lit(0.0f),
+ (acc, v) => acc + v)
+ }
+
+ /** Inner product: element-wise multiply + sum. */
+ private def dotProduct(
+ a: org.apache.spark.sql.Column,
+ b: org.apache.spark.sql.Column): org.apache.spark.sql.Column = {
+ aggregate(
+ zip_with(a, b, (x, y) => x * y),
+ lit(0.0f),
+ (acc, v) => acc + v)
+ }
+
+ // -- schema / dim detection ---------------------------------------------------------------
+
+ private def detectDim(spark: SparkSession): Int = {
+ val base = spark.read.format("lance").load(BaseUri)
+ detectDimFromFirstRow(base)
+ }
+
+ private def detectDimFromFirstRow(df: DataFrame): Int = {
+ val first: Row = df.select("vec").head(1).head
+ first.getSeq[Float](0).length
+ }
+
+ // -- index build --------------------------------------------------------------------------
+
+ private def buildIndex(kind: String): Unit = {
+ println(s"[cohere] building $kind index on $BaseUri " +
+ s"(numPartitions=$NumPartitions" +
+ (if (kind == "ivfpq") s", numSubVectors=$NumSubVectors" else "") + ")")
+ val t0 = System.nanoTime()
+ val ds = Dataset.open().uri(BaseUri).allocator(LanceRuntime.allocator())
+ .readOptions(new ReadOptions.Builder().build()).build()
+ try {
+ val metric = MetricName match {
+ case "l2" => Metric.L2
+ case "cosine" => Metric.Cosine
+ case "dot" => Metric.Dot
+ case other => sys.error(s"Unsupported COHERE_METRIC=$other")
+ }
+ val vectorParams = kind match {
+ case "ivfpq" =>
+ VectorIndexParams.ivfPq(NumPartitions, NumSubVectors, 8, metric.lanceType, 50)
+ case "ivfflat" =>
+ VectorIndexParams.ivfFlat(NumPartitions, metric.lanceType)
+ }
+ val idxParams = IndexParams.builder().setVectorIndexParams(vectorParams).build()
+ val opts = IndexOptions.builder(
+ java.util.Collections.singletonList("vec"),
+ LanceIndexType.VECTOR,
+ idxParams).build()
+ ds.createIndex(opts)
+ } finally ds.close()
+ val sec = (System.nanoTime() - t0) / 1e9
+ println(f"[cohere] $kind build complete in $sec%.1f s")
+ }
+
+ // -- recall evaluation --------------------------------------------------------------------
+
+ /**
+ * Run all queries through the indexed nearest-join at one (nprobes, refine) point,
+ * compute mean recall@K vs `gtByQid`. Returns (meanRecall, meanLatencyMs).
+ */
+ private def runRecallGrid(
+ spark: SparkSession,
+ queriesDf: DataFrame,
+ gtByQid: Map[Long, Set[Long]],
+ nprobes: Int,
+ refineFactor: Option[Int]): (Double, Double) = {
+ // Normalise the queries DF schema to what kNearestJoin expects (a left vector column
+ // named `qvec`). The stored queries parquet has `(qid, vec)`; rename.
+ val left = queriesDf.select(col("qid").as("lid"), col("vec").as("qvec"))
+
+ val t0 = System.nanoTime()
+ val joined = IndexedNearestJoin(
+ left = left,
+ rightLanceUri = BaseUri,
+ leftVecCol = "qvec",
+ rightVecCol = "vec",
+ k = K,
+ metric = MetricName,
+ rightProjection = Some(Seq("rid")),
+ nprobes = Some(nprobes),
+ refineFactor = refineFactor)
+ val collected = joined.collect()
+ val elapsedMs = (System.nanoTime() - t0) / 1e6
+ val nQueries = left.count()
+
+ // Schema of `collected`: [lid, qvec, rid, __score]. Positions: 0,1,2,3.
+ val actualByQid = collected
+ .groupBy(_.getLong(0))
+ .map { case (qid, rowsArr) =>
+ qid -> rowsArr.toSeq.sortBy(_.getFloat(3)).take(K).map(_.getLong(2)).toSet
+ }
+
+ var recallSum = 0.0
+ gtByQid.keys.foreach { qid =>
+ val expected = gtByQid(qid)
+ val got = actualByQid.getOrElse(qid, Set.empty[Long])
+ recallSum += got.intersect(expected).size.toDouble / K
+ }
+ val meanRecall = recallSum / gtByQid.size
+ val meanLatencyMs = elapsedMs / nQueries
+ (meanRecall, meanLatencyMs)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinBenchmark.scala
new file mode 100644
index 000000000..ce704f48f
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinBenchmark.scala
@@ -0,0 +1,700 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.sql.{DataFrame, Row, RowFactory, SparkSession}
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.lance.spark.knn.IndexedNearestJoin
+import org.lance.spark.knn.LanceKnnImplicits._
+
+import java.nio.file.{Files, Paths}
+import java.util.{Locale, Random}
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+/**
+ * Benchmark comparing the indexed nearest-by-join paths against vanilla Spark's cross-product
+ * baseline. Works in both local and cluster mode.
+ *
+ * == Local run ==
+ *
+ * {{{
+ * cd /path/to/lance-spark
+ * ./mvnw -pl lance-spark-knn_2.12 install -DskipTests -Pbenchmark # build fat JAR
+ * MAVEN_OPTS="-Xmx12g" ./mvnw -pl lance-spark-knn_2.12 \
+ * exec:java -Pbenchmark \
+ * -Dexec.mainClass="org.lance.spark.knn.benchmark.IndexedNearestJoinBenchmark"
+ * }}}
+ *
+ * == Cluster run (YARN / K8s / managed-Spark distributions) ==
+ *
+ * Build the benchmark fat JAR first:
+ * {{{
+ * ./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests
+ * # → target/lance-spark-knn_2.12--benchmark.jar
+ * }}}
+ *
+ * Then upload and submit via your cluster's job API. Set environment variables in the job:
+ * - `BENCH_CLUSTER_MODE=true` — skips setting `.master()` and bind-address configs
+ * - `BENCH_DATA_PATH=` — shared path for synthetic datasets (s3://, hdfs://, etc.)
+ * - `BENCHMARK_SCALE=small` — `small`, `medium`, or `both` (default)
+ *
+ * The baseline Spark crossJoin is only run at small scale (O(|L|×|R|) is impractical at medium).
+ *
+ * == Environment variables ==
+ *
+ * | Variable | Default | Description |
+ * |--------------------|-------------------|--------------------------------------------------|
+ * | `BENCH_CLUSTER_MODE` | `false` | Set to `true` to skip `.master()` + bind addrs |
+ * | `BENCH_DATA_PATH` | tmp dir (local) | URI for synthetic Lance datasets |
+ * | `BENCHMARK_SCALE` | `both` | `small`, `medium`, or `both` |
+ *
+ * == What this measures ==
+ *
+ * Six configurations, run at two scales (small: 100K×100, medium: 1M×1000):
+ *
+ * A) Vanilla Spark cross-product — `crossJoin` + custom L2 UDF + `row_number` window.
+ * The textbook way a user would express nearest-by-join in Spark 3.5 (no
+ * `vector_l2_distance` until 4.2). Strictly slower than what `RewriteNearestByJoin`
+ * actually does on Spark 4.2 (heap-K via `min_by_k`). Kept as the historical
+ * headline-naive comparison.
+ * A2) Heap-K-shape baseline — `crossJoin` + L2 UDF + `groupBy(lid).agg(slice(
+ * sort_array(collect_list(struct(dist, rid)), asc=true), 1, K))` + `inline()`. The
+ * closest Spark 3.5 SQL expression of what Spark 4.2's `RewriteNearestByJoin` lowers
+ * `NearestByJoin` to. Still O(|R| log |R|) per group on 3.5 (`min_by_k`'s O(|R| log K)
+ * heap is 4.2-only); narrows the gap vs A but doesn't fully match 4.2-native.
+ * B) Phase 0/1 single-task probe — `df.kNearestJoin(probeParallelism = 1)`. One task
+ * probes the whole right dataset per partition; Lance does the cross-fragment merge
+ * internally.
+ * C) Phase 1.5 with 4 groups — `probeParallelism = 4`. Four parallel probe tasks,
+ * each handling a quarter of the right dataset's fragments. The merge stage actually
+ * aggregates contributions for the first time.
+ * D) Phase 1.5 with 8 groups — `probeParallelism = 8`.
+ * E) Phase 1.5 with 8 + skew bal — `probeParallelism = 8`, `balanceFragments = true`.
+ * LPT bin-packing on per-fragment row counts. With evenly-sized synthetic fragments this
+ * should match (D); the win lands on real-world skewed data.
+ *
+ * The B–E configs use the `df.kNearestJoin(rightDf, ...)` extension method, which is the
+ * idiomatic DataFrame API form (the URI-based `IndexedNearestJoin.apply` still works and
+ * does the same thing). The pipeline lowers to the 3-exec Catalyst-visible staged plan
+ * (`LanceProbeExec → ShuffleExchangeExec → LanceMergeExec → LanceMaterializeExec` under
+ * `AdaptiveSparkPlanExec`); `df.explain()` shows all four nodes and AQE coalesces the
+ * merge shuffle (`AQEShuffleRead coalesced`). See `IMPL_PLAN.md` "3-exec staged split
+ * — root cause and fix" for the ColumnPruning + `references = child.outputSet` detail
+ * that makes this shape safe against `count()`-style consumers.
+ *
+ * == What this does NOT measure ==
+ *
+ * IVF-PQ approximate vs. exact recall trade-off — requires building a vector index via Lance
+ * Java DDL, which lance-spark-knn's test setup doesn't yet do. Without an index Lance
+ * brute-force-scans each fragment, so all our paths return exact (recall = 1.0) results. The
+ * "X-x faster than vanilla Spark" headline is real; the additional 10-100x speedup from index
+ * lookups is a Phase 3.x demo.
+ */
+object IndexedNearestJoinBenchmark {
+
+ private val Dim: Int = 128
+ private val K: Int = 10
+ private val Seed: Long = 1337L
+
+ /**
+ * Each scale: (numRight, numLeft, numFragments, runBaseline). The vanilla-Spark crossJoin
+ * baseline is `O(|L|×|R|)`. At medium scale (1M × 1000 = 1B pairs) it's measured in tens
+ * of minutes per run on a typical 8-core executor; on a wider cluster (8 × 8c/32g, the
+ * current sizing) it's tractable but still slow — flip `runBaseline = true` and budget
+ * extra cluster time accordingly. The default is "true at both scales" so the published
+ * speedup is always against an actual measured number rather than an extrapolated one.
+ */
+ private case class Scale(
+ name: String,
+ numRight: Int,
+ numLeft: Int,
+ numFragments: Int,
+ runBaseline: Boolean) {
+ override def toString: String = s"$name (|R|=$numRight, |L|=$numLeft, frags=$numFragments)"
+ }
+ private val Small =
+ Scale("small", numRight = 100000, numLeft = 100, numFragments = 4, runBaseline = true)
+ private val Medium =
+ Scale("medium", numRight = 1000000, numLeft = 1000, numFragments = 8, runBaseline = true)
+
+ /**
+ * Sampling sweep + ground-truth scales for the medium-scale baseline extrapolation
+ * methodology. Cross-product `O(|L|·|R|·dim)` is genuinely linear in |R| and |L|
+ * separately (distance compute over a fixed number of pairs), so a sampled |R| sweep
+ * with fixed |L|, plus one |R|=full × |L|=reduced ground-truth, lets us extrapolate the
+ * full medium baseline number cheaply (~30 min cluster total) without spending
+ * 1+ hour/iter on the full 1B-pair cross-product.
+ */
+ private val SampleR10K =
+ Scale("sample_r10k", numRight = 10000, numLeft = 1000, numFragments = 8, runBaseline = true)
+ private val SampleR50K =
+ Scale("sample_r50k", numRight = 50000, numLeft = 1000, numFragments = 8, runBaseline = true)
+ private val SampleR100K =
+ Scale("sample_r100k", numRight = 100000, numLeft = 1000, numFragments = 8, runBaseline = true)
+ private val SampleR200K =
+ Scale("sample_r200k", numRight = 200000, numLeft = 1000, numFragments = 8, runBaseline = true)
+
+ /** Ground-truth at full |R|=1M but reduced |L|=100 — keeps |L|·|R| at 100M pairs (10× small). */
+ private val MediumL100 =
+ Scale("medium_l100", numRight = 1000000, numLeft = 100, numFragments = 8, runBaseline = true)
+
+ /** A single timing result. */
+ private case class Result(scale: String, config: String, medianMs: Long, runs: Seq[Long]) {
+ def speedupVs(baseline: Long): Double =
+ if (medianMs <= 0) Double.NaN else baseline.toDouble / medianMs
+ }
+
+ def main(args: Array[String]): Unit = {
+ val scales = sys.env.getOrElse("BENCHMARK_SCALE", "both").toLowerCase(Locale.ROOT) match {
+ case "small" => Seq(Small)
+ case "medium" => Seq(Medium)
+ case "baseline_sweep" =>
+ // Sampling sweep for cross-product O(|R|) extrapolation. Each scale at fixed |L|=1000
+ // so the linear-in-|R| coefficient comes out clean. Add MediumL100 as the ground
+ // truth: full |R|=1M with reduced |L|=100, which validates the extrapolation
+ // independently (it should match `extrap @ |R|=1M / 10` since |L| linearly affects
+ // wall-clock too).
+ Seq(SampleR10K, SampleR50K, SampleR100K, SampleR200K, MediumL100)
+ case "medium_l100" => Seq(MediumL100)
+ case _ => Seq(Small, Medium)
+ }
+ val clusterMode = sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true"))
+ val dataDirOpt = sys.env.get("BENCH_DATA_PATH")
+
+ println(banner("Indexed Nearest-By-Join Benchmark"))
+ val masterDesc = if (clusterMode) "cluster (BENCH_CLUSTER_MODE=true)" else "local[*]"
+ println(s"Spark master: $masterDesc Dim: $Dim K: $K Seed: $Seed")
+ println(s"Scales: ${scales.map(_.name).mkString(", ")}")
+ dataDirOpt.foreach(p => println(s"Data root: $p"))
+ println()
+
+ // BENCH_DISABLE_AQE=true turns off AQE for the whole benchmark run. Useful when
+ // benchmarking the cross-product baseline at scale, because AQE coalesces the
+ // post-shuffle partition count down (advisoryPartitionSizeInBytes default 64MB →
+ // a small shuffle gets squeezed to ~8 partitions even on a 64-core cluster), which
+ // throttles parallelism for downstream compute-heavy stages like the per-`lid`
+ // top-K aggregation. With AQE off, post-shuffle parallelism stays at the configured
+ // `spark.sql.shuffle.partitions`, fully utilizing cluster cores. Indexed-path runs
+ // do want AQE on (the merge-side shuffle benefits from coalesce/skew handling), so
+ // this is opt-in per-run.
+ val disableAqe = sys.env.get("BENCH_DISABLE_AQE").exists(_.equalsIgnoreCase("true"))
+
+ val builder = SparkSession
+ .builder()
+ .appName("indexed-nearest-by-join-benchmark")
+ .config("spark.sql.crossJoin.enabled", "true")
+ // shuffle.partitions left as a runtime override (submit-script default for cluster runs;
+ // Spark default 200 for local). Was hardcoded here to 32 historically — that was throttling
+ // post-shuffle parallelism on wider clusters. For local runs we set it to a sane local
+ // value below so we don't shuffle to 200 partitions on a laptop.
+ if (!clusterMode) {
+ builder.config("spark.sql.shuffle.partitions", "32")
+ builder
+ .master("local[*]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ }
+ if (disableAqe) {
+ builder.config("spark.sql.adaptive.enabled", "false")
+ }
+ val spark = builder.getOrCreate()
+ spark.sparkContext.setLogLevel("WARN")
+ if (disableAqe) println("[bench] AQE DISABLED for this run (BENCH_DISABLE_AQE=true)")
+
+ // Use user-supplied shared path in cluster mode, otherwise a local temp dir.
+ val (ownedTmpDir, dataRoot) = dataDirOpt match {
+ case Some(p) => (None, p)
+ case None =>
+ val tmp = Files.createTempDirectory("knn-bench-")
+ (Some(tmp.toFile), tmp.toString)
+ }
+
+ val results = scala.collection.mutable.ArrayBuffer.empty[Result]
+ try {
+ scales.foreach { scale =>
+ println(banner(s"Scale: $scale"))
+ val (leftDf, rightUri) = setupScale(spark, scale, dataRoot)
+ val configs = makeConfigs(leftDf, rightUri, scale.runBaseline)
+
+ // Sanity check: every config — INCLUDING the Spark crossJoin baseline — returns the
+ // same top-K row IDs as the in-memory brute-force oracle on a 16-row left subset.
+ // This is what makes the timing comparison meaningful: an 18×/608× number is hollow if
+ // the paths disagree on output. Run at every scale on a small subset so the baseline's
+ // O(|L|×|R|) crossJoin only does 16 × |R| work — sub-second even at medium scale.
+ verifyAllConfigsAgainstOracle(spark, leftDf, rightUri)
+
+ configs.foreach { case (name, run) =>
+ val r = timeIt(scale.name, name, run)
+ results += r
+ println(formatResult(r))
+ }
+ println()
+ }
+
+ println(banner("Summary"))
+ printSummaryTable(results.toSeq)
+ } finally {
+ spark.stop()
+ // Only clean up a locally-created temp dir; leave user-supplied paths untouched.
+ ownedTmpDir.foreach(deleteRecursively)
+ }
+ }
+
+ // -- workload setup ----------------------------------------------------------------------
+
+ private def setupScale(
+ spark: SparkSession,
+ scale: Scale,
+ tmpRoot: String): (DataFrame, String) = {
+ val rng = new Random(Seed)
+ println(s" Generating ${scale.numLeft} left rows × dim $Dim ...")
+ val leftDf = buildLeft(spark, rng, scale.numLeft).cache()
+ leftDf.count()
+
+ val rightUri = Paths.get(tmpRoot, s"right_${scale.name}").toString
+ val t0 = System.nanoTime()
+ println(s" Writing ${scale.numRight} right rows × dim $Dim across ${scale.numFragments} " +
+ s"Spark partitions to $rightUri ...")
+ writeRight(spark, rng, scale.numRight, scale.numFragments, rightUri)
+ println(s" ... done in ${TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - t0)}s")
+ (leftDf, rightUri)
+ }
+
+ private def buildLeft(spark: SparkSession, rng: Random, n: Int): DataFrame = {
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i), randomVector(rng, Dim))
+ }
+ spark.createDataFrame(rows.asJava, schema)
+ }
+
+ private def writeRight(
+ spark: SparkSession,
+ rng: Random,
+ n: Int,
+ fragments: Int,
+ uri: String): Unit = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ // Build rows on the driver in a streaming-ish pattern to keep memory bounded for medium scale.
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i + 1000000), randomVector(rng, Dim))
+ }
+ val df = spark.createDataFrame(rows.asJava, schema).repartition(fragments)
+ df.write.format("lance").save(uri)
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ // -- configurations ----------------------------------------------------------------------
+
+ private type Runnable = () => DataFrame
+
+ private def makeConfigs(
+ left: DataFrame,
+ rightUri: String,
+ runBaseline: Boolean): Seq[(String, Runnable)] = {
+ val spark = left.sparkSession
+ // Lance-backed right DataFrame for the new `df.kNearestJoin` extension. The extension
+ // pulls the URI back out of the right DataFrame's analyzed plan internally — same probe
+ // pipeline, just a more idiomatic call site that mirrors `df.join(other, ...)`.
+ val rightDf = spark.read.format("lance").load(rightUri)
+
+ val baseline: Runnable = () => crossProductTopK(spark, left, rightUri, K)
+ val baselineMinByK: Runnable = () => crossProductMinByK(spark, left, rightUri, K)
+ val phase01: Runnable = () =>
+ left.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 1)
+ val phase15_4: Runnable = () =>
+ left.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 4)
+ val phase15_8: Runnable = () =>
+ left.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 8)
+ val phase15_8_skew: Runnable = () =>
+ left.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 8,
+ balanceFragments = true)
+
+ val baseSeq = Seq(
+ "B: Phase 0/1 (probeParallelism=1)" -> phase01,
+ "C: Phase 1.5 (probeParallelism=4)" -> phase15_4,
+ "D: Phase 1.5 (probeParallelism=8)" -> phase15_8,
+ "E: Phase 1.5 (G=8, skew-balanced)" -> phase15_8_skew)
+ // BENCHMARK_INCLUDE_BASELINE_A=true to include the row_number-window baseline. Default
+ // off because at medium scale that plan shuffles |L|×|R| rows through a per-lid
+ // window — structurally bad for top-K (no partial aggregation, single-task per
+ // shuffle partition handles tens of GB of sorted data, runs hours). A2 (the heap-K
+ // shape) gets per-task partial aggregation and is the realistic baseline at scale.
+ val includeRowNumberBaseline =
+ sys.env.get("BENCHMARK_INCLUDE_BASELINE_A").exists(_.equalsIgnoreCase("true"))
+ if (runBaseline) {
+ val a2Seq = Seq("A2: crossJoin + L2 UDF + groupBy/sort_array(K)" -> baselineMinByK)
+ val aSeq =
+ if (includeRowNumberBaseline) Seq("A: crossJoin + L2 UDF + row_number window" -> baseline)
+ else Seq.empty
+ aSeq ++ a2Seq ++ baseSeq
+ } else {
+ baseSeq
+ }
+ }
+
+ /**
+ * Naive vanilla-Spark baseline (config A): cross product + L2 UDF + `row_number` window per
+ * `lid`. The textbook way a user might first express nearest-by-join on Spark 3.5 (which
+ * doesn't have `vector_l2_distance` — that's a 4.2 addition). Strictly slower than the
+ * `min_by_k` heap-K shape that Spark 4.2's `RewriteNearestByJoin` actually produces; kept
+ * here as the headline-naive comparison and as a stable apples-to-apples reference vs.
+ * earlier benchmark runs.
+ *
+ * `repartitionRightForBaseline` expands the right-side partitioning beyond the natural
+ * Lance-fragment count so the cross-join compute stage gets enough tasks to use all
+ * cluster cores (otherwise stages fuse on Lance's 8-fragment partitioning and only 8
+ * tasks run at a time).
+ */
+ private def crossProductTopK(
+ spark: SparkSession,
+ left: DataFrame,
+ rightUri: String,
+ k: Int): DataFrame = {
+ val l2 = l2UdfFactory()
+ val right =
+ repartitionRightForBaseline(spark.read.format("lance").load(rightUri).select("rid", "rvec"))
+ val crossed = left.crossJoin(right).withColumn("__dist", l2(col("lvec"), col("rvec")))
+ val w = Window.partitionBy("lid").orderBy(col("__dist"))
+ crossed.withColumn("__rank", row_number().over(w)).filter(col("__rank") <= k).select(
+ "lid",
+ "rid",
+ "__dist")
+ }
+
+ /**
+ * Expand right-side partitioning for the cross-product baselines. Lance reads produce
+ * one Spark partition per fragment (default 8 here); when the cross-join + UDF stage
+ * gets fused into a single stage it inherits that partitioning, capping wall-clock
+ * parallelism at fragment-count regardless of cluster cores. Repartition target is
+ * env-driven so the same code can run on different cluster widths; default 64 matches
+ * the 8 × 8c cluster shape (= cores).
+ */
+ private def repartitionRightForBaseline(df: DataFrame): DataFrame = {
+ val target = sys.env.get("BENCH_BASELINE_RIGHT_PARTITIONS").map(_.toInt).getOrElse(64)
+ if (target > 0) df.repartition(target) else df
+ }
+
+ /**
+ * Closer-to-RewriteNearestByJoin baseline (config A2): cross product + L2 UDF + groupBy
+ * + `sort_array(collect_list(struct(rid, dist)))` + `slice(_, 1, K)` + `inline()`. Spark
+ * 4.2's `RewriteNearestByJoin` lowers `NearestByJoin` to roughly:
+ *
+ * Project(j.output)
+ * +- Generate(Inline(_matches))
+ * +- Aggregate [__qid], first(left.*) ++ min_by(struct(right.*), expr, K)
+ * +- LEFT OUTER Join (no condition) — cross product
+ *
+ * `min_by(struct, expr, K)` (`MaxMinByK`, SPARK-55322) is Spark 4.2-only; it does
+ * `O(|R| log K)` per group via a bounded heap. On Spark 3.5 the closest expressible
+ * shape is `slice(sort_array(collect_list(struct(dist, rid)), asc=true), 1, K)`, which
+ * is `O(|R| log |R|)` per group — strictly slower than the 4.2-native lowering. Quoted
+ * here so the speedup-vs-baseline number reflects what's actually possible to express
+ * in Spark 3.5 SQL today, not the naive row_number form.
+ *
+ * We sort `struct(__dist, rid)` (distance first) so `sort_array` orders by distance
+ * ascending; the K smallest distances bubble to the front; `inline()` expands the
+ * K-element array into K rows.
+ */
+ private def crossProductMinByK(
+ spark: SparkSession,
+ left: DataFrame,
+ rightUri: String,
+ k: Int): DataFrame = {
+ val l2 = l2UdfFactory()
+ val right =
+ repartitionRightForBaseline(spark.read.format("lance").load(rightUri).select("rid", "rvec"))
+ val crossed = left.crossJoin(right).withColumn("__dist", l2(col("lvec"), col("rvec")))
+ crossed.groupBy("lid")
+ .agg(
+ slice(
+ sort_array(collect_list(struct(col("__dist"), col("rid"))), asc = true),
+ 1,
+ k).as("__matches"))
+ .select(col("lid"), inline(col("__matches")).as(Seq("__dist", "rid")))
+ .select("lid", "rid", "__dist")
+ }
+
+ private def l2UdfFactory(): org.apache.spark.sql.expressions.UserDefinedFunction =
+ udf((a: Seq[Float], b: Seq[Float]) => {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 }
+ s
+ })
+
+ // -- timing harness ----------------------------------------------------------------------
+
+ private val WarmupRuns = 1
+ private val MeasurementRuns = 3
+
+ /**
+ * Execute the plan fully and discard output — Spark's canonical benchmark sink, same
+ * shape as `WikipediaKnnPerfBenchmark.runFull`. Avoids the `count()`-bias where the
+ * crossJoin baseline skips result-row assembly while the indexed path runs
+ * `LanceMaterialize` in full (the `references = child.outputSet` override on the
+ * materialize logical plan blocks ColumnPruning from removing the join-row columns).
+ */
+ private def runFull(df: DataFrame): Unit =
+ df.write.format("noop").mode("overwrite").save()
+
+ /** Run `f` once for warmup, then 3x for measurement. Median wall-clock in ms. */
+ private def timeIt(scale: String, config: String, f: Runnable): Result = {
+ print(s" $config ... ")
+ System.out.flush()
+ var i = 0
+ while (i < WarmupRuns) { runFull(f()); i += 1 }
+ val runs = (0 until MeasurementRuns).map { _ =>
+ val t0 = System.nanoTime()
+ runFull(f())
+ TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0)
+ }
+ val sortedRuns = runs.sorted
+ val median = sortedRuns(sortedRuns.length / 2)
+ println(s"runs=${runs.mkString("[", ",", "]")} ms, median=$median ms")
+ Result(scale, config, median, runs)
+ }
+
+ // -- oracle equivalence -----------------------------------------------------------------
+
+ /**
+ * Sanity check: confirm Phase 0/1 and Phase 1.5 paths agree with a brute-force oracle on a
+ * 16-row subset of the left side. If this disagrees with the indexed path the benchmark
+ * numbers are meaningless, so we bail before spending minutes on incorrect timings.
+ */
+ private def verifyOracleEquivalence(
+ spark: SparkSession,
+ leftDf: DataFrame,
+ rightUri: String): Unit = {
+ println(" Sanity check: indexed-path top-K matches brute-force oracle on a 16-row subset ...")
+ val leftSubset = leftDf.limit(16).cache()
+ leftSubset.count()
+ val rightDf = spark.read.format("lance").load(rightUri)
+ val joined = leftSubset.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid", "rvec")),
+ probeParallelism = 4)
+
+ val byLid = joined.collect().groupBy(_.getAs[Int]("lid"))
+ val rightVecs = readRightVectors(spark, rightUri)
+ val rightIds = readRightIds(spark, rightUri)
+ val leftRows = leftSubset.collect()
+
+ leftRows.foreach { lr =>
+ val lid = lr.getAs[Int]("lid")
+ val leftVec = lr.getAs[Seq[Float]]("lvec").toArray
+ val oracleIds = rightVecs.indices
+ .map(i => (rightIds(i), l2(leftVec, rightVecs(i))))
+ .sortBy(_._2)
+ .take(K)
+ .map(_._1)
+ .toSet
+ val actualIds = byLid(lid).map(_.getAs[Int]("rid")).toSet
+ if (oracleIds != actualIds) {
+ sys.error(
+ s"ORACLE MISMATCH at lid=$lid:\n oracle: $oracleIds\n actual: $actualIds")
+ }
+ }
+ leftSubset.unpersist()
+ println(" ... oracle equivalence holds.")
+ }
+
+ /**
+ * Run EVERY config (including the Spark crossJoin baseline) on a 16-row left subset and
+ * compare each result against an in-memory brute-force oracle. Running on a subset keeps
+ * the slow baseline tractable (16 × |R| pair evaluations is sub-second even at medium scale)
+ * while still validating that all paths produce the SAME top-K as the ground truth.
+ *
+ * Compared as Sets to tolerate tied-distance ordering. Random data makes exact ties rare in
+ * practice, but the comparison is robust either way.
+ */
+ private def verifyAllConfigsAgainstOracle(
+ spark: SparkSession,
+ leftDf: DataFrame,
+ rightUri: String): Unit = {
+ println(" Sanity check: indexed-path configs match brute-force oracle on a 16-row subset ...")
+ val left16 = leftDf.limit(16).cache()
+ left16.count()
+ val leftIds = left16.select("lid").collect().map(_.getInt(0)).toSet
+
+ // Brute-force oracle in plain Scala — the ground truth.
+ val rightVecs = readRightVectors(spark, rightUri)
+ val rightIds = readRightIds(spark, rightUri)
+ val leftRows = left16.collect()
+ val oracleByLid: Map[Int, Set[Int]] = leftRows.map { r =>
+ val lid = r.getAs[Int]("lid")
+ val lvec = r.getAs[Seq[Float]]("lvec").toArray
+ val topKRids = rightVecs.indices
+ .map(i => (rightIds(i), l2(lvec, rightVecs(i))))
+ .sortBy(_._2)
+ .take(K)
+ .map(_._1)
+ .toSet
+ lid -> topKRids
+ }.toMap
+
+ // Build mini-configs that close over `left16`, not the full left.
+ // Validate B/C/D/E against the in-memory brute-force oracle. We deliberately do NOT run
+ // config A (Spark crossJoin) here — its output IS brute force by construction, so
+ // comparing Spark's crossJoin to the in-memory brute force is a tautology, while the
+ // window-function pipeline can be slow enough on 16 × |R| pairs to dominate the
+ // benchmark's wall-clock. The semantic question "is the indexed path correct" is fully
+ // answered by checking B/C/D/E against the oracle directly.
+ val miniConfigs = makeConfigs(left16, rightUri, runBaseline = false)
+
+ miniConfigs.foreach { case (name, run) =>
+ val rows = run().collect()
+ val byLid = rows.groupBy(_.getAs[Int]("lid"))
+ .map { case (lid, rs) => lid -> rs.map(_.getAs[Int]("rid")).toSet }
+ leftIds.foreach { lid =>
+ val expected = oracleByLid(lid)
+ val actual = byLid.getOrElse(lid, Set.empty[Int])
+ if (expected != actual) {
+ sys.error(
+ s"ORACLE MISMATCH for $name at lid=$lid:\n oracle: $expected\n actual: $actual")
+ }
+ }
+ }
+ left16.unpersist()
+ println(
+ s" ... all ${miniConfigs.size} indexed configs match the oracle " +
+ s"(sample size: ${leftIds.size}).")
+ }
+
+ private def readRightVectors(spark: SparkSession, uri: String): Array[Array[Float]] =
+ spark.read.format("lance").load(uri).orderBy("rid").collect().map { r =>
+ r.getAs[Seq[Float]]("rvec").toArray
+ }
+
+ private def readRightIds(spark: SparkSession, uri: String): Array[Int] =
+ spark.read.format("lance").load(uri).orderBy("rid").collect().map(_.getAs[Int]("rid"))
+
+ private def l2(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 }
+ s
+ }
+
+ // -- output formatting ------------------------------------------------------------------
+
+ private def banner(s: String): String = s"\n=== $s " + ("=" * (76 - s.length - 5))
+
+ private def formatResult(r: Result): String =
+ f" -> ${r.config}%-40s median=${r.medianMs}%6d ms"
+
+ private def printSummaryTable(results: Seq[Result]): Unit = {
+ val byScale = results.groupBy(_.scale).map { case (k, vs) => k -> vs.sortBy(_.config) }
+ val scaleOrder = Seq("small", "medium").filter(byScale.contains)
+ println()
+ val configWidth = 38
+ val numWidth = 13
+ val header = "%-".concat(s"${configWidth}s") +
+ scaleOrder.map(_ => s"%${numWidth}s").mkString
+ val divider = "-" * (configWidth + scaleOrder.size * numWidth)
+ println(divider)
+ val args1 = "Configuration" +: scaleOrder.map(s => s"$s (ms)")
+ println(header.format(args1: _*))
+ val args2 = "" +: scaleOrder.map(s => s"speedup ×")
+ println(header.format(args2: _*))
+ println(divider)
+
+ val configs = scaleOrder.flatMap(byScale(_).map(_.config)).distinct
+ // Use config "A:" (row_number window) as the headline reference, since it's the value
+ // historical numbers were quoted against. A2 (sort_array(K)) is reported as a separate
+ // row whose speedup-vs-A also shows in the table — that's how much an `min_by_k`-style
+ // rewrite would shrink the gap on Spark 3.5 SQL.
+ val baselineByScale = scaleOrder.flatMap { s =>
+ byScale(s).find(_.config.startsWith("A:")).map(b => s -> b.medianMs)
+ }.toMap
+ configs.foreach { config =>
+ val cellsMs = scaleOrder.map { s =>
+ byScale(s).find(_.config == config).map(_.medianMs.toString).getOrElse("-")
+ }
+ println(header.format((config +: cellsMs): _*))
+ val cellsSpeedup = scaleOrder.map { s =>
+ val mineMs = byScale(s).find(_.config == config).map(_.medianMs).getOrElse(0L)
+ val baseMs = baselineByScale.getOrElse(s, 0L)
+ if (config.startsWith("A:")) "1.00x"
+ else if (mineMs <= 0 || baseMs <= 0) "(no base)"
+ else f"${baseMs.toDouble / mineMs}%.2fx"
+ }
+ println(header.format(("" +: cellsSpeedup): _*))
+ }
+ println(divider)
+ println(
+ "Speedup is `baseline(A) / config`. Higher = faster than the row_number-window baseline.")
+ println("A2 is the closer-to-RewriteNearestByJoin shape (sort_array(K) instead of full sort);")
+ println("compare A vs A2 to see how much the heap-K rewrite alone narrows the gap.")
+ }
+
+ private def deleteRecursively(f: java.io.File): Unit = {
+ if (f.isDirectory) f.listFiles().foreach(deleteRecursively)
+ f.delete()
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinSoakTest.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinSoakTest.scala
new file mode 100644
index 000000000..6e7dd2ecb
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/IndexedNearestJoinSoakTest.scala
@@ -0,0 +1,433 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.sql.{DataFrame, Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.lance.spark.knn.LanceKnnImplicits._
+
+import java.lang.management.ManagementFactory
+import java.util.Random
+import java.util.concurrent.{ConcurrentLinkedQueue, Executors, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.JavaConverters._
+
+/**
+ * Concurrent sustained-load soak test for `IndexedNearestJoin` / `df.kNearestJoin`.
+ *
+ * Production-readiness validation #2 ("sustained concurrent load") from the must-validate
+ * list: run N concurrent queries/second for M minutes and watch for:
+ *
+ * - Memory growth (JVM heap + off-heap direct buffers + Arrow allocator)
+ * - File handle leaks (executor-side Lance dataset handles)
+ * - GC pressure trends (Old-gen growth over time ⇒ a leak)
+ * - Per-query latency drift (early queries vs late queries ⇒ resource contention)
+ * - Failure rate (any query throwing indicates correctness regression under load)
+ *
+ * The unit benchmark ([[IndexedNearestJoinBenchmark]]) measures single-query throughput;
+ * this benchmark measures behavior under many simultaneous queries.
+ *
+ * == Cluster run ==
+ *
+ * Build the benchmark fat JAR first (same profile as the throughput benchmark):
+ *
+ * {{{
+ * ./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests
+ * # → target/lance-spark-knn_2.12--benchmark.jar
+ * }}}
+ *
+ * Submit via your cluster's job API, passing these environment variables:
+ *
+ * - `BENCH_CLUSTER_MODE=true` — skips `.master()` and driver bind-address config
+ * - `BENCH_DATA_PATH=` — shared path for the synthetic right-side Lance dataset
+ * (s3://, abfs://, gs://, hdfs://)
+ * - `SOAK_RIGHT_ROWS=100000000` — size of the right-side Lance dataset (default: 10M)
+ * - `SOAK_LEFT_ROWS=1000` — left rows per query (default: 100)
+ * - `SOAK_DIM=128` — vector dimension (default: 128)
+ * - `SOAK_K=10` — top-K (default: 10)
+ * - `SOAK_DURATION_MIN=60` — total soak wall-clock minutes (default: 5 for smoke)
+ * - `SOAK_CONCURRENCY=8` — concurrent queries in flight (default: 4)
+ * - `SOAK_QPS_TARGET=2` — target queries per second per driver thread (default: 2)
+ * - `SOAK_PROBE_PARALLELISM=1` — `probeParallelism` per query (default: 1)
+ * - `SOAK_SEED=1337` — RNG seed for reproducibility (default: 1337)
+ * - `SOAK_SETUP_ONLY=false` — if "true", writes the right-side dataset and exits
+ * (so you can pre-warm once and run the soak many times)
+ * - `SOAK_SKIP_SETUP=false` — if "true", assumes the dataset at `BENCH_DATA_PATH`
+ * already exists (for re-running the soak without rewriting)
+ *
+ * Report at end: p50/p95/p99/max latency, queries/sec, failure count, heap & off-heap
+ * snapshots at start / midpoint / end, GC counts & times. Streams per-query timings to
+ * stdout every 60 s so you can `tail -f` driver logs during a long run.
+ *
+ * == What this does NOT measure ==
+ *
+ * - Executor-side resource trends — those need cluster-level monitoring (Grafana /
+ * Spark UI / JMX). The driver-side snapshots here are an upper-bound check; a
+ * worker-side leak won't show up here.
+ * - Correctness under load — each query uses random left vectors; we don't validate
+ * against an oracle per query (too expensive). Any thrown exception IS counted and
+ * the stack trace is logged.
+ * - Very long-running behavior (days). Default is 5 minutes for smoke; set
+ * `SOAK_DURATION_MIN` higher for real soak (4-24h recommended for production
+ * qualification).
+ */
+object IndexedNearestJoinSoakTest {
+
+ private val Dim: Int = sys.env.get("SOAK_DIM").map(_.toInt).getOrElse(128)
+ private val K: Int = sys.env.get("SOAK_K").map(_.toInt).getOrElse(10)
+ private val RightRows: Long =
+ sys.env.get("SOAK_RIGHT_ROWS").map(_.toLong).getOrElse(10000000L)
+ private val LeftRows: Int = sys.env.get("SOAK_LEFT_ROWS").map(_.toInt).getOrElse(100)
+ private val DurationMin: Int =
+ sys.env.get("SOAK_DURATION_MIN").map(_.toInt).getOrElse(5)
+ private val Concurrency: Int =
+ sys.env.get("SOAK_CONCURRENCY").map(_.toInt).getOrElse(4)
+ private val QpsTarget: Double =
+ sys.env.get("SOAK_QPS_TARGET").map(_.toDouble).getOrElse(2.0)
+ private val ProbeParallelism: Int =
+ sys.env.get("SOAK_PROBE_PARALLELISM").map(_.toInt).getOrElse(1)
+ private val Seed: Long = sys.env.get("SOAK_SEED").map(_.toLong).getOrElse(1337L)
+ private val SetupOnly: Boolean =
+ sys.env.get("SOAK_SETUP_ONLY").exists(_.equalsIgnoreCase("true"))
+ private val SkipSetup: Boolean =
+ sys.env.get("SOAK_SKIP_SETUP").exists(_.equalsIgnoreCase("true"))
+
+ private val ClusterMode: Boolean =
+ sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true"))
+ private val DataPath: String = sys.env.getOrElse(
+ "BENCH_DATA_PATH",
+ "/tmp/lance-knn-soak")
+
+ // Stats collected across all queries.
+ private val completed = new AtomicInteger(0)
+ private val failed = new AtomicInteger(0)
+ private val latenciesNs = new ConcurrentLinkedQueue[Long]()
+
+ def main(args: Array[String]): Unit = {
+ val spark = buildSparkSession()
+ try {
+ logBanner(spark)
+
+ val rightUri = s"$DataPath/right_soak_${RightRows}_${Dim}"
+ if (SkipSetup) {
+ println(s"[soak] SOAK_SKIP_SETUP=true → using existing dataset at $rightUri")
+ } else {
+ setupRightDataset(spark, rightUri)
+ }
+
+ if (SetupOnly) {
+ println("[soak] SOAK_SETUP_ONLY=true → dataset written, exiting before load phase.")
+ return
+ }
+
+ runSoak(spark, rightUri)
+ printFinalReport(spark)
+ } finally {
+ spark.stop()
+ }
+ }
+
+ // -- setup ----------------------------------------------------------------------------------
+
+ private def buildSparkSession(): SparkSession = {
+ val b = SparkSession.builder().appName("IndexedNearestJoin-SoakTest")
+ if (!ClusterMode) {
+ b.master("local[*]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ }
+ // Schedule FIFO→FAIR so concurrent queries actually run concurrently on the same
+ // SparkContext. Without FAIR, queries serialise behind each other regardless of
+ // SOAK_CONCURRENCY.
+ b.config("spark.scheduler.mode", "FAIR")
+ .getOrCreate()
+ }
+
+ private def logBanner(spark: SparkSession): Unit = {
+ println("=" * 80)
+ println(s"IndexedNearestJoinSoakTest")
+ println("=" * 80)
+ println(f" Spark version: ${spark.version}")
+ println(f" master: ${spark.sparkContext.master}")
+ println(f" applicationId: ${spark.sparkContext.applicationId}")
+ println(f" default parallelism: ${spark.sparkContext.defaultParallelism}")
+ println(f" cluster mode: $ClusterMode")
+ println(f" data path: $DataPath")
+ println(" -- load knobs --")
+ println(f" right rows: $RightRows%,d")
+ println(f" left rows/query: $LeftRows%,d")
+ println(f" dim: $Dim")
+ println(f" K: $K")
+ println(f" probeParallelism: $ProbeParallelism")
+ println(f" concurrency: $Concurrency")
+ println(f" qps target: $QpsTarget%.2f queries/sec")
+ println(f" duration: $DurationMin minutes")
+ println(f" seed: $Seed")
+ println("=" * 80)
+ println()
+ }
+
+ private def setupRightDataset(spark: SparkSession, uri: String): Unit = {
+ println(s"[soak] writing right-side Lance dataset: $RightRows rows × dim=$Dim → $uri")
+ val t0 = System.nanoTime()
+
+ // Build the right DF via `spark.range` + a UDF-ish map so the data generation is
+ // distributed. For very large RightRows this is the critical scalability path.
+ val schema = new StructType(Array(
+ StructField("rid", LongType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ val capturedDim = Dim
+ val capturedSeed = Seed
+ val rdd = spark.sparkContext
+ .range(0L, RightRows, 1L, math.max(spark.sparkContext.defaultParallelism * 4, 16))
+ .mapPartitionsWithIndex { case (partIdx, iter) =>
+ val rng = new Random(capturedSeed + partIdx.toLong)
+ iter.map { i =>
+ val v = new Array[Float](capturedDim)
+ var j = 0
+ while (j < capturedDim) { v(j) = rng.nextFloat(); j += 1 }
+ // Pass the Array[Float] directly - Spark's ArrayType encoder on 2.12 expects
+ // either an Array or a scala.collection.Seq, NOT a Java List (which is what
+ // `.toSeq.asJava` produces and causes `Wrappers$SeqWrapper incompatible with
+ // scala.collection.Seq` at encode time).
+ RowFactory.create(java.lang.Long.valueOf(i), v): Row
+ }
+ }
+ val df = spark.createDataFrame(rdd, schema)
+ // Lance's catalog treats mode("overwrite") as "drop then create", which throws
+ // NoSuchTableException when the path is new. Default (ErrorIfExists) is correct for
+ // first-run setup; reuse across runs should set SOAK_SKIP_SETUP=true instead.
+ df.write.format("lance").save(uri)
+
+ val elapsedSec = (System.nanoTime() - t0) / 1e9
+ println(f"[soak] right dataset written in $elapsedSec%.1f s")
+ println()
+ }
+
+ // -- soak loop ------------------------------------------------------------------------------
+
+ private def runSoak(spark: SparkSession, rightUri: String): Unit = {
+ val right = spark.read.format("lance").load(rightUri)
+
+ // Per-query left DataFrames are generated in the driver; cheap at LeftRows ≤ few-K.
+ // Cache the schema and random generation once.
+ val leftSchema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "qvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ val deadline = System.currentTimeMillis() + DurationMin.toLong * 60L * 1000L
+ val pool = Executors.newFixedThreadPool(Concurrency)
+ val ticker = Executors.newSingleThreadScheduledExecutor()
+ val halfway = System.currentTimeMillis() + (DurationMin.toLong * 30L * 1000L)
+
+ val snapStart = snapshotProcess(spark)
+ var snapMid: ProcessSnapshot = null
+
+ // Per-60-s progress printer.
+ ticker.scheduleAtFixedRate(
+ new Runnable {
+ override def run(): Unit = {
+ val done = completed.get()
+ val bad = failed.get()
+ val snap = snapshotProcess(spark)
+ println(f"[soak] t=${elapsedSec()}%6.1fs completed=$done%,d failed=$bad%,d " +
+ f"heap=${snap.heapUsedMb}%,d MB directMem=${snap.directUsedMb}%,d MB " +
+ f"gcCount=${snap.gcCount} gcTimeMs=${snap.gcTimeMs}")
+ }
+ },
+ 60L,
+ 60L,
+ TimeUnit.SECONDS)
+
+ val intervalMs = math.max(1L, (1000.0 / QpsTarget).toLong)
+ val startTime = System.currentTimeMillis()
+ var queriesSubmitted = 0L
+
+ try {
+ while (System.currentTimeMillis() < deadline) {
+ val qid = queriesSubmitted
+ queriesSubmitted += 1
+ val task: Runnable = new Runnable {
+ override def run(): Unit = runOneQuery(spark, right, leftSchema, qid)
+ }
+ pool.submit(task)
+
+ if (snapMid == null && System.currentTimeMillis() >= halfway) {
+ snapMid = snapshotProcess(spark)
+ println(f"[soak] midpoint snapshot: heap=${snapMid.heapUsedMb}%,d MB " +
+ f"directMem=${snapMid.directUsedMb}%,d MB gcCount=${snapMid.gcCount}")
+ }
+
+ // Throttle submission rate. Without this the pool fills with backlog and the
+ // `SOAK_QPS_TARGET` knob is meaningless.
+ val targetSubmitTime = startTime + queriesSubmitted * intervalMs
+ val sleepMs = targetSubmitTime - System.currentTimeMillis()
+ if (sleepMs > 0) Thread.sleep(sleepMs)
+ }
+ } finally {
+ ticker.shutdown()
+ pool.shutdown()
+ pool.awaitTermination(2, TimeUnit.MINUTES)
+ ticker.awaitTermination(5, TimeUnit.SECONDS)
+ }
+ }
+
+ private def runOneQuery(
+ spark: SparkSession,
+ right: DataFrame,
+ leftSchema: StructType,
+ qid: Long): Unit = {
+ val t0 = System.nanoTime()
+ try {
+ val rng = new Random(Seed + qid)
+ val rows = new java.util.ArrayList[Row](LeftRows)
+ var i = 0
+ while (i < LeftRows) {
+ val v = new Array[Float](Dim)
+ var j = 0
+ while (j < Dim) { v(j) = rng.nextFloat(); j += 1 }
+ // Same Array[Float]-not-Java-List rule as setupRightDataset.
+ rows.add(RowFactory.create(java.lang.Long.valueOf(i.toLong), v))
+ i += 1
+ }
+ val left = spark.createDataFrame(rows, leftSchema)
+
+ val joined = left.kNearestJoin(
+ right = right,
+ leftVecCol = "qvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ probeParallelism = ProbeParallelism)
+
+ // Consume via count() — exercises the ColumnPruning-sensitive path AND avoids
+ // pulling all rows to the driver. Both are important under load.
+ val n = joined.count()
+ if (n != LeftRows.toLong * K.toLong) {
+ throw new IllegalStateException(
+ s"query $qid returned $n rows; expected ${LeftRows * K}")
+ }
+ latenciesNs.add(System.nanoTime() - t0)
+ completed.incrementAndGet()
+ } catch {
+ case t: Throwable =>
+ failed.incrementAndGet()
+ println(s"[soak] query $qid FAILED: ${t.getClass.getSimpleName}: ${t.getMessage}")
+ t.printStackTrace()
+ }
+ }
+
+ // -- reporting ------------------------------------------------------------------------------
+
+ private def printFinalReport(spark: SparkSession): Unit = {
+ val snapEnd = snapshotProcess(spark)
+ val done = completed.get()
+ val bad = failed.get()
+ val lats = latenciesNs.asScala.toVector.sorted
+
+ println()
+ println("=" * 80)
+ println("SOAK TEST FINAL REPORT")
+ println("=" * 80)
+ println(f" completed queries: $done%,d")
+ println(f" failed queries: $bad%,d")
+ if (done > 0) {
+ println(f" failure rate: ${bad.toDouble / (done + bad) * 100.0}%.3f%%")
+ }
+ println()
+
+ if (lats.nonEmpty) {
+ val p50 = lats(lats.length / 2) / 1e6
+ val p95 = lats((lats.length * 95) / 100) / 1e6
+ val p99 = lats((lats.length * 99) / 100) / 1e6
+ val max = lats.last / 1e6
+ val mean = lats.sum.toDouble / lats.length / 1e6
+ println(" LATENCY (ms)")
+ println(f" p50: $p50%,10.2f")
+ println(f" p95: $p95%,10.2f")
+ println(f" p99: $p99%,10.2f")
+ println(f" max: $max%,10.2f")
+ println(f" mean: $mean%,10.2f")
+ println()
+
+ // Early-vs-late drift. Splitting the sorted latencies doesn't work for this — we
+ // want time-order. Instead split by query-id order. Requires the latencies in
+ // submission order, which we don't have (we pushed in completion order). Skip for
+ // now; a proper drift metric needs per-query (qid, nanos) pairs. Left as follow-up.
+ }
+ println()
+ println(f" DRIVER RESOURCE TOTALS (end of run)")
+ println(f" heap used: ${snapEnd.heapUsedMb}%,d MB / ${snapEnd.heapMaxMb}%,d MB")
+ println(f" direct memory: ${snapEnd.directUsedMb}%,d MB")
+ println(f" GC count: ${snapEnd.gcCount}")
+ println(f" GC time: ${snapEnd.gcTimeMs}%,d ms")
+ println()
+
+ // Exit code hint for CI: non-zero if any query failed.
+ if (bad > 0) {
+ System.err.println(s"[soak] FAILURE: $bad queries failed; see stderr above")
+ System.exit(2)
+ }
+ }
+
+ // -- process snapshots ----------------------------------------------------------------------
+
+ private case class ProcessSnapshot(
+ heapUsedMb: Long,
+ heapMaxMb: Long,
+ directUsedMb: Long,
+ gcCount: Long,
+ gcTimeMs: Long)
+
+ /**
+ * Snapshot of driver-side memory + GC. Off-heap `direct` bytes are tracked via the
+ * JDK's `BufferPoolMXBean` for the "direct" pool — this catches `ByteBuffer.allocateDirect`
+ * but NOT native allocations (Arrow allocator, JNI). For native accounting, cluster-level
+ * JMX / /proc-scrape is required; this is a best-effort driver snapshot.
+ */
+ private def snapshotProcess(spark: SparkSession): ProcessSnapshot = {
+ val memoryMx = ManagementFactory.getMemoryMXBean
+ val heap = memoryMx.getHeapMemoryUsage
+ val heapUsedMb = heap.getUsed / (1024L * 1024L)
+ val heapMaxMb = (if (heap.getMax > 0) heap.getMax else heap.getCommitted) / (1024L * 1024L)
+
+ val directUsedMb: Long = ManagementFactory
+ .getPlatformMXBeans(classOf[java.lang.management.BufferPoolMXBean])
+ .asScala
+ .find(_.getName == "direct")
+ .map(_.getMemoryUsed / (1024L * 1024L))
+ .getOrElse(0L)
+
+ val gcBeans = ManagementFactory.getGarbageCollectorMXBeans.asScala
+ val gcCount = gcBeans.map(_.getCollectionCount).sum
+ val gcTimeMs = gcBeans.map(_.getCollectionTime).sum
+
+ ProcessSnapshot(heapUsedMb, heapMaxMb, directUsedMb, gcCount, gcTimeMs)
+ }
+
+ private def elapsedSec(): Double =
+ (System.currentTimeMillis() - appStartMillis) / 1000.0
+
+ private lazy val appStartMillis: Long = System.currentTimeMillis()
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/SiftRecallBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/SiftRecallBenchmark.scala
new file mode 100644
index 000000000..785561dff
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/SiftRecallBenchmark.scala
@@ -0,0 +1,478 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.lance.{Dataset, ReadOptions}
+import org.lance.index.IndexParams
+import org.lance.index.vector.{IvfBuildParams, PQBuildParams, VectorIndexParams}
+import org.lance.spark.LanceRuntime
+import org.lance.spark.knn.IndexedNearestJoin
+import org.lance.spark.knn.internal.Metric
+
+import java.io.{BufferedInputStream, DataInputStream, File, FileInputStream, IOException}
+import java.nio.{ByteBuffer, ByteOrder}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+/**
+ * SIFT1M / SIFT10K recall benchmark.
+ *
+ * Validates the IVF-PQ and IVF-FLAT indexed paths against the canonical SIFT1M ground
+ * truth from http://corpus-texmex.irisa.fr. For each query the dataset ships, we compute
+ * the top-K nearest in the base set using Lance's indexed nearest-join and compare
+ * against the shipped ground truth. Reports recall@K for each configuration.
+ *
+ * This is production-readiness validation #3 ("real-embeddings recall validation"). Where
+ * [[IndexedNearestJoinIvfPqRecallTest]] uses synthetic random vectors and so only proves
+ * the mechanics, this benchmark uses the standard ANN-benchmark corpus — so the recall
+ * numbers are comparable to published IVF-PQ results from the HNSWlib / FAISS papers.
+ *
+ * == Downloading the data ==
+ *
+ * SIFT1M (small -- 168 MB compressed, 500 MB uncompressed):
+ * {{{
+ * curl -L -o /tmp/sift.tar.gz ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz
+ * tar -xzf /tmp/sift.tar.gz -C /tmp/
+ * # produces /tmp/sift/{sift_base.fvecs, sift_query.fvecs, sift_groundtruth.ivecs,
+ * # sift_learn.fvecs}
+ * }}}
+ *
+ * SIFT10K (tiny -- 10 MB, useful for smoke testing):
+ * {{{
+ * curl -L -o /tmp/siftsmall.tar.gz ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz
+ * tar -xzf /tmp/siftsmall.tar.gz -C /tmp/
+ * }}}
+ *
+ * Formats:
+ * - `.fvecs` — [int32 dim, float32 * dim, int32 dim, float32 * dim, ...] (little-endian).
+ * Dim header repeats per vector; we read the first to establish dim and assume all
+ * match.
+ * - `.ivecs` — same but int32 payloads (used for ground-truth top-K indices).
+ *
+ * == Cluster run ==
+ *
+ * {{{
+ * ./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests
+ * # upload target/lance-spark-knn_2.12--benchmark.jar + the unpacked SIFT dir
+ *
+ * BENCH_CLUSTER_MODE=true \
+ * BENCH_DATA_PATH=s3://bucket/path \
+ * SIFT_DIR=/tmp/sift \
+ * SIFT_K=10 \
+ * SIFT_NUM_PARTITIONS=256 \
+ * SIFT_NUM_SUB_VECTORS=16 \
+ * SIFT_NPROBES_LIST=1,4,16,64 \
+ * SIFT_REFINE_LIST=1,4,8 \
+ * SIFT_NUM_QUERIES=1000 \
+ * spark-submit --class org.lance.spark.knn.benchmark.SiftRecallBenchmark
+ * }}}
+ *
+ * == Env knobs ==
+ *
+ * - `SIFT_DIR=/path/to/sift` -- directory containing the extracted .fvecs/.ivecs
+ * (required). Expected files:
+ * `sift_base.fvecs`, `sift_query.fvecs`,
+ * `sift_groundtruth.ivecs`.
+ * For siftsmall, files are prefixed `siftsmall_`;
+ * set `SIFT_PREFIX=siftsmall` to use them.
+ * - `SIFT_PREFIX=sift` -- file prefix (default: `sift`; set `siftsmall` for
+ * the 10K subset).
+ * - `SIFT_K=10` -- top-K to measure recall@K against ground truth
+ * (default 10). SIFT ships 100 ground-truth neighbors
+ * per query, so K ≤ 100.
+ * - `SIFT_NUM_QUERIES=1000` -- how many queries to run (default: all 10000). Lower
+ * numbers give faster feedback loops.
+ * - `SIFT_NUM_PARTITIONS=256` -- IVF cluster count for IVF-PQ (default: sqrt(N)).
+ * - `SIFT_NUM_SUB_VECTORS=16` -- PQ sub-vector count; must divide 128 (SIFT dim).
+ * Default: 16 (=> 8-byte PQ codes).
+ * - `SIFT_NPROBES_LIST=1,4,16,64` -- comma list of `nprobes` values to test.
+ * - `SIFT_REFINE_LIST=1,4,8` -- comma list of `refineFactor` values to test.
+ * - `SIFT_INDEX=ivfpq` -- `ivfpq` | `ivfflat` | `both` (default: `both`).
+ * - `SIFT_SKIP_WRITE=false` -- if "true", assumes the Lance dataset already
+ * exists at `BENCH_DATA_PATH`/sift.
+ * - `SIFT_SKIP_INDEX=false` -- if "true", assumes an index already exists. Useful
+ * for running a grid sweep against a pre-built index.
+ * - `BENCH_CLUSTER_MODE`, `BENCH_DATA_PATH` -- same as other benchmarks.
+ *
+ * == Expected numbers ==
+ *
+ * On SIFT1M × 128-dim × 10K queries, with IVF-PQ(256 clusters, 16 subvectors, 8 bits) at
+ * K=10, published ANN-benchmark numbers for FAISS IVF-PQ are in this ballpark:
+ *
+ * nprobes=1: ~0.20 recall@10, ~1 ms/query (barely touches the right centroid)
+ * nprobes=4: ~0.55 recall@10
+ * nprobes=16: ~0.85 recall@10, ~5 ms/query
+ * nprobes=64: ~0.97 recall@10, ~15 ms/query
+ * nprobes=256: 1.00 recall@10 (visits every centroid -> exact)
+ *
+ * With `refineFactor=8` the numbers shift up meaningfully: PQ compresses distance to
+ * 8-byte codes so Voronoi selection is accurate but ranking-within-cluster is lossy; the
+ * refine pass re-ranks top-K*8 by exact distance. Good recall targets with refine:
+ *
+ * nprobes=4, refine=8: ~0.85 recall@10
+ * nprobes=16, refine=8: ~0.97 recall@10
+ *
+ * If the numbers this benchmark prints are dramatically lower than published figures,
+ * that's a signal the indexed-probe path or IVF-PQ construction has a bug.
+ */
+object SiftRecallBenchmark {
+
+ // -- env knobs ------------------------------------------------------------------------------
+
+ private val ClusterMode: Boolean =
+ sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true"))
+ private val DataPath: String = sys.env.getOrElse("BENCH_DATA_PATH", "/tmp/lance-sift")
+ private val SiftDir: String =
+ sys.env.getOrElse("SIFT_DIR", sys.error("SIFT_DIR is required (path to unpacked sift/*.fvecs)"))
+ private val SiftPrefix: String = sys.env.getOrElse("SIFT_PREFIX", "sift")
+ private val K: Int = sys.env.get("SIFT_K").map(_.toInt).getOrElse(10)
+ private val NumQueries: Int =
+ sys.env.get("SIFT_NUM_QUERIES").map(_.toInt).getOrElse(10000)
+ private val NumPartitions: Int =
+ sys.env.get("SIFT_NUM_PARTITIONS").map(_.toInt).getOrElse(256)
+ private val NumSubVectors: Int =
+ sys.env.get("SIFT_NUM_SUB_VECTORS").map(_.toInt).getOrElse(16)
+ private val NprobesList: Seq[Int] =
+ sys.env.getOrElse("SIFT_NPROBES_LIST", "1,4,16,64").split(",").map(_.trim.toInt).toSeq
+ private val RefineList: Seq[Int] =
+ sys.env.getOrElse("SIFT_REFINE_LIST", "1,4,8").split(",").map(_.trim.toInt).toSeq
+ private val IndexType: String =
+ sys.env.getOrElse("SIFT_INDEX", "both").toLowerCase
+ private val SkipWrite: Boolean =
+ sys.env.get("SIFT_SKIP_WRITE").exists(_.equalsIgnoreCase("true"))
+ private val SkipIndex: Boolean =
+ sys.env.get("SIFT_SKIP_INDEX").exists(_.equalsIgnoreCase("true"))
+
+ // -- main -----------------------------------------------------------------------------------
+
+ def main(args: Array[String]): Unit = {
+ val spark = buildSparkSession()
+ try {
+ logBanner(spark)
+
+ val baseFile = s"$SiftDir/${SiftPrefix}_base.fvecs"
+ val queryFile = s"$SiftDir/${SiftPrefix}_query.fvecs"
+ val gtFile = s"$SiftDir/${SiftPrefix}_groundtruth.ivecs"
+ Seq(baseFile, queryFile, gtFile).foreach { path =>
+ if (!new File(path).exists()) {
+ sys.error(s"Missing SIFT file: $path. See scaladoc for download instructions.")
+ }
+ }
+
+ val lanceUri = s"$DataPath/${SiftPrefix}_base"
+
+ if (SkipWrite) {
+ println(s"[sift] SOAK_SKIP_WRITE=true -> using existing Lance dataset at $lanceUri")
+ } else {
+ writeBaseAsLance(spark, baseFile, lanceUri)
+ }
+
+ val indexesToRun: Seq[String] = IndexType match {
+ case "ivfpq" => Seq("ivfpq")
+ case "ivfflat" => Seq("ivfflat")
+ case "both" => Seq("ivfflat", "ivfpq")
+ case other => sys.error(s"Unknown SIFT_INDEX=$other (expected ivfpq|ivfflat|both)")
+ }
+
+ if (!SkipIndex) {
+ indexesToRun.foreach(idx => buildIndex(lanceUri, idx))
+ } else {
+ println("[sift] SOAK_SKIP_INDEX=true -> using existing index(es); no build")
+ }
+
+ val queries = loadFvecs(queryFile, limit = NumQueries)
+ val groundTruth = loadIvecs(gtFile, limit = NumQueries)
+ println(f"[sift] loaded ${queries.size}%,d queries × dim ${queries.head.length}, " +
+ f"ground truth with ${groundTruth.head.length}%,d neighbors per query")
+ require(
+ groundTruth.head.length >= K,
+ s"K=$K > ground-truth neighbors-per-query (${groundTruth.head.length}). " +
+ s"Lower K or rebuild ground truth.")
+
+ indexesToRun.foreach { idx =>
+ println()
+ println("=" * 80)
+ println(s" RECALL GRID: index=$idx, K=$K, nprobes=${NprobesList.mkString(",")}" +
+ (if (idx == "ivfpq") s", refineFactor=${RefineList.mkString(",")}" else ""))
+ println("=" * 80)
+ // IVF-FLAT ignores refineFactor (no PQ codes to re-rank), so only iterate `nprobes`.
+ val refineGrid: Seq[Int] = if (idx == "ivfpq") RefineList else Seq(1)
+ println(
+ f"${"nprobes"}%8s ${"refine"}%6s ${"recall@K"}%10s ${"mean_ms"}%10s ${"queries"}%8s")
+ for (nprobes <- NprobesList; refine <- refineGrid) {
+ val (recall, meanMs) = runRecallGrid(
+ spark,
+ lanceUri,
+ queries,
+ groundTruth,
+ nprobes = nprobes,
+ refineFactor = if (idx == "ivfpq") Some(refine) else None)
+ println(f"$nprobes%8d $refine%6d $recall%10.4f $meanMs%10.2f ${queries.size}%8d")
+ }
+ }
+ } finally {
+ spark.stop()
+ }
+ }
+
+ // -- Spark session --------------------------------------------------------------------------
+
+ private def buildSparkSession(): SparkSession = {
+ val b = SparkSession.builder().appName("SiftRecallBenchmark")
+ if (!ClusterMode) {
+ b.master("local[*]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ }
+ b.getOrCreate()
+ }
+
+ private def logBanner(spark: SparkSession): Unit = {
+ println("=" * 80)
+ println(s"SiftRecallBenchmark")
+ println("=" * 80)
+ println(f" Spark version: ${spark.version}")
+ println(f" master: ${spark.sparkContext.master}")
+ println(f" cluster mode: $ClusterMode")
+ println(f" SIFT dir: $SiftDir")
+ println(f" SIFT prefix: $SiftPrefix")
+ println(f" data path: $DataPath")
+ println(f" index type: $IndexType")
+ println(f" num IVF parts: $NumPartitions")
+ println(f" num PQ subvectors: $NumSubVectors")
+ println(f" K: $K")
+ println(f" num queries: $NumQueries")
+ println(f" nprobes grid: ${NprobesList.mkString(",")}")
+ println(f" refine grid: ${RefineList.mkString(",")}")
+ println("=" * 80)
+ println()
+ }
+
+ // -- fvecs / ivecs I/O ----------------------------------------------------------------------
+
+ /**
+ * Read an .fvecs / .ivecs file into a Seq of arrays. The file format is a concatenation
+ * of (int32 dim, payload[dim]) records, little-endian. `limit` caps the number of
+ * records read; the loader stops gracefully at EOF. `readElement` reads one 4-byte
+ * payload element (float via `intBitsToFloat` for fvecs, raw int for ivecs).
+ */
+ private def loadVecs[T: scala.reflect.ClassTag](
+ path: String,
+ limit: Int,
+ readElement: Int => T): IndexedSeq[Array[T]] = {
+ val in = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))
+ try {
+ val out = new mutable.ArrayBuffer[Array[T]]()
+ var continue = true
+ var i = 0
+ while (continue && i < limit) {
+ val buf = new Array[Byte](4)
+ val n = in.read(buf)
+ if (n < 4) { continue = false }
+ else {
+ val dim = ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getInt
+ if (dim < 0) throw new IOException(s"Invalid dim $dim at vector $i in $path")
+ val v = new Array[T](dim)
+ var j = 0
+ while (j < dim) {
+ v(j) = readElement(readLeInt(in))
+ j += 1
+ }
+ out += v
+ i += 1
+ }
+ }
+ out.toIndexedSeq
+ } finally in.close()
+ }
+
+ private def loadFvecs(path: String, limit: Int = Int.MaxValue): IndexedSeq[Array[Float]] =
+ loadVecs[Float](path, limit, java.lang.Float.intBitsToFloat)
+
+ private def loadIvecs(path: String, limit: Int = Int.MaxValue): IndexedSeq[Array[Int]] =
+ loadVecs[Int](path, limit, identity)
+
+ /** Read one little-endian int32 from the stream. Throws `IOException` on short read. */
+ private def readLeInt(in: DataInputStream): Int = {
+ val buf = new Array[Byte](4)
+ val n = in.read(buf)
+ if (n != 4) throw new IOException(s"Short read: expected 4 bytes, got $n")
+ ByteBuffer.wrap(buf).order(ByteOrder.LITTLE_ENDIAN).getInt
+ }
+
+ // -- Lance dataset write + index build ------------------------------------------------------
+
+ /**
+ * Write SIFT base vectors to a Lance dataset. Uses distributed Spark write -- avoids
+ * loading all 1M × 128-dim × 4B = 512 MB into driver memory.
+ */
+ private def writeBaseAsLance(spark: SparkSession, baseFile: String, lanceUri: String): Unit = {
+ println(s"[sift] writing base dataset to Lance: $baseFile -> $lanceUri")
+ val t0 = System.nanoTime()
+
+ // Load base vectors on driver (file is ~500 MB for SIFT1M; fine for driver heap of 4 GB+).
+ val baseVecs = loadFvecs(baseFile, limit = Int.MaxValue)
+ val dim = baseVecs.head.length
+ println(f"[sift] base has ${baseVecs.size}%,d vectors × dim $dim")
+
+ val schema = new StructType(Array(
+ StructField("rid", LongType, nullable = false),
+ StructField(
+ "vec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+
+ // Parallelize to cluster default parallelism * 4 for good write fan-out.
+ val parallelism = math.max(spark.sparkContext.defaultParallelism * 4, 16)
+ val rdd = spark.sparkContext
+ .parallelize(baseVecs.indices, parallelism)
+ .map { i =>
+ // Pass Array[Float] directly; Scala 2.12 `.toSeq.asJava` produces a Java List via
+ // Wrappers$SeqWrapper that Spark's ArrayType encoder rejects.
+ RowFactory.create(java.lang.Long.valueOf(i.toLong), baseVecs(i)): Row
+ }
+ // mode("overwrite") on a non-existent Lance path throws NoSuchTableException via the
+ // catalog's drop-before-create path. Default ErrorIfExists is correct for first-run;
+ // set SIFT_SKIP_WRITE=true to reuse an existing dataset.
+ spark.createDataFrame(rdd, schema)
+ .write.format("lance").save(lanceUri)
+
+ val sec = (System.nanoTime() - t0) / 1e9
+ println(f"[sift] write complete in $sec%.1f s")
+ }
+
+ /**
+ * Build an IVF-PQ or IVF-FLAT index on the Lance dataset. Must be called on the driver
+ * (Lance's Java SDK builds the index in-process using the opened dataset handle).
+ */
+ private def buildIndex(lanceUri: String, kind: String): Unit = {
+ println(s"[sift] building $kind index on $lanceUri (numPartitions=$NumPartitions" +
+ (if (kind == "ivfpq") s", numSubVectors=$NumSubVectors" else "") + ")")
+ val t0 = System.nanoTime()
+ val ds = Dataset.open().uri(lanceUri).allocator(LanceRuntime.allocator())
+ .readOptions(new ReadOptions.Builder().build()).build()
+ try {
+ val vectorParams = kind match {
+ case "ivfpq" =>
+ // Use the explicit builder form — the 5-arg positional `ivfPq(partitions,
+ // subvectors, bits, metric, maxIters)` path had the bits/subvectors params swapped
+ // somewhere between Scala call site and Rust side, yielding
+ // "num_bits 16 not supported" on a call that passed bits=8. Named setters
+ // eliminate that ambiguity.
+ val ivf = new IvfBuildParams.Builder()
+ .setNumPartitions(NumPartitions)
+ .setMaxIters(50)
+ .build()
+ val pq = new PQBuildParams.Builder()
+ .setNumSubVectors(NumSubVectors)
+ .setNumBits(8)
+ .setMaxIters(50)
+ .build()
+ VectorIndexParams.withIvfPqParams(Metric.L2.lanceType, ivf, pq)
+ case "ivfflat" =>
+ VectorIndexParams.ivfFlat(NumPartitions, Metric.L2.lanceType)
+ }
+ val idxParams = IndexParams.builder().setVectorIndexParams(vectorParams).build()
+ // Give each index a kind-specific name so `SIFT_INDEX=both` can build IVF-FLAT and
+ // IVF-PQ on the same `vec` column. Lance's default is `_idx` which collides
+ // when a second index is built on the same column.
+ val opts = org.lance.index.IndexOptions.builder(
+ java.util.Collections.singletonList("vec"),
+ org.lance.index.IndexType.VECTOR,
+ idxParams)
+ .withIndexName(s"vec_${kind}_idx")
+ .build()
+ ds.createIndex(opts)
+ } finally ds.close()
+ val sec = (System.nanoTime() - t0) / 1e9
+ println(f"[sift] $kind build complete in $sec%.1f s")
+ }
+
+ // -- recall evaluation ----------------------------------------------------------------------
+
+ /**
+ * Run all `queries` against the indexed Lance dataset and compute recall@K against the
+ * shipped ground truth. Returns (meanRecall, meanLatencyMs).
+ *
+ * All queries are submitted as a single `IndexedNearestJoin.apply` call (left DF has one
+ * row per query). Lance parallelises the probes internally + via Spark's per-task
+ * `LanceProbe`. Total wall-clock divided by query count gives mean latency.
+ */
+ private def runRecallGrid(
+ spark: SparkSession,
+ lanceUri: String,
+ queries: IndexedSeq[Array[Float]],
+ groundTruth: IndexedSeq[Array[Int]],
+ nprobes: Int,
+ refineFactor: Option[Int]): (Double, Double) = {
+ val dim = queries.head.length
+
+ val leftSchema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "qvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+
+ val rows = new java.util.ArrayList[Row](queries.size)
+ var i = 0
+ while (i < queries.size) {
+ // Same Array[Float]-not-Java-List rule as writeBaseAsLance.
+ rows.add(RowFactory.create(java.lang.Long.valueOf(i.toLong), queries(i)))
+ i += 1
+ }
+ val left = spark.createDataFrame(rows, leftSchema)
+
+ val right = spark.read.format("lance").load(lanceUri)
+
+ val t0 = System.nanoTime()
+ val joined = IndexedNearestJoin(
+ left = left,
+ rightLanceUri = lanceUri,
+ leftVecCol = "qvec",
+ rightVecCol = "vec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ nprobes = Some(nprobes),
+ refineFactor = refineFactor)
+ // Bring to driver. Keeping it small (NumQueries × K rows × a couple of cols).
+ val collected = joined.collect()
+ val elapsedMs = (System.nanoTime() - t0) / 1e6
+
+ // Output schema is `left.fields ++ right.fields :+ score` (see
+ // IndexedNearestJoin.buildOutputSchema). Left is [lid:Long, qvec:Array], right
+ // projection is [rid:Long], so indices are [0:lid, 1:qvec, 2:rid, 3:__score].
+ val actual = collected.groupBy(_.getLong(0)).map { case (lid, rowsArr) =>
+ lid -> rowsArr.toSeq.sortBy(_.getFloat(3)).take(K).map(_.getLong(2).toInt).toSet
+ }
+
+ var recallSum = 0.0
+ var q = 0
+ while (q < queries.size) {
+ val expected = groundTruth(q).take(K).toSet
+ val got = actual.getOrElse(q.toLong, Set.empty[Int])
+ recallSum += got.intersect(expected).size.toDouble / K
+ q += 1
+ }
+ val meanRecall = recallSum / queries.size
+ val meanLatencyMs = elapsedMs / queries.size
+ (meanRecall, meanLatencyMs)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/WikipediaKnnPerfBenchmark.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/WikipediaKnnPerfBenchmark.scala
new file mode 100644
index 000000000..01b438b94
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/benchmark/WikipediaKnnPerfBenchmark.scala
@@ -0,0 +1,637 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.sql.{DataFrame, Row, RowFactory, SparkSession}
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.lance.spark.knn.LanceKnnImplicits._
+
+import java.util.concurrent.TimeUnit
+
+/**
+ * Performance benchmark: indexed kNN-join vs vanilla Spark crossJoin on Cohere Wikipedia
+ * embeddings (dim=1024). Production-shape counterpart to
+ * [[IndexedNearestJoinBenchmark]], which uses synthetic random vectors at dim=128.
+ *
+ * Shape matches `IndexedNearestJoinBenchmark`: 5 configs, median of 3 runs + 1 warmup,
+ * same 4 `probeParallelism` flavours on the indexed side:
+ *
+ * - A: Vanilla Spark cross-product + L2 UDF + row_number window (baseline).
+ * - B: Phase 0/1 single-task probe (probeParallelism=1).
+ * - C: Phase 1.5 fragment-grouped probe (probeParallelism=4).
+ * - D: Phase 1.5 fragment-grouped probe (probeParallelism=8).
+ * - E: Phase 1.5 + skew-balanced fragment grouping.
+ *
+ * Unlike [[CohereWikiRecallBenchmark]], this does NOT build a vector index on the right
+ * side. Lance therefore brute-force-scans each fragment for every probe call — the
+ * speedup over Spark's crossJoin is entirely from the native SIMD distance kernel +
+ * columnar Arrow iteration, not from an ANN index. Adding an IVF-PQ/IVF-FLAT index would
+ * produce a further 10-100× speedup on top of what this measures.
+ *
+ * == Why dim=1024 matters ==
+ *
+ * The existing `IndexedNearestJoinBenchmark` runs at dim=128. At dim=128 the per-pair
+ * distance kernel is tiny (~8-16 cycles), so the speedup is dominated by JVM-vs-native
+ * constant overhead. At dim=1024 the kernel cost is 8× higher in absolute terms, so:
+ * - Spark's crossJoin baseline gets dramatically slower in absolute wall-clock
+ * - Lance's SIMD kernel advantage widens (AVX2/AVX-512 process 8-16 floats/cycle)
+ * - The speedup factor gets larger, not smaller
+ * Which is the opposite of what one might naively expect. The production-shape number
+ * this benchmark produces is more credible than the SIFT-style dim=128 result.
+ *
+ * == Cluster run ==
+ *
+ * {{{
+ * ./mvnw -pl lance-spark-knn_2.12 package -Pbenchmark -DskipTests
+ * # upload the fat jar + the Cohere parquet shard(s)
+ *
+ * BENCH_CLUSTER_MODE=true \
+ * BENCH_DATA_PATH=file:///valve-binaries/wiki-perf-data \
+ * WIKI_PARQUET=/valve-binaries/wiki-*.parquet \
+ * WIKI_NUM_RIGHT=100000 \
+ * WIKI_NUM_LEFT=100 \
+ * WIKI_NUM_FRAGMENTS=8 \
+ * WIKI_RUN_BASELINE=true \
+ * spark-submit --class org.lance.spark.knn.benchmark.WikipediaKnnPerfBenchmark
+ * }}}
+ *
+ * == Env knobs ==
+ *
+ * - `WIKI_PARQUET=` -- Cohere parquet glob (required). Same source as
+ * `CohereWikiRecallBenchmark`. Only the `emb` column
+ * is used.
+ * - `WIKI_EMB_COL=emb` -- embedding column name (default `emb`).
+ * - `WIKI_NUM_RIGHT=100000` -- base-set size (default 100K). The remaining rows in
+ * the parquet are ignored. Larger values make the
+ * crossJoin baseline increasingly impractical.
+ * - `WIKI_NUM_LEFT=100` -- query rows held out from the base (default 100).
+ * The crossJoin baseline is O(|L|×|R|); at |L|=100,
+ * |R|=100K that's 10M pair evaluations.
+ * - `WIKI_NUM_FRAGMENTS=8` -- number of Lance fragments to split base into. Drives
+ * `probeParallelism` caps on the indexed side.
+ * - `WIKI_K=10` -- top-K neighbours per query (default 10).
+ * - `WIKI_RUN_BASELINE=true` -- set false to skip the crossJoin baseline. Useful for
+ * medium/large scales where O(|L|×|R|) is > 30 minutes
+ * per run.
+ * - `WIKI_WARMUP_RUNS=1` / `WIKI_MEASURE_RUNS=3` -- timing iterations.
+ * - `WIKI_SKIP_SETUP=false` -- if "true", reuse the Lance dataset at
+ * `BENCH_DATA_PATH/right` written by a prior run.
+ * - `BENCH_CLUSTER_MODE`, `BENCH_DATA_PATH` -- same semantics as other benchmarks.
+ *
+ * == What this does NOT measure ==
+ *
+ * - ANN-index speedup. The right side is written without a vector index; Lance does
+ * brute-force distance computation. [[CohereWikiRecallBenchmark]] is the right tool
+ * for IVF-FLAT / IVF-PQ recall × latency tradeoffs.
+ * - Warm vs cold cache effects. Warmup runs prime the JVM / native code caches; the
+ * reported median is a steady-state number.
+ * - Driver-side latency (left-row creation, result materialization). These are inside
+ * the timing loop but dominated by the probe at any non-trivial |L|.
+ *
+ * == Known: Cohere parquet → Lance fixed-size-list ==
+ *
+ * The Cohere parquet emits `emb` as variable-length `list`; Lance's nearest-scan
+ * needs `FixedSizeList`. `DataFrameWriter.option("vec.arrow.fixed-size-list.size",
+ * dim)` does NOT propagate through the writer path — the option is only honoured by the
+ * `CREATE TABLE` + TBLPROPERTIES route. Working shape (same as
+ * [[CohereWikiRecallBenchmark]]'s fix): collect to driver, rebuild fresh rows with a
+ * `StructType` that has `arrow.fixed-size-list.size` metadata on the vec StructField,
+ * `createDataFrame` + write. Costs ~400 MB driver heap per 100K rows × dim=1024.
+ */
+object WikipediaKnnPerfBenchmark {
+
+ private val K: Int = sys.env.get("WIKI_K").map(_.toInt).getOrElse(10)
+ private val NumRight: Int = sys.env.get("WIKI_NUM_RIGHT").map(_.toInt).getOrElse(100000)
+ private val NumLeft: Int = sys.env.get("WIKI_NUM_LEFT").map(_.toInt).getOrElse(100)
+ private val NumFragments: Int =
+ sys.env.get("WIKI_NUM_FRAGMENTS").map(_.toInt).getOrElse(8)
+ private val RunBaseline: Boolean = sys.env.get("WIKI_RUN_BASELINE")
+ .map(_.equalsIgnoreCase("true")).getOrElse(true)
+ private val WarmupRuns: Int = sys.env.get("WIKI_WARMUP_RUNS").map(_.toInt).getOrElse(1)
+ private val MeasurementRuns: Int =
+ sys.env.get("WIKI_MEASURE_RUNS").map(_.toInt).getOrElse(3)
+ private val SkipSetup: Boolean =
+ sys.env.get("WIKI_SKIP_SETUP").exists(_.equalsIgnoreCase("true"))
+ private val ParquetPath: String = sys.env.getOrElse(
+ "WIKI_PARQUET",
+ sys.error("WIKI_PARQUET is required (path to Cohere wiki parquet files)"))
+ private val EmbCol: String = sys.env.getOrElse("WIKI_EMB_COL", "emb")
+ private val ClusterMode: Boolean =
+ sys.env.get("BENCH_CLUSTER_MODE").exists(_.equalsIgnoreCase("true"))
+ private val DataPath: String = sys.env.getOrElse(
+ "BENCH_DATA_PATH",
+ "file:///tmp/wiki-perf-lance")
+
+ private case class Result(config: String, medianMs: Long, runs: Seq[Long])
+ private type RunFn = () => DataFrame
+
+ def main(args: Array[String]): Unit = {
+ val spark = buildSparkSession()
+ try {
+ println(banner("Wikipedia KNN-Join Perf Benchmark"))
+ val masterDesc = if (ClusterMode) "cluster (BENCH_CLUSTER_MODE=true)" else "local[*]"
+ println(s" master: $masterDesc")
+ println(s" parquet: $ParquetPath")
+ println(s" data path: $DataPath")
+ println(s" |R|=$NumRight |L|=$NumLeft fragments=$NumFragments K=$K")
+ println(s" warmup/measure: $WarmupRuns/$MeasurementRuns runBaseline=$RunBaseline")
+ println()
+
+ val rightUri = s"$DataPath/right"
+ val (leftDf, dim) = if (SkipSetup) {
+ val d = detectDim(spark, rightUri)
+ println(s"[wiki-perf] WIKI_SKIP_SETUP=true -> reusing $rightUri (dim=$d)")
+ buildLeftFromLanceBase(spark, rightUri, d)
+ } else {
+ setupDatasets(spark, rightUri)
+ }
+ println(s"[wiki-perf] left=${leftDf.count()} right=$NumRight dim=$dim")
+ println()
+
+ val configs = makeConfigs(leftDf, rightUri)
+
+ // Correctness gate: every config — including the crossJoin baseline — must agree
+ // with a brute-force oracle on a 16-row left subset before we quote any speedup.
+ // Without this, the `count()`/`noop` timing loop only checks cardinality, not
+ // content — a bug that emits |L|×K garbage rows would still "validate." This is
+ // the same check `IndexedNearestJoinBenchmark.verifyAllConfigsAgainstOracle`
+ // runs on synthetic data; here it runs on real Cohere Wikipedia embeddings so we
+ // also catch any dim=1024-specific correctness regressions.
+ verifyAllConfigsAgainstOracle(spark, leftDf, rightUri)
+
+ val results = scala.collection.mutable.ArrayBuffer.empty[Result]
+ configs.foreach { case (name, run) =>
+ val r = timeIt(name, run)
+ results += r
+ println(formatResult(r))
+ }
+ println()
+ println(banner("Summary"))
+ printSummary(results.toSeq)
+ } finally {
+ spark.stop()
+ }
+ }
+
+ // -- Spark session --------------------------------------------------------------------------
+
+ private def buildSparkSession(): SparkSession = {
+ val disableAqe = sys.env.get("BENCH_DISABLE_AQE").exists(_.equalsIgnoreCase("true"))
+ val b = SparkSession.builder().appName("wikipedia-knn-perf")
+ .config("spark.sql.crossJoin.enabled", "true")
+ // shuffle.partitions: cluster runs use the submit-time value (default 128 on the
+ // current sizing); local runs get 32 to avoid 200-partition fanout on a laptop.
+ if (!ClusterMode) {
+ b.config("spark.sql.shuffle.partitions", "32")
+ .master("local[*]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ }
+ if (disableAqe) b.config("spark.sql.adaptive.enabled", "false")
+ val s = b.getOrCreate()
+ s.sparkContext.setLogLevel("WARN")
+ if (disableAqe) println("[wiki-perf] AQE DISABLED for this run (BENCH_DISABLE_AQE=true)")
+ s
+ }
+
+ // -- setup: Cohere parquet -> Lance right + sampled left ---------------------------------
+
+ /**
+ * Load the Cohere parquet, pull NumRight+NumLeft rows to the driver, tag the schema
+ * with `arrow.fixed-size-list.size`, write the base set to Lance, keep the query rows
+ * as an in-memory DataFrame. Returns (leftDf, dim).
+ *
+ * Driver-side collect is required because `DataFrame.write.format("lance")` does NOT
+ * honour the `vec.arrow.fixed-size-list.size` option on the writer path — without
+ * proper field metadata, Lance writes the column as variable-length `list` and
+ * the nearest-scan path returns `_rowid` only (no `_distance`), tripping
+ * `LanceProbe.readScored`'s "did not return a score column" check. See
+ * `CohereWikiRecallBenchmark` for the same workaround.
+ */
+ private def setupDatasets(spark: SparkSession, rightUri: String): (DataFrame, Int) = {
+ println(s"[wiki-perf] reading source parquet: $ParquetPath")
+ val raw = spark.read.parquet(ParquetPath)
+ require(
+ raw.schema.fieldNames.contains(EmbCol),
+ s"Source parquet does not contain $EmbCol; fields: ${raw.schema.fieldNames.mkString(",")}")
+
+ val total = NumRight + NumLeft
+ // Stable: select ordered, then limit. Avoids rand()-driven non-determinism across runs.
+ val embColExpr = col(EmbCol).cast(ArrayType(FloatType))
+ val sliced = raw.select(embColExpr.as("vec")).limit(total)
+
+ val allRows = sliced.collect()
+ require(
+ allRows.length >= NumLeft + 1,
+ s"Parquet yielded ${allRows.length} rows; need at least ${NumLeft + 1} for a " +
+ s"left/right split.")
+ if (allRows.length < total) {
+ // Shard smaller than requested |R|+|L|. Shrink the right side to whatever's left
+ // after holding out NumLeft queries. Emit a warning so the user sees the scale
+ // downshift (otherwise speedup numbers can look too good at a smaller |R|).
+ println(f"[wiki-perf] WARN: parquet only yielded ${allRows.length}%,d rows; " +
+ f"needed $total%,d. Right side shrunk from $NumRight%,d to " +
+ f"${allRows.length - NumLeft}%,d.")
+ }
+ val effectiveRight = math.min(NumRight, allRows.length - NumLeft)
+ val dim = allRows(0).getAs[scala.collection.Seq[Float]]("vec").length
+ println(f"[wiki-perf] collected ${allRows.length}%,d rows at dim=$dim; " +
+ f"using |R|=$effectiveRight%,d |L|=$NumLeft%,d")
+
+ // Split first NumRight rows -> right (base), remaining NumLeft -> left (queries).
+ val embMeta = new MetadataBuilder()
+ .putLong("arrow.fixed-size-list.size", dim.toLong)
+ .build()
+
+ // Right side: (rid: Long, rvec: Array [fixed-size]).
+ val rightSchema = new StructType(Array(
+ StructField("rid", LongType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ embMeta)))
+ val rightRows = new java.util.ArrayList[Row](effectiveRight)
+ var i = 0
+ while (i < effectiveRight) {
+ val s = allRows(i).getAs[scala.collection.Seq[Float]]("vec")
+ val arr = new Array[Float](s.length)
+ var j = 0
+ while (j < s.length) { arr(j) = s(j); j += 1 }
+ rightRows.add(RowFactory.create(java.lang.Long.valueOf(i.toLong), arr))
+ i += 1
+ }
+ println(s"[wiki-perf] writing right (base) to Lance: $rightUri")
+ val t0 = System.nanoTime()
+ spark.createDataFrame(rightRows, rightSchema)
+ .repartition(NumFragments)
+ .write.format("lance").save(rightUri)
+ println(f"[wiki-perf] wrote in ${TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - t0)}%d s")
+
+ // Left side: (lid: Long, lvec: Array [fixed-size]).
+ val leftSchema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ embMeta)))
+ val leftRows = new java.util.ArrayList[Row](NumLeft)
+ var li = 0
+ while (li < NumLeft) {
+ val r = allRows(effectiveRight + li)
+ val s = r.getAs[scala.collection.Seq[Float]]("vec")
+ val arr = new Array[Float](s.length)
+ var j = 0
+ while (j < s.length) { arr(j) = s(j); j += 1 }
+ leftRows.add(RowFactory.create(java.lang.Long.valueOf(li.toLong), arr))
+ li += 1
+ }
+ val leftDf = spark.createDataFrame(leftRows, leftSchema).cache()
+ leftDf.count() // force cache materialization
+
+ (leftDf, dim)
+ }
+
+ /**
+ * When `WIKI_SKIP_SETUP=true`, synthesise a left side from the already-written Lance
+ * base dataset. Takes the first `NumLeft` rows by `rid` (deterministic across reruns).
+ * The base set keeps those rows — the query rows are "in" the base — which means the
+ * nearest-neighbour for each query is the query itself at distance 0. Fine for a
+ * pure perf benchmark (we're not measuring recall), but note in your write-up.
+ */
+ private def buildLeftFromLanceBase(
+ spark: SparkSession,
+ rightUri: String,
+ dim: Int): (DataFrame, Int) = {
+ val embMeta = new MetadataBuilder()
+ .putLong("arrow.fixed-size-list.size", dim.toLong)
+ .build()
+ val leftSchema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ embMeta)))
+ val baseRows = spark.read.format("lance").load(rightUri)
+ .orderBy("rid").limit(NumLeft).collect()
+ val javaRows = new java.util.ArrayList[Row](baseRows.length)
+ var i = 0
+ while (i < baseRows.length) {
+ val s = baseRows(i).getAs[scala.collection.Seq[Float]]("rvec")
+ val arr = new Array[Float](s.length)
+ var j = 0
+ while (j < s.length) { arr(j) = s(j); j += 1 }
+ javaRows.add(RowFactory.create(java.lang.Long.valueOf(i.toLong), arr))
+ i += 1
+ }
+ val leftDf = spark.createDataFrame(javaRows, leftSchema).cache()
+ leftDf.count()
+ (leftDf, dim)
+ }
+
+ private def detectDim(spark: SparkSession, rightUri: String): Int = {
+ val r = spark.read.format("lance").load(rightUri).select("rvec").limit(1).collect()
+ require(r.nonEmpty, s"Empty Lance dataset at $rightUri")
+ r(0).getAs[scala.collection.Seq[Float]]("rvec").length
+ }
+
+ // -- configs --------------------------------------------------------------------------------
+
+ private def makeConfigs(
+ leftDf: DataFrame,
+ rightUri: String,
+ runBaseline: Boolean = RunBaseline): Seq[(String, RunFn)] = {
+ val spark = leftDf.sparkSession
+ val rightDf = spark.read.format("lance").load(rightUri)
+
+ val baseline: RunFn = () => crossProductTopK(spark, leftDf, rightUri, K)
+ val baselineMinByK: RunFn = () => crossProductMinByK(spark, leftDf, rightUri, K)
+ val phase01: RunFn = () =>
+ leftDf.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 1)
+ val phase15_4: RunFn = () =>
+ leftDf.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 4)
+ val phase15_8: RunFn = () =>
+ leftDf.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 8)
+ val phase15_8_skew: RunFn = () =>
+ leftDf.kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 8,
+ balanceFragments = true)
+
+ val indexed = Seq(
+ "B: Phase 0/1 (probeParallelism=1)" -> phase01,
+ "C: Phase 1.5 (probeParallelism=4)" -> phase15_4,
+ "D: Phase 1.5 (probeParallelism=8)" -> phase15_8,
+ "E: Phase 1.5 (G=8, skew-balanced)" -> phase15_8_skew)
+ // WIKI_INCLUDE_BASELINE_A=true to include the row_number-window baseline. Default off
+ // because at medium scale (|R|=100K+, |L|=1000+) the row_number plan shuffles |L|×|R|
+ // rows through a per-lid window with no partial aggregation; runs hours. A2 (heap-K
+ // shape) gets per-task partial aggregation and is the realistic baseline at scale.
+ val includeRowNumberBaseline =
+ sys.env.get("WIKI_INCLUDE_BASELINE_A").exists(_.equalsIgnoreCase("true"))
+ if (runBaseline) {
+ val a2 = Seq("A2: crossJoin + L2 UDF + groupBy/sort_array(K)" -> baselineMinByK)
+ val a =
+ if (includeRowNumberBaseline)
+ Seq("A: crossJoin + L2 UDF + row_number window" -> baseline)
+ else Seq.empty
+ a ++ a2 ++ indexed
+ } else {
+ indexed
+ }
+ }
+
+ /**
+ * Naive vanilla-Spark baseline (config A): cross product + L2 UDF + `row_number` window
+ * per `lid`. The textbook way a user would express nearest-by-join in Spark 3.5 (no
+ * `vector_l2_distance` until 4.2). Strictly slower than what `RewriteNearestByJoin`
+ * actually does on Spark 4.2 (heap-K via `min_by_k`); kept as the historical
+ * headline-naive comparison and as a stable apples-to-apples reference vs. earlier
+ * benchmark runs. Same shape as `IndexedNearestJoinBenchmark.crossProductTopK`; only
+ * difference is dim=1024 vs 128 in the synthetic benchmark.
+ */
+ private def crossProductTopK(
+ spark: SparkSession,
+ left: DataFrame,
+ rightUri: String,
+ k: Int): DataFrame = {
+ val l2 = l2UdfFactory()
+ val right =
+ repartitionRightForBaseline(spark.read.format("lance").load(rightUri).select("rid", "rvec"))
+ val crossed = left.crossJoin(right).withColumn("__dist", l2(col("lvec"), col("rvec")))
+ val w = Window.partitionBy("lid").orderBy(col("__dist"))
+ crossed.withColumn("__rank", row_number().over(w))
+ .filter(col("__rank") <= k)
+ .select("lid", "rid", "__dist")
+ }
+
+ /**
+ * Expand right-side partitioning for the cross-product baselines so the cross-join
+ * compute stage gets enough tasks to use all cluster cores. Lance reads produce one
+ * partition per fragment; without repartitioning, the fused cross-join + UDF stage
+ * inherits that count and only fragment-many tasks run at a time.
+ */
+ private def repartitionRightForBaseline(df: DataFrame): DataFrame = {
+ val target = sys.env.get("BENCH_BASELINE_RIGHT_PARTITIONS").map(_.toInt).getOrElse(64)
+ if (target > 0) df.repartition(target) else df
+ }
+
+ /**
+ * Closer-to-RewriteNearestByJoin baseline (config A2): cross product + L2 UDF + groupBy
+ * + `sort_array(collect_list(struct(dist, rid)))` + `slice(_, 1, K)` + `inline()`. Spark
+ * 4.2's `RewriteNearestByJoin` lowers `NearestByJoin` to roughly:
+ *
+ * Project(j.output)
+ * +- Generate(Inline(_matches))
+ * +- Aggregate [__qid], first(left.*) ++ min_by(struct(right.*), expr, K)
+ * +- LEFT OUTER Join (no condition) — cross product
+ *
+ * `min_by(struct, expr, K)` (`MaxMinByK`, SPARK-55322) is Spark 4.2-only; it does
+ * `O(|R| log K)` per group via a bounded heap. On Spark 3.5 the closest expressible
+ * shape is `slice(sort_array(collect_list(struct(dist, rid)), asc=true), 1, K)`, which
+ * is `O(|R| log |R|)` per group — strictly slower than the 4.2-native lowering. Quoted
+ * here so the speedup-vs-baseline number reflects what's actually possible to express
+ * in Spark 3.5 SQL today, not the naive row_number form.
+ *
+ * Same shape as `IndexedNearestJoinBenchmark.crossProductMinByK`; only difference is
+ * dim=1024 vs 128 in the synthetic benchmark.
+ */
+ private def crossProductMinByK(
+ spark: SparkSession,
+ left: DataFrame,
+ rightUri: String,
+ k: Int): DataFrame = {
+ val l2 = l2UdfFactory()
+ val right =
+ repartitionRightForBaseline(spark.read.format("lance").load(rightUri).select("rid", "rvec"))
+ val crossed = left.crossJoin(right).withColumn("__dist", l2(col("lvec"), col("rvec")))
+ crossed.groupBy("lid")
+ .agg(
+ slice(
+ sort_array(collect_list(struct(col("__dist"), col("rid"))), asc = true),
+ 1,
+ k).as("__matches"))
+ .select(col("lid"), inline(col("__matches")).as(Seq("__dist", "rid")))
+ .select("lid", "rid", "__dist")
+ }
+
+ private def l2UdfFactory(): org.apache.spark.sql.expressions.UserDefinedFunction =
+ udf((a: Seq[Float], b: Seq[Float]) => {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 }
+ s
+ })
+
+ // -- oracle equivalence --------------------------------------------------------------------
+
+ /**
+ * Run EVERY config (including the crossJoin baseline) on a 16-row left subset and compare
+ * each result against an in-memory brute-force oracle. Running on a subset keeps the
+ * baseline tractable (16 × |R| pair evaluations is sub-second even at dim=1024 × 100K)
+ * while still validating that all paths agree on top-K row IDs.
+ *
+ * Compared as `Set[Long]` per `lid` to tolerate tied-distance ordering. Real embeddings
+ * rarely produce exact ties but the comparison is robust either way.
+ *
+ * Why this matters: the `timeIt` harness uses `write.format("noop")` which materializes
+ * every row but discards output. Cardinality alone isn't a correctness proof — a bug
+ * that emits `|L|×K` garbage rows would still produce the expected count. This oracle
+ * check closes that loop on real Cohere data at dim=1024.
+ */
+ private def verifyAllConfigsAgainstOracle(
+ spark: SparkSession,
+ leftDf: DataFrame,
+ rightUri: String): Unit = {
+ println(" Sanity check: all configs match brute-force oracle on a 16-row subset ...")
+ val left16 = leftDf.limit(16).cache()
+ left16.count()
+ val leftRows = left16.collect()
+
+ // Brute-force oracle in plain Scala — the ground truth.
+ val rightDf = spark.read.format("lance").load(rightUri).select("rid", "rvec").collect()
+ val rightVecs = rightDf.map(r => r.getAs[scala.collection.Seq[Float]]("rvec").toArray)
+ val rightIds = rightDf.map(_.getAs[Long]("rid"))
+ val oracleByLid: Map[Long, Set[Long]] = leftRows.map { r =>
+ val lid = r.getAs[Long]("lid")
+ val lvec = r.getAs[scala.collection.Seq[Float]]("lvec").toArray
+ val topKRids = rightVecs.indices
+ .map(i => (rightIds(i), l2(lvec, rightVecs(i))))
+ .sortBy(_._2)
+ .take(K)
+ .map(_._1)
+ .toSet
+ lid -> topKRids
+ }.toMap
+
+ // Validate B/C/D/E against the oracle (A is brute force by construction — comparing
+ // Spark's crossJoin to in-memory brute force would be a tautology, and the window
+ // pipeline is slow enough on 16 × |R| dim=1024 pairs to add minutes per run).
+ val miniConfigs = makeConfigs(left16, rightUri, runBaseline = false)
+ miniConfigs.foreach { case (name, run) =>
+ val rows = run().collect()
+ val byLid = rows.groupBy(_.getAs[Long]("lid"))
+ .map { case (lid, rs) => lid -> rs.map(_.getAs[Long]("rid")).toSet }
+ leftRows.map(_.getAs[Long]("lid")).foreach { lid =>
+ val expected = oracleByLid(lid)
+ val actual = byLid.getOrElse(lid, Set.empty[Long])
+ if (expected != actual) {
+ sys.error(
+ s"ORACLE MISMATCH for $name at lid=$lid:\n oracle: $expected\n actual: $actual")
+ }
+ }
+ }
+ left16.unpersist()
+ println(
+ s" ... all ${miniConfigs.size} indexed configs match the oracle " +
+ s"(sample size: ${leftRows.length}, K=$K).")
+ }
+
+ private def l2(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 }
+ s
+ }
+
+ // -- timing ---------------------------------------------------------------------------------
+
+ /**
+ * Execute the plan fully and discard the result rows — Spark's canonical benchmark
+ * sink. Unlike `count()`, this forces both paths to materialize every join row through
+ * the projected columns, closing the `count()`-bias gap where the crossJoin baseline
+ * skips result-row assembly while the indexed path runs `LanceMaterialize` in full
+ * (due to the `references = child.outputSet` override on `LanceMaterializeLogicalPlan`).
+ * No network round-trip to driver either, so the measurement is pure pipeline wall-clock.
+ */
+ private def runFull(df: DataFrame): Unit =
+ df.write.format("noop").mode("overwrite").save()
+
+ private def timeIt(config: String, f: RunFn): Result = {
+ print(s" $config ... ")
+ System.out.flush()
+ var i = 0
+ while (i < WarmupRuns) { runFull(f()); i += 1 }
+ val runs = (0 until MeasurementRuns).map { _ =>
+ val t0 = System.nanoTime()
+ runFull(f())
+ TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0)
+ }
+ val sorted = runs.sorted
+ val median = sorted(sorted.length / 2)
+ println(s"runs=${runs.mkString("[", ",", "]")} ms, median=$median ms")
+ Result(config, median, runs)
+ }
+
+ // -- output ---------------------------------------------------------------------------------
+
+ private def banner(s: String): String = s"\n=== $s " + ("=" * math.max(0, 76 - s.length - 5))
+
+ private def formatResult(r: Result): String =
+ f" -> ${r.config}%-40s median=${r.medianMs}%8d ms"
+
+ private def printSummary(results: Seq[Result]): Unit = {
+ val configWidth = 40
+ val numWidth = 14
+ val divider = "-" * (configWidth + numWidth * 2)
+ println(divider)
+ println(("%-" + configWidth + "s%" + numWidth + "s%" + numWidth + "s")
+ .format("Configuration", "median (ms)", "speedup ×"))
+ println(divider)
+ val baselineMs = results.find(_.config.startsWith("A:")).map(_.medianMs).getOrElse(0L)
+ results.foreach { r =>
+ val speedup =
+ if (r.config.startsWith("A:")) "1.00x"
+ else if (r.medianMs <= 0 || baselineMs <= 0) "(no base)"
+ else f"${baselineMs.toDouble / r.medianMs}%.2fx"
+ println(("%-" + configWidth + "s%" + numWidth + "d%" + numWidth + "s")
+ .format(r.config, r.medianMs, speedup))
+ }
+ println(divider)
+ println(
+ "Speedup = baseline(A) / config. Higher = faster. Baseline is vanilla Spark " +
+ "crossJoin + L2 UDF + row_number window, the lowering Spark's RewriteNearestByJoin " +
+ "applies to SQL APPROX NEAREST when the indexed rule is off.")
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceFragments.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceFragments.scala
new file mode 100644
index 000000000..484723cbc
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceFragments.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.lance.{Dataset, ReadOptions}
+import org.lance.spark.LanceRuntime
+
+import scala.collection.JavaConverters._
+
+/**
+ * Driver-side helper that enumerates Lance fragment IDs and partitions them into balanced groups.
+ * Foundation of Phase 1.5 fragment-grouped probing — once we know all fragment IDs and how they
+ * should be split across probe tasks, the probe stage's per-task `fragmentIds` becomes a function
+ * of the partition index and the group count.
+ *
+ * Round-robin assignment is the simplest balanced split: fragments tend to be similarly sized in
+ * a healthy Lance dataset, so straight round-robin gives groups within ~1 fragment of each other
+ * in count. For uneven datasets, [[enumerateGroupsByRowCount]] uses LPT greedy bin-packing on
+ * per-fragment row counts (Phase 3 skew handling).
+ */
+object LanceFragments {
+
+ /**
+ * Open the dataset, enumerate fragment IDs, and round-robin them into `groupCount` groups.
+ * If `groupCount > numFragments`, the trailing groups will be empty — call sites must tolerate
+ * an empty group rather than fail (the probe stage skips them, since an empty fragment list
+ * means "no rows to probe").
+ *
+ * @return a list of length exactly `groupCount`. Each inner list is the fragment IDs assigned
+ * to that group. Concatenating all groups in order yields the full fragment-id set
+ * (each fragment appears in exactly one group).
+ */
+ def enumerateGroups(
+ datasetUri: String,
+ version: Option[Long],
+ groupCount: Int): Seq[Seq[Int]] = {
+ require(groupCount > 0, s"groupCount must be positive, got $groupCount")
+ val dataset = openDataset(datasetUri, version)
+ try {
+ val ids = dataset.getFragments.asScala.iterator.map(_.getId).toIndexedSeq
+ roundRobin(ids, groupCount)
+ } finally dataset.close()
+ }
+
+ /**
+ * Phase 3 skew handling — group fragments such that each group's total row count is balanced.
+ * Falls back to round-robin if the row-count metadata isn't available for a fragment (defensive
+ * default of 1 row per missing fragment so the assignment doesn't degenerate).
+ *
+ * Uses the classic "Longest Processing Time" greedy heuristic: sort fragments by row count
+ * descending, then assign each to whichever group currently has the smallest total. Worst-case
+ * makespan within 4/3 of optimal — sufficient for fragment-grouping where the goal is "no
+ * task does dramatically more work than another", not perfect balance.
+ *
+ * Use this over `enumerateGroups` when fragments are known to be uneven (e.g., produced by an
+ * unbalanced upstream write). For evenly-sized fragments the simpler round-robin gives the
+ * same result with less overhead.
+ */
+ def enumerateGroupsByRowCount(
+ datasetUri: String,
+ version: Option[Long],
+ groupCount: Int): Seq[Seq[Int]] = {
+ require(groupCount > 0, s"groupCount must be positive, got $groupCount")
+ val dataset = openDataset(datasetUri, version)
+ try {
+ val weighted = dataset.getFragments.asScala.iterator.map { f =>
+ // Some fragment metadata implementations return -1 / 0 if not populated. Treat any
+ // non-positive value as "1 row" so it still occupies a slot and gets assigned somewhere.
+ val rows = scala.math.max(1L, f.metadata.getNumRows)
+ (f.getId, rows)
+ }.toIndexedSeq
+ greedyBalance(weighted, groupCount)
+ } finally dataset.close()
+ }
+
+ /**
+ * Public for testing without spinning up a Lance dataset. Round-robins `ids` into `groupCount`
+ * sub-sequences while preserving relative order within each group.
+ */
+ private[knn] def roundRobin(ids: Seq[Int], groupCount: Int): Seq[Seq[Int]] = {
+ val groups = Array.fill(groupCount)(scala.collection.mutable.ArrayBuffer.empty[Int])
+ var i = 0
+ while (i < ids.size) {
+ groups(i % groupCount) += ids(i)
+ i += 1
+ }
+ groups.toSeq.map(_.toSeq)
+ }
+
+ /**
+ * Public for testing. LPT (Longest Processing Time) greedy bin-packing: sort by weight desc,
+ * assign each item to the currently lightest group. 4/3-approximation of optimal makespan.
+ */
+ private[knn] def greedyBalance(weighted: Seq[(Int, Long)], groupCount: Int): Seq[Seq[Int]] = {
+ val groups = Array.fill(groupCount)(scala.collection.mutable.ArrayBuffer.empty[Int])
+ val totals = Array.fill(groupCount)(0L)
+ val sorted = weighted.sortBy { case (_, w) => -w } // descending by weight
+ sorted.foreach { case (id, w) =>
+ var minIdx = 0
+ var i = 1
+ while (i < groupCount) {
+ if (totals(i) < totals(minIdx)) minIdx = i
+ i += 1
+ }
+ groups(minIdx) += id
+ totals(minIdx) += w
+ }
+ groups.toSeq.map(_.toSeq)
+ }
+
+ private def openDataset(datasetUri: String, version: Option[Long]): Dataset = {
+ val readOpts = {
+ val b = new ReadOptions.Builder()
+ version.foreach(v => b.setVersion(v))
+ b.build()
+ }
+ Dataset
+ .open()
+ .uri(datasetUri)
+ .allocator(LanceRuntime.allocator())
+ .readOptions(readOpts)
+ .build()
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMaterializeStage.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMaterializeStage.scala
new file mode 100644
index 000000000..0a66dbe82
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMaterializeStage.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructField
+
+import scala.collection.mutable
+
+/**
+ * Materialize stage. Per task, opens the Lance dataset and point-fetches right rows by
+ * `_rowaddr`. Joins against the carried left payload to emit final join rows.
+ *
+ * Lance's row-address-IN filter is the natural index point-fetch path (validated by
+ * `LanceProbeValidationTest`). Materialize re-opens the Lance dataset rather than reusing the
+ * probe stage's open because the two stages are now in separate Spark tasks (across the shuffle).
+ * The cost is one Lance manifest read per task — Lance's metadata is mmap-friendly and that's
+ * the same trade-off lance-spark already accepts for fragment scans.
+ *
+ * The materialize-only `LanceProbe` instance never calls `probe()`, so `vectorColumn` is unused on
+ * this path; we pass an empty string. Refactoring `LanceProbe`'s constructor to drop the param
+ * for materialize-only use is left for a follow-up, since the current shape is the validated one.
+ */
+object LanceMaterializeStage {
+
+ final case class Conf(
+ datasetUri: String,
+ version: Option[Long],
+ rightProjection: Seq[String],
+ rightFields: Seq[StructField],
+ leftFieldCount: Int,
+ outerJoin: Boolean)
+ extends Serializable
+
+ def run(merged: RDD[(Long, ProbedLeft)], conf: Conf): RDD[Row] = {
+ merged.mapPartitions(iter => materializePartition(iter, conf))
+ }
+
+ private def materializePartition(
+ iter: Iterator[(Long, ProbedLeft)],
+ conf: Conf): Iterator[Row] = {
+ if (iter.isEmpty) return Iterator.empty
+
+ val probe =
+ new LanceProbe(conf.datasetUri, fragmentIds = None, version = conf.version)
+ val out = mutable.ArrayBuffer.empty[Row]
+ try {
+ iter.foreach { case (_, pl) =>
+ if (pl.refs.isEmpty && conf.outerJoin) {
+ out += assembleRow(
+ pl.leftRow,
+ conf.leftFieldCount,
+ conf.rightFields,
+ rightValues = null,
+ score = null)
+ } else if (pl.refs.nonEmpty) {
+ // Build a `rowAddr -> materialized row` map. If `pl.refs` ever contains duplicate
+ // `rowAddr`s (same row referenced by multiple probe contributions) the map collapses
+ // them to one entry; the `pl.refs.foreach` loop below still emits one output row per
+ // ref, all sharing that materialized payload. Phase 1.5's `TopKHeap.merge` does not
+ // dedupe across contributions, so this collapse is intentional rather than a bug —
+ // duplicate refs would mean the same right row is the K-th nearest along multiple
+ // fragment-group paths, which is genuinely "the same hit" and should appear once
+ // per ref in the output.
+ val materialized: Map[Long, Map[String, Any]] = probe
+ .materialize(pl.refs.iterator.map(_.rowAddr).toSeq, conf.rightProjection)
+ .map(m => extractRowAddr(m) -> m)
+ .toMap
+ pl.refs.foreach { ref =>
+ val rightMap = materialized.getOrElse(ref.rowAddr, null)
+ out += assembleRow(
+ pl.leftRow,
+ conf.leftFieldCount,
+ conf.rightFields,
+ rightMap,
+ ref.score)
+ }
+ }
+ }
+ } finally probe.close()
+ out.iterator
+ }
+
+ private def extractRowAddr(m: Map[String, Any]): Long =
+ m.get(LanceProbe.RowIdColumn) match {
+ case Some(l: java.lang.Long) => l.longValue()
+ case Some(l: Long) => l
+ case Some(other) => other.toString.toLong
+ case None =>
+ throw new IllegalStateException(
+ s"Materialized row missing ${LanceProbe.RowIdColumn}; " +
+ s"got keys: ${m.keys.mkString(", ")}")
+ }
+
+ private def assembleRow(
+ leftRow: Row,
+ leftFieldCount: Int,
+ rightFields: Seq[StructField],
+ rightValues: Map[String, Any],
+ score: Any): Row = {
+ val arr = new Array[Any](leftFieldCount + rightFields.size + 1)
+ var i = 0
+ while (i < leftFieldCount) { arr(i) = leftRow.get(i); i += 1 }
+ var j = 0
+ while (j < rightFields.size) {
+ arr(leftFieldCount + j) =
+ if (rightValues == null) null else rightValues.getOrElse(rightFields(j).name, null)
+ j += 1
+ }
+ arr(leftFieldCount + rightFields.size) = score
+ Row.fromSeq(arr.toSeq)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMergeStage.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMergeStage.scala
new file mode 100644
index 000000000..b1fa88ae9
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceMergeStage.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+/**
+ * Merge-stage configuration. Aggregates per-`leftId` contributions from N probe tasks via a
+ * bounded `TopKHeap` and trims the result to `finalK`. The aggregation runs inside
+ * [[org.lance.spark.knn.internal.staged.LanceMergeExec.doExecute]] — this object carries
+ * only the parameters.
+ */
+object LanceMergeStage {
+
+ final case class Conf(finalK: Int, smallerIsBetter: Boolean) extends Serializable
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbe.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbe.scala
new file mode 100644
index 000000000..2af4f4b8c
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbe.scala
@@ -0,0 +1,348 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.{BigIntVector, FieldVector, Float4Vector, Float8Vector, UInt8Vector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.lance.{Dataset, ReadOptions}
+import org.lance.ipc.{LanceScanner, Query, ScanOptions}
+import org.lance.spark.{LanceConstant, LanceRuntime}
+
+import java.util
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+/**
+ * Per-task vector-index probe primitive. Opens a Lance dataset once and runs many `probe()` calls
+ * against a fixed set of fragments. Returns row references + scores only — no payload. Late
+ * materialization happens elsewhere (`LanceMaterialize`).
+ *
+ * This is the core primitive Phase 0 of the indexed nearest-by design depends on. Validating its
+ * cost profile is the first thing to do on a new Lance build:
+ * - dataset open should be one-time cost
+ * - per-probe cost should be index traversal + small overhead, not full fragment scan
+ * - returning top-K row addrs should match Lance's native nearest search recall
+ *
+ * Lifecycle: instantiate per task, call `probe(...)` repeatedly, close at end.
+ *
+ * @param datasetUri Lance dataset URI (passed straight to `Dataset.open`).
+ * @param fragmentIds Fragments this probe is restricted to. Pass `None` for whole-dataset search.
+ * @param version Optional Lance version to pin. Required when used inside a join, so all
+ * probe / materialize stages see the same snapshot.
+ * @param allocator Arrow allocator. Defaults to lance-spark's shared `LanceRuntime.allocator()`.
+ */
+final class LanceProbe(
+ datasetUri: String,
+ fragmentIds: Option[Seq[Int]],
+ version: Option[Long] = None,
+ allocator: BufferAllocator = LanceRuntime.allocator())
+ extends AutoCloseable {
+
+ // Open the dataset once. Lance's Java binding caches index metadata against the Dataset handle,
+ // so reusing it across probes keeps subsequent calls index-warm.
+ private val dataset: Dataset = openDataset()
+
+ private val javaFragmentIds: Option[util.List[Integer]] = fragmentIds.map { ids =>
+ val javaList = new util.ArrayList[Integer](ids.size)
+ ids.foreach(i => javaList.add(Integer.valueOf(i)))
+ javaList: util.List[Integer]
+ }
+
+ private def openDataset(): Dataset = {
+ val readOpts = {
+ val b = new ReadOptions.Builder()
+ version.foreach(v => b.setVersion(v))
+ b.build()
+ }
+ Dataset.open()
+ .uri(datasetUri)
+ .allocator(allocator)
+ .readOptions(readOpts)
+ .build()
+ }
+
+ /**
+ * Run a single nearest-neighbor query. Returns up to `k` row references for the configured
+ * fragments, ordered best-first by `metric`.
+ *
+ * Implementation note: lance-spark mandates `prefilter = true` for fragmented vector queries
+ * (see `LanceFragmentScanner.create`). We mirror that here — Lance's index probe semantics
+ * require it when fragment scope is restricted.
+ *
+ * `vectorColumn` is a per-call argument (not a constructor field) because the same
+ * `LanceProbe` instance also serves the materialize stage via [[materialize]], which
+ * doesn't reference any vector column. Keeping it on the call sidesteps the smell of
+ * passing a placeholder string when constructing for materialize-only use.
+ *
+ * `prefilter` is a Lance SQL filter string (DataFusion-flavored). Lance applies it BEFORE the
+ * vector index lookup when `prefilter = true` (which we always set), so the top-K is computed
+ * over only the rows matching the filter — exactly what a `Filter(cond, lance) RIGHT JOIN ...
+ * APPROX NEAREST K` should do. Without prefilter pushdown, a per-fragment vector probe could
+ * return K rows that are all later filtered out post-join, masking truly-nearest-but-also-
+ * matching rows further down the index — a recall bug. The translator in
+ * `IndexedNearestByJoinRule` is responsible for producing only safely-translated SQL; here we
+ * just hand it through.
+ */
+ def probe(
+ vectorColumn: String,
+ query: Array[Float],
+ k: Int,
+ metric: Metric,
+ nprobes: Option[Int] = None,
+ refineFactor: Option[Int] = None,
+ ef: Option[Int] = None,
+ prefilter: Option[String] = None): Seq[ScoredRowRef] = {
+ require(vectorColumn != null && vectorColumn.nonEmpty, "vectorColumn must be non-empty")
+ require(query != null && query.length > 0, "Query vector must be non-empty")
+ require(k > 0, "k must be positive")
+
+ val q = {
+ val b = new Query.Builder()
+ .setColumn(vectorColumn)
+ .setKey(query)
+ .setK(k)
+ .setDistanceType(metric.lanceType)
+ nprobes.foreach(b.setNprobes(_))
+ // refineFactor: IVF-PQ recall knob. Lance fetches `k * refineFactor` approximate
+ // candidates, then re-ranks them with exact distance and trims to k. Bigger factor =
+ // better recall, more compute. None leaves Lance's default (= 1, no re-rank).
+ refineFactor.foreach(b.setRefineFactor(_))
+ // ef: HNSW search depth. Higher = better recall, more compute. None leaves Lance's
+ // index-default. Only meaningful for HNSW indexes; ignored for IVF-PQ.
+ ef.foreach(b.setEf(_))
+ b.build()
+ }
+
+ val opts = new ScanOptions.Builder()
+ .nearest(q)
+ .prefilter(true)
+ // Use `_rowid` rather than `_rowaddr` for the row identity. Lance's INDEXED nearest
+ // search path doesn't materialize `_rowaddr` (errors with "No field named _rowaddr"),
+ // while `_rowid` works on both the indexed and non-indexed paths. The materialize stage
+ // filters `_rowid IN (...)` to fetch the surviving rows.
+ .withRowId(true)
+ // Project only what we need into the result. The vector column is implied by `nearest`;
+ // requesting an empty user column list keeps the Arrow batch narrow (just the rowid +
+ // distance metadata). Materialization fetches payload columns later.
+ .columns(java.util.Collections.emptyList[String]())
+
+ prefilter.filter(_.nonEmpty).foreach(opts.filter)
+ javaFragmentIds.foreach(opts.fragmentIds)
+
+ val scanner: LanceScanner = LanceScanner.create(dataset, opts.build(), allocator)
+ try {
+ readScored(scanner.scanBatches())
+ } finally {
+ scanner.close()
+ }
+ }
+
+ /**
+ * Drain the Arrow stream from a nearest-search scan into `(rowId, score)` pairs.
+ *
+ * Expected schema:
+ * - `_rowid` : UInt8 / BigInt — Lance logical row identifier
+ * - `_distance` (or score column added by `nearest`) : Float4 / Float8 — ranking value
+ *
+ * We resolve columns by name to be encoding-version-agnostic; the underlying primitive type
+ * (UInt8 vs BigInt for the id, Float4 vs Float8 for score) varies across Arrow / Lance combos
+ * and we tolerate both.
+ */
+ private def readScored(reader: ArrowReader): Seq[ScoredRowRef] = {
+ val out = mutable.ArrayBuffer.empty[ScoredRowRef]
+ try {
+ while (reader.loadNextBatch()) {
+ val root = reader.getVectorSchemaRoot
+ val addrVec: FieldVector = root.getVector(LanceProbe.RowIdColumn)
+ val scoreVec: FieldVector = LanceProbe.ScoreColumns.iterator
+ .map(name => Option(root.getVector(name)).orNull)
+ .find(_ != null)
+ .getOrElse(throw new IllegalStateException(
+ s"Lance nearest scan did not return a score column. Got: " +
+ root.getSchema.getFields.asScala.map(_.getName).mkString(", ")))
+
+ val n = root.getRowCount
+ var i = 0
+ while (i < n) {
+ val addr = addrVec match {
+ case v: UInt8Vector => v.get(i)
+ case v: BigIntVector => v.get(i)
+ case other =>
+ throw new IllegalStateException(
+ s"Unexpected row-address vector type: ${other.getClass.getName}")
+ }
+ val score = scoreVec match {
+ case v: Float4Vector => v.get(i)
+ case v: Float8Vector => v.get(i).toFloat
+ case other =>
+ throw new IllegalStateException(
+ s"Unexpected score vector type: ${other.getClass.getName}")
+ }
+ out += ScoredRowRef(addr, score)
+ i += 1
+ }
+ }
+ } finally {
+ reader.close()
+ }
+ out.toSeq
+ }
+
+ /**
+ * Materialize a set of right-side rows by their `_rowaddr`s. Used by the join's materialize
+ * stage to fetch full payloads after the probe + merge has decided which rows survive.
+ *
+ * The row addresses are pushed down as a `_rowaddr IN (...)` filter, which Lance executes via
+ * its row-address index — the natural point-fetch path. The result is unordered with respect
+ * to the input list; the caller re-aligns by `_rowaddr`.
+ *
+ * @param rowAddrs list of Lance `_rowid` values (parameter name retained for source
+ * compatibility with callers — semantically these are now row IDs).
+ * @param projection projected column list. `Seq.empty` means "all columns".
+ * @return a sequence of materialized rows, each represented as a `Map[String, Any]` for the
+ * projected columns plus an entry under `LanceProbe.RowIdColumn` so the caller can
+ * re-key. Returning a Map keeps this primitive Spark-agnostic; conversion to
+ * `InternalRow` happens in the API layer.
+ */
+ def materialize(
+ rowAddrs: Seq[Long],
+ projection: Seq[String] = Seq.empty): Seq[Map[String, Any]] = {
+ if (rowAddrs.isEmpty) return Seq.empty
+
+ val opts = new ScanOptions.Builder().withRowId(true)
+ if (projection.nonEmpty) {
+ opts.columns(projection.toList.asJava)
+ }
+ // `_rowid IN (a, b, c)` — Lance lowers this to its row-id lookup path. Same point-fetch
+ // semantics as `_rowaddr IN (...)` previously used here, but `_rowid` is the universal
+ // identifier (works on indexed + non-indexed scan paths alike).
+ //
+ // Each row ID is rendered as `arrow_cast('', 'UInt64')` for two
+ // compounding reasons:
+ //
+ // 1. Lance row IDs are 64-bit UNSIGNED; storing them as Java signed `long` means
+ // values >= 2^63 come back negative. `mkString(", ")` would render them as
+ // negative integer literals and Lance/DataFusion would reject (`Int64(-...)
+ // cannot convert to UInt64`).
+ // 2. Even after `Long.toUnsignedString` produces a positive 20-digit decimal,
+ // DataFusion's SQL parser tries `Int64` first, overflows, then falls back to
+ // `Float64`. `Float64` loses precision past 2^53 — the literal becomes a
+ // different number — and DataFusion then can't downcast `Float64` to `UInt64`.
+ //
+ // `arrow_cast(string, 'UInt64')` bypasses both: the string literal goes through
+ // `arrow_cast`'s own coercion, which is precision-preserving for UInt64.
+ //
+ // At 100K rows row IDs stay below 2^53 and both layers of the bug are invisible; at
+ // 1M+ rows they bite. Caught when the DataFrame benchmark hit 1M-row scale.
+ val rowIdLiterals = rowAddrs.iterator
+ .map(addr => s"arrow_cast('${java.lang.Long.toUnsignedString(addr)}', 'UInt64')")
+ .mkString(", ")
+ opts.filter(s"${LanceProbe.RowIdColumn} IN ($rowIdLiterals)")
+ javaFragmentIds.foreach(opts.fragmentIds)
+
+ val scanner: LanceScanner = LanceScanner.create(dataset, opts.build(), allocator)
+ try {
+ readRows(scanner.scanBatches())
+ } finally {
+ scanner.close()
+ }
+ }
+
+ private def readRows(reader: ArrowReader): Seq[Map[String, Any]] = {
+ val out = mutable.ArrayBuffer.empty[Map[String, Any]]
+ try {
+ while (reader.loadNextBatch()) {
+ val root: VectorSchemaRoot = reader.getVectorSchemaRoot
+ val n = root.getRowCount
+ var i = 0
+ while (i < n) {
+ val rowMap = mutable.LinkedHashMap.empty[String, Any]
+ val fields = root.getSchema.getFields.asScala
+ var f = 0
+ while (f < fields.size) {
+ val name = fields(f).getName
+ val v = root.getVector(name)
+ rowMap(name) = if (v.isNull(i)) null else LanceProbe.toSparkValue(v.getObject(i))
+ f += 1
+ }
+ out += rowMap.toMap
+ i += 1
+ }
+ }
+ } finally {
+ reader.close()
+ }
+ out.toSeq
+ }
+
+ override def close(): Unit = dataset.close()
+}
+
+object LanceProbe {
+
+ /**
+ * Lance row-identity virtual column name. We use `_rowid` rather than `_rowaddr` because
+ * Lance's INDEXED nearest-search path materializes `_rowid` but not `_rowaddr`, while
+ * non-indexed scans materialize both. `_rowid` therefore works on every code path that
+ * calls `probe()` (with or without a vector index built on the column). Sourced from
+ * `LanceConstant` to keep the literal defined in exactly one place.
+ */
+ val RowIdColumn: String = LanceConstant.ROW_ID
+
+ /**
+ * Candidate names for the score column in a Lance nearest-search result. Lance's vector indexes
+ * have used `_distance` historically; tolerate `_score` too in case future versions rename it.
+ * The lookup is name-based so the consumer is agnostic to where Lance puts the column in its
+ * output schema.
+ */
+ val ScoreColumns: Seq[String] = Seq("_distance", "_score")
+
+ /**
+ * Convert an Arrow-returned cell value into something Spark's encoders accept when stuffed
+ * into a `Row`. Arrow's `FieldVector.getObject` returns Java types (boxed primitives,
+ * `JsonStringArrayList` for list cells, `Text` for utf8) which Spark's `RowEncoder` does not
+ * always understand directly — most painfully, a `java.util.ArrayList` can't satisfy a Spark
+ * `ArrayType` slot, which expects a `scala.collection.Seq`.
+ *
+ * Conversion rules, in order:
+ * - `java.util.List` → recursively-converted `Seq`
+ * - `java.util.Map` → recursively-converted Scala `Map`
+ * - `org.apache.arrow.vector.util.Text` → `String`
+ * - `Number` boxed primitives → returned as-is (Spark handles them)
+ * - everything else → returned as-is (caller's responsibility)
+ *
+ * Recursive on lists/maps to handle nested types (arrays of structs, etc.) without surprises
+ * for callers.
+ */
+ def toSparkValue(value: Any): Any = value match {
+ case null => null
+ case list: java.util.List[_] =>
+ val out = scala.collection.mutable.ArrayBuffer.empty[Any]
+ val it = list.iterator
+ while (it.hasNext) out += toSparkValue(it.next())
+ out.toSeq
+ case map: java.util.Map[_, _] =>
+ val out = scala.collection.mutable.LinkedHashMap.empty[Any, Any]
+ val it = map.entrySet().iterator
+ while (it.hasNext) {
+ val e = it.next()
+ out(toSparkValue(e.getKey)) = toSparkValue(e.getValue)
+ }
+ out.toMap
+ case t: org.apache.arrow.vector.util.Text => t.toString
+ case other => other
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbeStage.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbeStage.scala
new file mode 100644
index 000000000..29e039629
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/LanceProbeStage.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.apache.spark.HashPartitioner
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+
+import scala.collection.mutable
+
+/**
+ * Probe stage of the indexed nearest-by pipeline. Per task, opens a Lance dataset once and runs
+ * `LanceProbe.probe(...)` per left row, restricted to the task's assigned fragments. Emits
+ * `(leftId, ProbedLeft)` keyed by `leftId` so the downstream Exchange can hash-shuffle by it.
+ *
+ * Map-side combine via [[TopKHeap]] only kicks in when a single task probes multiple fragments
+ * directly. With `fragmentIds = None` Lance does the cross-fragment merge internally and a
+ * single `probe` call already returns the right K — the per-row heap collapses to a passthrough.
+ * Splitting fragments across tasks via `runWithFragmentGroups` (Phase 1.5) gives each task a
+ * subset of the dataset's fragments and lets the downstream merge stage aggregate contributions.
+ *
+ * The stage materializes its output into an `ArrayBuffer` before closing the probe handle. This
+ * is the same closing-iterator pattern used in Phase 0 — necessary because Spark's iterator
+ * model lazily pulls from `mapPartitions`, which would otherwise let the consumer outlive the
+ * `try`/`finally`.
+ */
+object LanceProbeStage {
+
+ /**
+ * Driver-side configuration shipped to every probe task. Kept minimal so adding new probe knobs
+ * (filter pushdown, refine factor, etc.) stays local to this object.
+ */
+ final case class Conf(
+ datasetUri: String,
+ fragmentIds: Option[Seq[Int]],
+ vectorColumn: String,
+ version: Option[Long],
+ metric: Metric,
+ k: Int,
+ nprobes: Option[Int],
+ leftVecIdx: Int,
+ refineFactor: Option[Int] = None,
+ ef: Option[Int] = None,
+ prefilter: Option[String] = None)
+ extends Serializable
+
+ def run(leftKeyed: RDD[(Long, Row)], conf: Conf): RDD[(Long, ProbedLeft)] = {
+ leftKeyed.mapPartitions(iter => probePartition(iter, conf))
+ }
+
+ /**
+ * Phase 1.5 — fragment-grouped probe. Replicates each left row across `fragmentGroups.size`
+ * partitions, so each task probes a single group's fragments only and ALL left rows produce
+ * `fragmentGroups.size` contributions per `leftId`. Downstream `LanceMergeStage` then has real
+ * work to do — its `reduceByKey` aggregates across groups via `TopKHeap.merge`.
+ *
+ * Topology:
+ *
+ * {{{
+ * leftKeyed: RDD[(Long, Row)]
+ * -- flatMap (leftId, row) -> (groupIdx, (leftId, row)) replicate G times
+ * -- partitionBy(HashPartitioner(G)) one partition per group
+ * -- mapPartitionsWithIndex { (idx, iter) =>
+ * openLanceProbe(fragmentGroups(idx))
+ * probe each row }
+ * -- emits (leftId, ProbedLeft) G entries per leftId
+ * }}}
+ *
+ * The `flatMap` + `partitionBy` together form one shuffle. The merge stage's `reduceByKey`
+ * adds a second shuffle (since the output here is keyed by `leftId` but the partitioning is by
+ * `groupIdx`). Two shuffles is the cost of fragment-grouping.
+ *
+ * Empty groups (when `fragmentGroups.size > numFragments`) are skipped — the partition's
+ * iterator yields zero output. Callers don't need to special-case this.
+ *
+ * @param leftKeyed left rows keyed by stable `leftId`
+ * @param conf probe config; `fragmentIds` on the conf is IGNORED — this method
+ * overrides it per group
+ * @param fragmentGroups fragment ID assignment; one entry per group
+ */
+ def runWithFragmentGroups(
+ leftKeyed: RDD[(Long, Row)],
+ conf: Conf,
+ fragmentGroups: Seq[Seq[Int]]): RDD[(Long, ProbedLeft)] = {
+ require(fragmentGroups.nonEmpty, "fragmentGroups must not be empty")
+ val groupCount = fragmentGroups.size
+ val groupsBcast = leftKeyed.context.broadcast(fragmentGroups)
+
+ val replicated: RDD[(Int, (Long, Row))] = leftKeyed.flatMap {
+ case (leftId, leftRow) =>
+ (0 until groupCount).iterator.map(g => (g, (leftId, leftRow)))
+ }
+ val byGroup = replicated.partitionBy(new HashPartitioner(groupCount))
+
+ byGroup.mapPartitionsWithIndex(
+ { (partIdx, iter) =>
+ if (!iter.hasNext) Iterator.empty
+ else {
+ val groups = groupsBcast.value
+ // partIdx maps directly to groupIdx because HashPartitioner places key i in
+ // partition `i % groupCount`, and our keys are 0..groupCount-1.
+ val frags = groups(partIdx)
+ if (frags.isEmpty) Iterator.empty
+ else {
+ val groupConf = conf.copy(fragmentIds = Some(frags))
+ probePartition(iter.map(_._2), groupConf)
+ }
+ }
+ },
+ preservesPartitioning = false)
+ }
+
+ private def probePartition(
+ iter: Iterator[(Long, Row)],
+ conf: Conf): Iterator[(Long, ProbedLeft)] = {
+ if (iter.isEmpty) return Iterator.empty
+
+ val probe =
+ new LanceProbe(conf.datasetUri, conf.fragmentIds, conf.version)
+ val out = mutable.ArrayBuffer.empty[(Long, ProbedLeft)]
+ try {
+ iter.foreach { case (leftId, leftRow) =>
+ val q = extractVector(leftRow, conf.leftVecIdx)
+ val refs =
+ if (q == null) Array.empty[ScoredRowRef]
+ else probe.probe(
+ conf.vectorColumn,
+ q,
+ conf.k,
+ conf.metric,
+ conf.nprobes,
+ conf.refineFactor,
+ conf.ef,
+ conf.prefilter)
+ .toArray
+ out += ((leftId, ProbedLeft(leftRow, refs)))
+ }
+ } finally probe.close()
+ out.iterator
+ }
+
+ /**
+ * Pull a query vector out of a Spark `Row`'s ArrayType column. Mirrors the matching logic from
+ * Phase 0 — the Scala 2.13 `Seq` gotcha is real: `Row.get` on `ArrayType` returns
+ * `mutable.ArraySeq`, which `case s: Seq[_]` only matches against the root `scala.collection.Seq`
+ * trait (the default `Seq` alias is `immutable.Seq` on 2.13).
+ */
+ private[knn] def extractVector(row: Row, idx: Int): Array[Float] = {
+ if (row.isNullAt(idx)) return null
+ row.get(idx) match {
+ case s: scala.collection.Seq[_] =>
+ s.iterator.map {
+ case f: java.lang.Float => f.floatValue()
+ case f: Float => f
+ case d: java.lang.Double => d.doubleValue().toFloat
+ case d: Double => d.toFloat
+ case other =>
+ throw new IllegalStateException(
+ s"Unsupported vector element type: ${other.getClass.getName}")
+ }.toArray
+ case arr: Array[Float] => arr
+ case arr: Array[java.lang.Float] => arr.map(_.floatValue())
+ case other =>
+ throw new IllegalStateException(
+ s"Unsupported vector column representation: ${other.getClass.getName}")
+ }
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/Metric.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/Metric.scala
new file mode 100644
index 000000000..798bcedd4
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/Metric.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.lance.index.DistanceType
+
+/**
+ * Vector distance / similarity metric. Mirrors `org.lance.index.DistanceType` but exposed as a
+ * Scala enumeration so callers don't have to import Lance internals. Each metric fixes the
+ * "best-first" direction used during merge:
+ *
+ * - L2: smaller score is better (distance)
+ * - Cosine / Dot: larger score is better (similarity)
+ */
+sealed trait Metric {
+
+ /** The Lance distance type used when configuring a `Query`. */
+ def lanceType: DistanceType
+
+ /** True if smaller scores rank better (distance), false if larger (similarity). */
+ def smallerIsBetter: Boolean
+}
+
+object Metric {
+
+ case object L2 extends Metric {
+ val lanceType: DistanceType = DistanceType.L2
+ val smallerIsBetter: Boolean = true
+ }
+
+ case object Cosine extends Metric {
+ val lanceType: DistanceType = DistanceType.Cosine
+ val smallerIsBetter: Boolean = false
+ }
+
+ case object Dot extends Metric {
+ val lanceType: DistanceType = DistanceType.Dot
+ val smallerIsBetter: Boolean = false
+ }
+
+ /**
+ * Parse a metric name. Accepts the same set of names Lance accepts plus a few synonyms commonly
+ * used in Spark vector functions:
+ *
+ * - "l2" | "euclidean" → L2
+ * - "cosine" → Cosine
+ * - "dot" | "inner" | "ip" → Dot
+ */
+ def fromName(name: String): Metric = name.trim.toLowerCase match {
+ case "l2" | "euclidean" => L2
+ case "cosine" => Cosine
+ case "dot" | "inner" | "ip" => Dot
+ case other =>
+ throw new IllegalArgumentException(
+ s"Unknown metric '$other'. Expected one of: l2, cosine, dot.")
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ProbedLeft.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ProbedLeft.scala
new file mode 100644
index 000000000..f350b63f3
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ProbedLeft.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.apache.spark.sql.Row
+
+/**
+ * Tuple shipped from the probe stage to the merge stage. Carries the left row alongside the
+ * top-K row references so the materialize stage can reconstruct join output without a separate
+ * join back to the left source.
+ *
+ * Phase 1 trades shuffle bandwidth for simplicity: shipping the left payload through the shuffle
+ * loses the "refs only ~24B per ref" bandwidth advantage that the IMPL_PLAN positions as the
+ * eventual win. Phase 2/3 plan: pre-partition the left RDD by `leftId` and ship only
+ * `(leftId, refs)` through the shuffle, then cogroup against the co-partitioned left payload at
+ * materialize time. Doing so requires either a stable user-supplied join key or a synthetic
+ * `leftId` carried alongside the payload — orthogonal to the staging refactor done here.
+ */
+final case class ProbedLeft(leftRow: Row, refs: Array[ScoredRowRef]) extends Serializable
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ScoredRowRef.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ScoredRowRef.scala
new file mode 100644
index 000000000..ac1b194fc
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/ScoredRowRef.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+/**
+ * A reference to a single right-side row produced by a vector index probe, along with the ranking
+ * score. Carries no payload — payloads are fetched in the materialize stage by row address. Pairing
+ * a tiny ref with a score is the unit of work passing through the shuffle and is what keeps the
+ * shuffle volume to `O(|L| × tasks × K × ~24B)` instead of `O(|L| × tasks × K × payload_bytes)`.
+ *
+ * @param rowAddr Lance row address (`_rowaddr`): packed `(frag_id << 32) | row_in_frag`. Stable
+ * within a Lance dataset version.
+ * @param score Distance or similarity returned by Lance's vector search. Smaller-is-better for
+ * distance metrics (L2), larger-is-better for similarity metrics (cosine/dot).
+ * Direction is carried out-of-band in the operator config; this struct stays metric-
+ * agnostic.
+ */
+final case class ScoredRowRef(rowAddr: Long, score: Float)
+
+object ScoredRowRef {
+
+ /** Order best-first for distance metrics (smallest score wins). */
+ val distanceOrdering: Ordering[ScoredRowRef] =
+ Ordering.by[ScoredRowRef, Float](_.score)
+
+ /** Order best-first for similarity metrics (largest score wins). */
+ val similarityOrdering: Ordering[ScoredRowRef] =
+ Ordering.by[ScoredRowRef, Float](-_.score)
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/TopKHeap.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/TopKHeap.scala
new file mode 100644
index 000000000..673179c8b
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/TopKHeap.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import scala.collection.mutable
+
+/**
+ * Bounded top-K heap with metric-aware ordering. Used by the probe stage for map-side combine
+ * across fragments owned by a single task — keeps the per-left-row state at exactly K entries no
+ * matter how many fragments contribute, and again on the reduce side to merge contributions from
+ * different tasks for the same `leftId`.
+ *
+ * Semantics:
+ * - `smallerIsBetter = true` (distance, e.g. L2): retain the K smallest-score entries.
+ * - `smallerIsBetter = false` (similarity, e.g. cosine): retain the K largest-score entries.
+ *
+ * Internally, the heap's *head* holds the worst surviving element so eviction is O(log K). Scala's
+ * `mutable.PriorityQueue` is a max-heap by the supplied `Ordering`, so the ordering is chosen to
+ * place "worst surviving" at the top:
+ * - distance → max-heap on `score` (largest score is worst)
+ * - similarity → max-heap on `-score` (smallest score is worst)
+ *
+ * Not thread-safe. Each left row in a probe stage gets its own heap.
+ */
+final class TopKHeap(k: Int, smallerIsBetter: Boolean) {
+ require(k > 0, "k must be positive")
+
+ private val ord: Ordering[ScoredRowRef] =
+ if (smallerIsBetter) Ordering.by[ScoredRowRef, Float](_.score)
+ else Ordering.by[ScoredRowRef, Float](-_.score)
+
+ private val heap = new mutable.PriorityQueue[ScoredRowRef]()(ord)
+
+ /**
+ * Insert `ref` if it would survive the top-K cut. Either grows the heap up to K or evicts the
+ * current worst-surviving element if `ref` is strictly better than it.
+ */
+ def offer(ref: ScoredRowRef): Unit = {
+ if (heap.size < k) {
+ heap.enqueue(ref)
+ } else {
+ val worst = heap.head
+ val isBetter =
+ if (smallerIsBetter) ref.score < worst.score
+ else ref.score > worst.score
+ if (isBetter) {
+ heap.dequeue()
+ heap.enqueue(ref)
+ }
+ }
+ }
+
+ def offerAll(refs: TraversableOnce[ScoredRowRef]): Unit = refs.foreach(offer)
+
+ /**
+ * Drain the heap into a best-first sorted Array. After this call the heap is empty. Best-first
+ * means index 0 is the top-ranked entry (smallest score for distance, largest for similarity).
+ */
+ def drain(): Array[ScoredRowRef] = {
+ val out = new Array[ScoredRowRef](heap.size)
+ var i = heap.size - 1
+ // PriorityQueue.dequeue returns the worst surviving element first; walking the array in
+ // reverse places best at index 0.
+ while (i >= 0) {
+ out(i) = heap.dequeue()
+ i -= 1
+ }
+ out
+ }
+
+ def size: Int = heap.size
+ def isEmpty: Boolean = heap.isEmpty
+}
+
+object TopKHeap {
+
+ /**
+ * Convenience: merge several already-sorted (best-first) ref arrays into one top-K array. Used
+ * by the merge stage as the `reduceByKey` combine function.
+ */
+ def merge(
+ a: Array[ScoredRowRef],
+ b: Array[ScoredRowRef],
+ k: Int,
+ smallerIsBetter: Boolean): Array[ScoredRowRef] = {
+ if (a.isEmpty) return takeBest(b, k, smallerIsBetter)
+ if (b.isEmpty) return takeBest(a, k, smallerIsBetter)
+ val heap = new TopKHeap(k, smallerIsBetter)
+ heap.offerAll(a)
+ heap.offerAll(b)
+ heap.drain()
+ }
+
+ private def takeBest(
+ arr: Array[ScoredRowRef],
+ k: Int,
+ smallerIsBetter: Boolean): Array[ScoredRowRef] = {
+ if (arr.length <= k) return arr
+ val heap = new TopKHeap(k, smallerIsBetter)
+ heap.offerAll(arr)
+ heap.drain()
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/LanceKnnStagedStrategy.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/LanceKnnStagedStrategy.scala
new file mode 100644
index 000000000..acdaef16f
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/LanceKnnStagedStrategy.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal.staged
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
+
+/**
+ * Maps the three staged-pipeline logical plans to their physical execs. Registered
+ * lazily on the active `SparkSession`'s `experimentalMethods.extraStrategies` the first
+ * time `IndexedNearestJoin.apply` runs against a session — see
+ * [[LanceKnnStagedStrategy.ensureRegistered]].
+ *
+ * The strategy is `object`-singleton (no per-call state) so registration can use
+ * reference-equality to avoid duplicate entries on repeated calls within the same session.
+ */
+private[knn] object LanceKnnStagedStrategy extends SparkStrategy {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case p: LanceProbeLogicalPlan =>
+ LanceProbeExec(
+ child = planLater(p.child),
+ stageConf = p.stageConf,
+ fragmentGroups = p.fragmentGroups,
+ leftSchema = p.leftSchema,
+ output = p.output) :: Nil
+
+ case p: LanceMergeLogicalPlan =>
+ LanceMergeExec(
+ child = planLater(p.child),
+ stageConf = p.stageConf,
+ leftSchema = p.leftSchema,
+ output = p.output) :: Nil
+
+ case p: LanceMaterializeLogicalPlan =>
+ LanceMaterializeExec(
+ child = planLater(p.child),
+ stageConf = p.stageConf,
+ leftSchema = p.leftSchema,
+ finalSchema = p.finalSchema,
+ output = p.output) :: Nil
+
+ case _ => Nil
+ }
+
+ /**
+ * Idempotently install this strategy on the session's planner. Called from
+ * `IndexedNearestJoin.apply` so users don't have to wire up Spark session extensions
+ * just to use the DataFrame API.
+ *
+ * `experimentalMethods.extraStrategies` is mutable but not thread-safe in its setter.
+ * Synchronising on a private monitor (the singleton object itself) keeps concurrent
+ * `IndexedNearestJoin.apply` calls from racing and double-installing. Idempotency uses
+ * reference equality on the strategy singleton — straightforward since the strategy is
+ * an `object`, not a `class`.
+ */
+ def ensureRegistered(spark: SparkSession): Unit = synchronized {
+ val em = spark.sessionState.experimentalMethods
+ val current = em.extraStrategies
+ if (!current.exists(_ eq this)) {
+ em.extraStrategies = current :+ this
+ }
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/ProbedLeftCodec.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/ProbedLeftCodec.scala
new file mode 100644
index 000000000..e8bafc0c2
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/ProbedLeftCodec.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal.staged
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types.{ArrayType, DataType, FloatType, LongType, StructField, StructType}
+import org.lance.spark.knn.internal.{ProbedLeft, ScoredRowRef}
+
+/**
+ * Catalyst-level codec for `(leftId: Long, ProbedLeft)` tuples flowing between the staged
+ * physical operators (`LanceProbeExec → LanceMergeExec → LanceMaterializeExec`). The single-
+ * exec wrapper used by Phase 2 keeps the tuple as a typed Scala value through `RDD[(Long,
+ * ProbedLeft)]`; once we split into three separate `SparkPlan` operators the inter-stage RDD
+ * has to be `RDD[InternalRow]`, so each boundary needs encode + decode.
+ *
+ * == Schema ==
+ *
+ * The inter-stage row carries:
+ *
+ * - `leftId: long` — synthetic per-row id used as the merge `reduceByKey` key
+ * - `leftRow: struct` — the original left DataFrame row, schema preserved so
+ * `LanceMaterializeStage` can reconstruct the join output without re-fetching the left
+ * side
+ * - `refs: array>` — the probe's top-K row references
+ * for this left row
+ *
+ * == Why Catalyst struct, not Kryo blob ==
+ *
+ * The microbench (`InterStagePayloadOverheadBench`) showed both encodings cost <1 % of total
+ * SQL benchmark wall-clock at every realistic scale, so the choice is code-aesthetic, not
+ * performance. Catalyst struct wins on aesthetics:
+ *
+ * - `df.explain()` shows the inter-stage operators with their schema readable
+ * (`[leftId#0L, leftRow#1, refs#2]`) instead of an opaque `[leftId, blob: binary]`.
+ * - Spark's shuffle path is native to Catalyst's UnsafeRow encoding — UnsafeRow shuffle
+ * write/read is heavily optimised. Kryo blobs would still ride that path but with an
+ * extra serialize-into-binary hop on top.
+ *
+ * == Encoder / decoder lifecycle ==
+ *
+ * `ExpressionEncoder(leftSchema).resolveAndBind()` is constructed driver-side and shipped to
+ * executors via task closure serialization. Each partition then calls `createSerializer()`
+ * / `createDeserializer()` once and reuses the resulting function across the partition's
+ * iterator — this matches Spark's standard encoder lifecycle.
+ */
+private[knn] object ProbedLeftCodec {
+
+ /** Schema of a single ref struct inside the `refs` array column. */
+ private val RefStructFields: Array[StructField] = Array(
+ StructField("rowAddr", LongType, nullable = false),
+ StructField("score", FloatType, nullable = false))
+ private val RefStruct: StructType = StructType(RefStructFields)
+ val RefsType: ArrayType = ArrayType(RefStruct, containsNull = false)
+
+ /**
+ * Inter-stage row schema parameterised by the left side's schema. The leftRow's fields
+ * are FLATTENED into the top-level schema (rather than nested as a sub-struct) — earlier
+ * iterations used a nested struct and triggered Spark's `UnsafeRowSerializer` + nested-
+ * struct reuse semantics into JVM-level SIGSEGV / unsafe-fault crashes inside the
+ * materialize stage's deserializer. Flattening sidesteps the nested-struct path entirely
+ * and keeps the binary layout to plain top-level fields plus one array-of-struct.
+ *
+ * Schema:
+ * - `_leftId: long` — synthetic per-row id
+ * - `` — every field of the user's left DataFrame, inlined
+ * - `_refs: array>`
+ *
+ * The leading underscore on `_leftId` and `_refs` keeps them out of the way of any
+ * reasonable user column name.
+ */
+ def interStageSchema(leftSchema: StructType): StructType = {
+ val leadField = StructField("_leftId", LongType, nullable = false)
+ val refsField = StructField("_refs", RefsType, nullable = false)
+ StructType(leadField +: leftSchema.fields :+ refsField)
+ }
+
+ /** Number of leading non-leftSchema columns at the top of `interStageSchema` (just `_leftId`). */
+ private val LeftIdColIndex: Int = 0
+ private val FirstLeftColIndex: Int = 1
+
+ /** The `_refs` column lives at the very end. */
+ private def refsColIndex(leftSchema: StructType): Int = 1 + leftSchema.length
+
+ /**
+ * Build the AttributeReference list that operators expose as `output`. Created once per
+ * plan tree and shared between probe and merge so attribute exprIds line up across the
+ * boundary — Catalyst attribute resolution rejects mid-tree exprId changes.
+ */
+ def interStageAttributes(leftSchema: StructType): Seq[AttributeReference] =
+ interStageSchema(leftSchema).fields.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+ }
+
+ /**
+ * Per-partition encoder/decoder. Construct once at the top of each task's `mapPartitions`
+ * closure and reuse for every row. The underlying serializer/deserializer functions are
+ * stateful (they cache writers/readers) so DO NOT share across partitions / threads.
+ *
+ * The codec uses a single `ExpressionEncoder(interStageSchema)` to handle the whole row
+ * end-to-end. Earlier iterations tried pre-serialising the `leftRow` to `UnsafeRow` and
+ * then re-wrapping in a `GenericInternalRow + UnsafeProjection` — that produced JVM-level
+ * unsafe-memory faults at the materialize stage's deserializer (the array length read
+ * from the nested struct's binary layout was corrupt). Letting Catalyst's encoder handle
+ * the entire `Row` → `InternalRow` conversion in one pass avoids the nesting-mismatch
+ * pitfall and produces UnsafeRow output that the shuffle path can serialize directly.
+ *
+ * Output is `UnsafeRow` (Spark's shuffle path requires it — `UnsafeRowSerializer` casts
+ * every shuffled row to `UnsafeRow`).
+ */
+ final class Encoder(leftSchema: StructType) extends Serializable {
+ @transient private lazy val outerSer =
+ ExpressionEncoder(interStageSchema(leftSchema)).resolveAndBind().createSerializer()
+ private val leftFieldCount: Int = leftSchema.length
+
+ def encode(leftId: Long, pl: ProbedLeft): InternalRow = {
+ // Flatten: outer row is `[leftId, leftField0, leftField1, ..., refs]`.
+ val cols = new Array[Any](2 + leftFieldCount)
+ cols(0) = java.lang.Long.valueOf(leftId)
+ var i = 0
+ while (i < leftFieldCount) {
+ cols(1 + i) = pl.leftRow.get(i)
+ i += 1
+ }
+ cols(1 + leftFieldCount) = pl.refs.map(r => Row(r.rowAddr, r.score)).toSeq
+ outerSer(Row.fromSeq(cols.toSeq)).copy()
+ }
+ }
+
+ /**
+ * Decoder reads fields directly from the input `InternalRow` without going through
+ * `ExpressionEncoder.Deserializer`. Earlier iterations did go through the deserializer
+ * and tripped a JIT-compiled SIGSEGV at `UnsafeRow.getArray` in generated `MapObjects`
+ * code on the `_refs` array column under sustained load (after JIT C2 had compiled the
+ * inner loop). The deserializer's generated code interacts poorly with UnsafeRow array
+ * accessors in some Spark 3.5 setups — known fragility.
+ *
+ * Reading via `InternalRow.get(i, dataType)` + `CatalystTypeConverters.createToScalaConverter`
+ * sidesteps the problematic generated code entirely. The Catalyst→Scala converters handle
+ * primitive unwrapping, `UnsafeArrayData` → `Seq`, etc., which is what the encoder's
+ * deserializer would have done — just via the direct converter API rather than codegen.
+ */
+ final class Decoder(leftSchema: StructType) extends Serializable {
+ private val leftFieldCount: Int = leftSchema.length
+ // Per-column converters: build once on driver, reused on executors. The `Any => Any`
+ // function is what `CatalystTypeConverters` exposes for converting raw InternalRow
+ // values into idiomatic Scala/Java values.
+ private val leftConverters: Array[Any => Any] = leftSchema.fields.map { f =>
+ CatalystTypeConverters.createToScalaConverter(f.dataType)
+ }
+ private val leftDataTypes: Array[DataType] = leftSchema.fields.map(_.dataType)
+ private val refsIdx: Int = refsColIndex(leftSchema)
+
+ def decode(ir: InternalRow): (Long, ProbedLeft) = {
+ val leftId = ir.getLong(LeftIdColIndex)
+
+ // Read each leftRow field directly from the InternalRow.
+ val leftValues = new Array[Any](leftFieldCount)
+ var i = 0
+ while (i < leftFieldCount) {
+ val colIdx = FirstLeftColIndex + i
+ if (ir.isNullAt(colIdx)) {
+ leftValues(i) = null
+ } else {
+ val raw = ir.get(colIdx, leftDataTypes(i))
+ leftValues(i) = leftConverters(i)(raw)
+ }
+ i += 1
+ }
+ val leftRow = Row.fromSeq(leftValues.toSeq)
+
+ // Refs array: ArrayData of struct.
+ val arr: ArrayData = ir.getArray(refsIdx)
+ val n = arr.numElements()
+ val refs = new Array[ScoredRowRef](n)
+ var r = 0
+ while (r < n) {
+ val refStruct = arr.getStruct(r, 2)
+ refs(r) = ScoredRowRef(refStruct.getLong(0), refStruct.getFloat(1))
+ r += 1
+ }
+
+ (leftId, ProbedLeft(leftRow, refs))
+ }
+ }
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedExecs.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedExecs.scala
new file mode 100644
index 000000000..877eea0c1
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedExecs.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal.staged
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.StructType
+import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, ProbedLeft, TopKHeap}
+
+import scala.collection.mutable
+
+/**
+ * Three physical operators corresponding to [[LanceProbeLogicalPlan]],
+ * [[LanceMergeLogicalPlan]], and [[LanceMaterializeLogicalPlan]]. Each `doExecute` decodes
+ * its child's `RDD[InternalRow]` to typed `(Long, ProbedLeft)` tuples, runs the existing
+ * staged-RDD pipeline op, and re-encodes back to `RDD[InternalRow]`.
+ *
+ * The decode → run → re-encode dance is what gives `df.explain()` an honest 3-stage shape.
+ * The microbench (`InterStagePayloadOverheadBench`) measured the per-row encoding cost at
+ * <1 % of total wall-clock at every realistic SQL benchmark scale, so the runtime cost is
+ * below the noise floor of repeated runs.
+ *
+ * == AQE compatibility ==
+ *
+ * `LanceMergeExec` declares `requiredChildDistribution = ClusteredDistribution(leftId)`, so
+ * Catalyst's `EnsureRequirements` rule inserts a `ShuffleExchangeExec` between the probe
+ * and merge execs. That exchange is what AQE wraps as a `ShuffleQueryStageExec` — and once
+ * AQE has the wrapper it can apply the usual rules (`CoalesceShufflePartitions`,
+ * `OptimizeSkewJoin`, `OptimizeShuffleWithLocalRead`) to the merge stage's shuffle.
+ *
+ * The merge exec itself does NO shuffle internally — `doExecute` is just a per-partition
+ * group-by-leftId aggregation, since after the exchange every leftId's contributions are
+ * co-located in one partition.
+ *
+ * Caveat: when `probeParallelism > 1`, the probe stage uses
+ * `LanceProbeStage.runWithFragmentGroups` which still does an internal RDD-level
+ * `partitionBy` shuffle (to replicate left rows across fragment groups). That shuffle
+ * remains AQE-invisible. Fixing it would need a different shape — replication is a
+ * `flatMap` plus a partitioning, which doesn't fit Catalyst's `requiredChildDistribution`
+ * model cleanly. Local-laptop benchmarks showed the fragment-grouped path doesn't help at
+ * single-machine scale anyway, so we leave it as-is for now.
+ */
+private[knn] case class LanceProbeExec(
+ override val child: SparkPlan,
+ stageConf: LanceProbeStage.Conf,
+ fragmentGroups: Option[Seq[Seq[Int]]],
+ leftSchema: StructType,
+ override val output: Seq[Attribute])
+ extends UnaryExecNode {
+
+ // The inter-stage attrs (leftId, leftRow, refs) appear in `output` but not in
+ // `child.output`. We synthesise them from per-row probe results — declare so Spark's
+ // `missingInput` check (and the `!` bang in tree-string output) doesn't flag this node.
+ override def producedAttributes: AttributeSet = AttributeSet(output) -- child.outputSet
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val childRdd = child.execute()
+ val schemaCaptured = leftSchema // capture for closure serialization
+ val confCaptured = stageConf
+ val groupsCaptured = fragmentGroups
+
+ // Decode user's left-side InternalRows into Rows (matches the shape the existing
+ // LanceProbeStage takes). copy() because Spark may reuse the InternalRow buffer
+ // across iterations of the upstream operator.
+ val leftEnc = ExpressionEncoder(schemaCaptured).resolveAndBind()
+ val rowLeftRdd: RDD[Row] = childRdd.mapPartitions { iter =>
+ val deser = leftEnc.createDeserializer()
+ iter.map(ir => deser(ir.copy()))
+ }
+ val leftKeyed: RDD[(Long, Row)] =
+ rowLeftRdd.zipWithUniqueId().map { case (row, id) => (id, row) }
+
+ val probed: RDD[(Long, ProbedLeft)] = groupsCaptured match {
+ case Some(groups) => LanceProbeStage.runWithFragmentGroups(leftKeyed, confCaptured, groups)
+ case None => LanceProbeStage.run(leftKeyed, confCaptured)
+ }
+
+ probed.mapPartitions { iter =>
+ val enc = new ProbedLeftCodec.Encoder(schemaCaptured)
+ iter.map { case (lid, pl) => enc.encode(lid, pl) }
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): LanceProbeExec =
+ copy(child = newChild)
+}
+
+private[knn] case class LanceMergeExec(
+ override val child: SparkPlan,
+ stageConf: LanceMergeStage.Conf,
+ leftSchema: StructType,
+ override val output: Seq[Attribute])
+ extends UnaryExecNode {
+
+ /**
+ * Force a hash-partitioned shuffle on `leftId` between probe and merge. Without this,
+ * the merge would have to do its own RDD-level shuffle (the original design) which is
+ * invisible to AQE. With `ClusteredDistribution(leftId)` declared, Catalyst's
+ * `EnsureRequirements` rule inserts a `ShuffleExchangeExec` automatically; AQE wraps
+ * that exchange as a `ShuffleQueryStageExec` and can coalesce / re-balance / etc.
+ *
+ * `leftId` is always the first column of the inter-stage schema. Pulling it from
+ * `child.output.head` matches what `ProbedLeftCodec.interStageSchema` produces.
+ */
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(Seq(child.output.head)) :: Nil
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val childRdd = child.execute()
+ val schemaCaptured = leftSchema
+ val finalK = stageConf.finalK
+ val smallerIsBetter = stageConf.smallerIsBetter
+
+ // After the upstream `ShuffleExchangeExec`, all rows with the same leftId are co-located
+ // on the same partition. We just group within partition and apply `TopKHeap.merge`.
+ // No reduceByKey needed.
+ childRdd.mapPartitions { iter =>
+ val dec = new ProbedLeftCodec.Decoder(schemaCaptured)
+ val enc = new ProbedLeftCodec.Encoder(schemaCaptured)
+ val byLid = mutable.LinkedHashMap.empty[Long, ProbedLeft]
+
+ while (iter.hasNext) {
+ val ir = iter.next().copy()
+ val (lid, pl) = dec.decode(ir)
+ byLid.get(lid) match {
+ case Some(prev) =>
+ val mergedRefs = TopKHeap.merge(prev.refs, pl.refs, finalK, smallerIsBetter)
+ byLid(lid) = ProbedLeft(prev.leftRow, mergedRefs)
+ case None =>
+ // First contribution for this lid. Trim if it already overflows finalK
+ // (probe stage may emit `internalK = k * overfetch` per ref array).
+ val refs =
+ if (pl.refs.length <= finalK) pl.refs
+ else {
+ val heap = new TopKHeap(finalK, smallerIsBetter)
+ heap.offerAll(pl.refs)
+ heap.drain()
+ }
+ byLid(lid) = ProbedLeft(pl.leftRow, refs)
+ }
+ }
+
+ byLid.iterator.map { case (lid, pl) => enc.encode(lid, pl) }
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): LanceMergeExec =
+ copy(child = newChild)
+}
+
+private[knn] case class LanceMaterializeExec(
+ override val child: SparkPlan,
+ stageConf: LanceMaterializeStage.Conf,
+ leftSchema: StructType,
+ finalSchema: StructType,
+ override val output: Seq[Attribute])
+ extends UnaryExecNode {
+
+ // Final-schema attrs (left.* ++ right.* ++ score) are synthesised here, not in any child.
+ override def producedAttributes: AttributeSet = AttributeSet(output) -- child.outputSet
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val childRdd = child.execute()
+ val leftSchemaCaptured = leftSchema
+ val finalSchemaCaptured = finalSchema
+ val confCaptured = stageConf
+
+ val keyed: RDD[(Long, ProbedLeft)] = childRdd.mapPartitions { iter =>
+ val dec = new ProbedLeftCodec.Decoder(leftSchemaCaptured)
+ iter.map(ir => dec.decode(ir.copy()))
+ }
+
+ val joinedRows: RDD[Row] = LanceMaterializeStage.run(keyed, confCaptured)
+
+ val finalEnc = ExpressionEncoder(finalSchemaCaptured).resolveAndBind()
+ joinedRows.mapPartitions { iter =>
+ val ser = finalEnc.createSerializer()
+ iter.map(row => ser(row).copy())
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): LanceMaterializeExec =
+ copy(child = newChild)
+}
diff --git a/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedPlans.scala b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedPlans.scala
new file mode 100644
index 000000000..b655aa3ac
--- /dev/null
+++ b/lance-spark-knn_2.12/src/main/scala/org/lance/spark/knn/internal/staged/StagedPlans.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal.staged
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage}
+
+/**
+ * Three logical plan nodes that both the DataFrame API path (`IndexedNearestJoin.apply`)
+ * and the SQL path (`IndexedNearestByJoinRule` in `lance-spark-knn-4.2_2.13`) build, so
+ * each stage of the staged pipeline shows up as its own operator in `df.explain()`. They
+ * are deliberately Lance-specific (the matching `Strategy` only knows how to lower these
+ * three). Both user-facing paths emit this exact tree and share `LanceKnnStagedStrategy`
+ * for the lowering.
+ *
+ * == Why they are not children of each other in the obvious "tree" way ==
+ *
+ * Probe and Merge nodes share the inter-stage attribute references (created once via
+ * `ProbedLeftCodec.interStageAttributes(leftSchema)`) so Catalyst's attribute resolution
+ * does not fight us — a parent that references attributes of a child requires those
+ * attributes to be reachable from the child's `outputSet`. By keeping the same
+ * `Seq[AttributeReference]` instances on both probe.output and merge.output we sidestep
+ * having to copy or rewrite expr-ids across the inter-stage boundary.
+ *
+ * == producedAttributes ==
+ *
+ * For each node, `producedAttributes` is `output -- child.outputSet` — Catalyst convention
+ * for "attributes this node introduces that don't come from its child." Probe introduces
+ * the inter-stage triple from the user-shaped left input. Merge passes through. Materialize
+ * introduces the final join output (which doesn't exist in any child) so all of its output
+ * is produced.
+ */
+private[knn] case class LanceProbeLogicalPlan(
+ override val child: LogicalPlan,
+ stageConf: LanceProbeStage.Conf,
+ fragmentGroups: Option[Seq[Seq[Int]]],
+ leftSchema: StructType,
+ interStageOutput: Seq[Attribute])
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = interStageOutput
+
+ override def producedAttributes: AttributeSet = AttributeSet(output) -- child.outputSet
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): LanceProbeLogicalPlan =
+ copy(child = newChild)
+}
+
+private[knn] case class LanceMergeLogicalPlan(
+ override val child: LogicalPlan,
+ stageConf: LanceMergeStage.Conf,
+ leftSchema: StructType,
+ interStageOutput: Seq[Attribute])
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = interStageOutput
+
+ // Pass-through schema: same attrs as child. AttributeSet subtraction on identical
+ // attribute references yields the empty set — merge produces nothing new.
+ override def producedAttributes: AttributeSet = AttributeSet(output) -- child.outputSet
+
+ // Every child attribute is load-bearing — the matching `LanceMergeExec.doExecute`
+ // decodes the full inter-stage row (leftId + leftRow + refs) on every input. Without
+ // this override, Catalyst's `ColumnPruning` sees downstream consumers that reference
+ // nothing (`count(*)`, `Aggregate`, etc.) and wraps this node in `Project(Nil)`; that
+ // project codegens to 0-field UnsafeRows which then crash `ProbedLeftCodec.Decoder`
+ // at `ir.getLong(0)` (AssertionError under interpreter/C1, SIGSEGV under C2 JIT).
+ // This pattern was initially misdiagnosed as a JVM-aarch64 bug; the actual cause
+ // is this Catalyst rule. See IMPL_PLAN.md "3-exec staged split — root cause and fix".
+ override lazy val references: AttributeSet = child.outputSet
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): LanceMergeLogicalPlan =
+ copy(child = newChild)
+}
+
+private[knn] case class LanceMaterializeLogicalPlan(
+ override val child: LogicalPlan,
+ stageConf: LanceMaterializeStage.Conf,
+ leftSchema: StructType,
+ finalSchema: StructType,
+ finalOutput: Seq[Attribute])
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = finalOutput
+
+ // Final schema attrs do not appear in the inter-stage child, so all of them are produced.
+ override def producedAttributes: AttributeSet = AttributeSet(output) -- child.outputSet
+
+ // Same rationale as `LanceMergeLogicalPlan.references`: the matching exec decodes every
+ // child attribute to rebuild the `ProbedLeft` tuple, so nothing in the child can be
+ // pruned.
+ override lazy val references: AttributeSet = child.outputSet
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): LanceMaterializeLogicalPlan =
+ copy(child = newChild)
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinAqeVisibilityTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinAqeVisibilityTest.scala
new file mode 100644
index 000000000..96cf98118
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinAqeVisibilityTest.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{DataFrame, RowFactory, SparkSession}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Proves that the 3-exec staged pipeline (LanceProbeExec → ShuffleExchangeExec →
+ * LanceMergeExec → LanceMaterializeExec) is Catalyst-visible end-to-end, so AQE
+ * engages on the merge-side shuffle.
+ *
+ * The `LanceMergeExec.requiredChildDistribution = ClusteredDistribution(leftId)` on the
+ * physical exec is what makes `EnsureRequirements` insert a `ShuffleExchangeExec` between
+ * probe and merge. That exchange is the AQE wrap point.
+ *
+ * Unlike the previous `InterStageShuffle` (`repartition`-based) approach — which called
+ * `.rdd` on an intermediate DataFrame, collapsing the Catalyst plan to an RDD before the
+ * final join wrapped it — here the three execs live in ONE `joined.queryExecution.executedPlan`
+ * tree. The Exchange is directly inspectable in that tree and AQE's `AdaptiveSparkPlanExec`
+ * wraps the whole thing when AQE is enabled.
+ */
+class IndexedNearestJoinAqeVisibilityTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 8
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ private def newSparkSession(aqeEnabled: Boolean): SparkSession = {
+ spark = SparkSession.builder()
+ .appName(s"aqe-visibility-aqe-$aqeEnabled")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.sql.adaptive.enabled", aqeEnabled.toString)
+ .config("spark.sql.shuffle.partitions", "4")
+ .getOrCreate()
+ spark
+ }
+
+ /**
+ * The joined DataFrame's executed plan contains a `ShuffleExchangeExec` with
+ * `hashpartitioning` on the `_leftId` attribute — inserted by `EnsureRequirements` in
+ * response to `LanceMergeExec.requiredChildDistribution = ClusteredDistribution`.
+ * This is the shape-level proof the exchange is Catalyst-planned.
+ */
+ @Test def testExchangeHashpartitioningOnLeftIdWithAqeEnabled(): Unit = {
+ val s = newSparkSession(aqeEnabled = true)
+ val joined = buildJoined(s)
+ joined.collect()
+ assertExchangeOnLeftId(joined)
+ }
+
+ @Test def testExchangeHashpartitioningOnLeftIdWithAqeDisabled(): Unit = {
+ val s = newSparkSession(aqeEnabled = false)
+ val joined = buildJoined(s)
+ joined.collect()
+ assertExchangeOnLeftId(joined)
+ }
+
+ /**
+ * With AQE enabled, the executed plan is wrapped in an `AdaptiveSparkPlanExec` — AQE's
+ * entry point. This is the end-to-end proof AQE engaged on the full tree (not just an
+ * inner sub-plan like the previous `InterStageShuffle` approach produced).
+ */
+ @Test def testAqeAdaptivePlanWrapsEntireExecution(): Unit = {
+ val s = newSparkSession(aqeEnabled = true)
+ val joined = buildJoined(s)
+ joined.collect()
+ val plan = joined.queryExecution.executedPlan
+ val aqeNodes = collectAqe(plan)
+ assertTrue(
+ aqeNodes.nonEmpty,
+ s"Expected AdaptiveSparkPlanExec at the top of executed plan; got:\n${plan.treeString}")
+ }
+
+ /**
+ * `LanceProbe`, `LanceMerge`, and `LanceMaterialize` are all named in the executed
+ * plan's tree string. Confirms the three custom execs are actually wired in.
+ */
+ @Test def testAllThreeCustomExecsInTree(): Unit = {
+ val s = newSparkSession(aqeEnabled = true)
+ val joined = buildJoined(s)
+ joined.collect()
+ val tree = joined.queryExecution.executedPlan.treeString
+ assertTrue(tree.contains("LanceProbe"), s"Expected LanceProbe in executed plan; got:\n$tree")
+ assertTrue(tree.contains("LanceMerge"), s"Expected LanceMerge in executed plan; got:\n$tree")
+ assertTrue(
+ tree.contains("LanceMaterialize"),
+ s"Expected LanceMaterialize in executed plan; got:\n$tree")
+ }
+
+ /**
+ * Regression for the `missingInput` bug caught during the initial 3-exec split: if
+ * `producedAttributes` is not set, Spark's tree-string prefixes each custom node with
+ * `!` to flag "this node references attrs not in child.outputSet". A clean tree string
+ * has no `!` prefix. Asserts we haven't regressed.
+ */
+ @Test def testNoMissingInputBangPrefix(): Unit = {
+ val s = newSparkSession(aqeEnabled = false)
+ val joined = buildJoined(s)
+ joined.collect()
+ val tree = joined.queryExecution.executedPlan.treeString
+ // Every line that starts with optional whitespace + "!" is a missingInput warning.
+ // Allow `!` elsewhere (e.g. inside parenthesized text). Check for the known pattern:
+ // "+- !LanceProbe" or "!LanceMerge" etc.
+ assertFalse(
+ tree.contains("!LanceProbe") || tree.contains("!LanceMerge") ||
+ tree.contains("!LanceMaterialize"),
+ s"Found `!` missingInput prefix on a custom exec; got:\n$tree")
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def assertExchangeOnLeftId(joined: DataFrame): Unit = {
+ val plan = joined.queryExecution.executedPlan
+ val treeString = plan.treeString
+ // Use string-match on the tree. With AQE enabled, the Exchange can be nested inside
+ // `ShuffleQueryStageExec` / `AQEShuffleRead` wrappers whose `children` relationship
+ // isn't a plain SparkPlan child — walking via `.children` misses them. The string
+ // form is stable across AQE on/off and is what `df.explain()` prints, so matching
+ // the same text the user would see is both simpler and correct.
+ assertTrue(
+ treeString.contains("Exchange hashpartitioning(_leftId") ||
+ treeString.contains("hashpartitioning(_leftId"),
+ s"Expected Exchange hashpartitioning on _leftId; executedPlan:\n$treeString")
+ }
+
+ private def collectAqe(plan: SparkPlan): Seq[AdaptiveSparkPlanExec] = {
+ val hits = scala.collection.mutable.ArrayBuffer.empty[AdaptiveSparkPlanExec]
+ def walk(p: SparkPlan): Unit = {
+ p match {
+ case aqe: AdaptiveSparkPlanExec =>
+ hits += aqe
+ walk(aqe.executedPlan)
+ case _ =>
+ }
+ p.children.foreach(walk)
+ }
+ walk(plan)
+ hits.toSeq
+ }
+
+ private def buildJoined(s: SparkSession): DataFrame = {
+ val rng = new Random(17L)
+ val leftDf = buildLeft(s, rng, n = 8, dim = Dim)
+ val rightUri = writeRight(s, rng, n = 16, dim = Dim)
+ IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 2,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ }
+
+ private def buildLeft(s: SparkSession, rng: Random, n: Int, dim: Int) = {
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i), randomVector(rng, dim))
+ }
+ s.createDataFrame(rows.asJava, schema)
+ }
+
+ private def writeRight(s: SparkSession, rng: Random, n: Int, dim: Int): String = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i + 1000), randomVector(rng, dim))
+ }
+ val df = s.createDataFrame(rows.asJava, schema)
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ out
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinConsumerShapeTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinConsumerShapeTest.scala
new file mode 100644
index 000000000..b31c2ad12
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinConsumerShapeTest.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{DataFrame, RowFactory, SparkSession}
+import org.apache.spark.sql.functions.{count, lit}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Covers the exact consumer shapes that crashed an early 3-exec staged plan iteration
+ * during development: `count(*)`, `Aggregate`, and operators that reference NONE of
+ * the join output columns. Those shapes drove Catalyst's `ColumnPruning` to insert
+ * `Project(Nil)` wrappers between the custom staged operators, which codegen to 0-field
+ * `UnsafeRow`s and crashed the custom decoder (`AssertionError` / SIGSEGV under C2).
+ *
+ * `InterStageShuffle.mergeViaCatalystShuffle` sidesteps that entirely — no custom
+ * `LogicalPlan` / `SparkPlan` exists for `ColumnPruning` to wrap — but these tests
+ * confirm the property rather than relying on reasoning.
+ */
+class IndexedNearestJoinConsumerShapeTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 8
+ private val NumRight = 32
+ private val NumLeft = 8
+ private val K = 3
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-join-consumer-shape")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /** `df.count()` — the simplest case that crashed the reverted path. */
+ @Test def testCountSucceeds(): Unit = {
+ val joined = buildJoined()
+ val n = joined.count()
+ assertEquals((NumLeft * K).toLong, n, s"Expected ${NumLeft * K} rows; got $n")
+ }
+
+ /** `df.agg(count("*"))` — same conceptual shape, different entry point. */
+ @Test def testAggCountSucceeds(): Unit = {
+ val joined = buildJoined()
+ val result = joined.agg(count("*")).collect()
+ assertEquals(1, result.length)
+ assertEquals((NumLeft * K).toLong, result.head.getLong(0))
+ }
+
+ /**
+ * `df.select(lit(1))` — references NONE of the join's output columns. Under
+ * `ColumnPruning` this is the strongest form of "prune everything from my child";
+ * if any custom plan node were going to get wrapped in `Project(Nil)`, this is where.
+ */
+ @Test def testSelectLiteralSucceeds(): Unit = {
+ val joined = buildJoined()
+ val result = joined.select(lit(1).as("one")).collect()
+ assertEquals(NumLeft * K, result.length)
+ assertTrue(result.forall(_.getInt(0) == 1))
+ }
+
+ /** `df.collect()` — the normal case, asserts count and that real data is materialised. */
+ @Test def testCollectMaterialisesAllColumns(): Unit = {
+ val joined = buildJoined()
+ val rows = joined.collect()
+ assertEquals(NumLeft * K, rows.length)
+ // Assert nothing is empty/corrupt — every row should have non-null lid, qvec, rid,
+ // rvec, __score at the column positions the join output schema defines.
+ rows.foreach { row =>
+ assertNotNull(row.get(0), "lid (col 0) should be non-null")
+ assertNotNull(row.get(1), "qvec (col 1) should be non-null")
+ assertNotNull(row.get(2), "rid (col 2) should be non-null")
+ }
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def buildJoined(): DataFrame = {
+ val rng = new Random(23L)
+ val leftDf = buildLeft(rng)
+ val rightUri = writeRight(rng)
+ IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "qvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2")
+ }
+
+ private def buildLeft(rng: Random) = {
+ val schema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "qvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val rows = (0 until NumLeft).map { i =>
+ RowFactory.create(java.lang.Long.valueOf(i.toLong), randomVector(rng, Dim).toSeq.asJava)
+ }
+ spark.createDataFrame(rows.asJava, schema)
+ }
+
+ private def writeRight(rng: Random): String = {
+ val schema = new StructType(Array(
+ StructField("rid", LongType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val rows = (0 until NumRight).map { i =>
+ RowFactory.create(
+ java.lang.Long.valueOf((i + 1000).toLong),
+ randomVector(rng, Dim).toSeq.asJava)
+ }
+ val df = spark.createDataFrame(rows.asJava, schema)
+ val uri = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(uri)
+ uri
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinCorrectnessTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinCorrectnessTest.scala
new file mode 100644
index 000000000..96ff6dee8
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinCorrectnessTest.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * End-to-end correctness regression for the `InterStageShuffle.mergeViaCatalystShuffle`
+ * path. Builds a right side without any vector index — Lance then falls back to an exact
+ * per-fragment scan, which makes the join a recall = 1.0 oracle: the top-K refs per
+ * left row MUST equal the brute-force Scala top-K computed on the driver.
+ *
+ * This is the strongest correctness check available short of setting up an actual
+ * vector index. If the `repartition(col(_leftId))` → per-partition merge path dropped,
+ * reordered, or corrupted rows in any way, this would detect it.
+ */
+class IndexedNearestJoinCorrectnessTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 16
+ private val NumRight = 1000
+ private val NumLeft = 100
+ private val K = 10
+ private val Seed = 31L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-join-correctness")
+ .master("local[4]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.sql.shuffle.partitions", "8")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /** Every left row's top-K must match the brute-force Scala oracle exactly. */
+ @Test def testTopKMatchesBruteForceOracle(): Unit = {
+ val rng = new Random(Seed)
+ val (rightRows, rightVecs) = generateRows(rng, NumRight, Dim, idOffset = 1000)
+ val (leftRows, leftVecs) = generateRows(rng, NumLeft, Dim, idOffset = 0)
+
+ val rightUri = writeLance(rightRows, "rid", "rvec")
+ val leftDf = buildDf(leftRows, "lid", "qvec")
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "qvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid"))).collect()
+
+ // Build expected map: leftLid → sorted list of rid ASCENDING by distance.
+ val expected: Map[Long, Seq[Long]] = leftVecs.zipWithIndex.map { case (qvec, leftIdx) =>
+ val dists = rightVecs.zipWithIndex.map { case (rvec, rightIdx) =>
+ ((rightIdx + 1000).toLong, l2DistanceSquared(qvec, rvec))
+ }
+ val topK = dists.sortBy(_._2).take(K).map(_._1)
+ leftIdx.toLong -> topK
+ }.toMap
+
+ // Group actual rows by lid, sort by __score ASC, extract rid sequence.
+ val actualByLid: Map[Long, Seq[Long]] = joined
+ .groupBy(_.getLong(0))
+ .map { case (lid, rows) =>
+ lid -> rows.toSeq.sortBy(_.getFloat(3)).map(_.getLong(2))
+ }
+
+ assertEquals(NumLeft, actualByLid.size, s"Expected $NumLeft distinct leftIds in output")
+
+ expected.foreach { case (lid, expectedRids) =>
+ val actualRids = actualByLid.getOrElse(lid, Seq.empty)
+ assertEquals(
+ expectedRids,
+ actualRids,
+ s"Top-$K rids mismatch for lid=$lid:\n expected=$expectedRids\n actual=$actualRids")
+ }
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def generateRows(
+ rng: Random,
+ n: Int,
+ dim: Int,
+ idOffset: Int): (Seq[Row], Seq[Array[Float]]) = {
+ val vecs = (0 until n).map(_ => randomVector(rng, dim))
+ val rows = vecs.zipWithIndex.map { case (v, i) =>
+ RowFactory.create(java.lang.Long.valueOf((i + idOffset).toLong), v.toSeq.asJava)
+ }
+ (rows, vecs)
+ }
+
+ private def writeLance(rows: Seq[Row], idCol: String, vecCol: String): String = {
+ val schema = new StructType(Array(
+ StructField(idCol, LongType, nullable = false),
+ StructField(
+ vecCol,
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val df = spark.createDataFrame(rows.asJava, schema)
+ val uri = tempDir.resolve(s"ds_${System.nanoTime()}").toString
+ df.write.format("lance").save(uri)
+ uri
+ }
+
+ private def buildDf(rows: Seq[Row], idCol: String, vecCol: String) = {
+ val schema = new StructType(Array(
+ StructField(idCol, LongType, nullable = false),
+ StructField(
+ vecCol,
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ spark.createDataFrame(rows.asJava, schema)
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def l2DistanceSquared(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) {
+ val d = a(i) - b(i)
+ s += d * d
+ i += 1
+ }
+ s
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinFragmentGroupingTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinFragmentGroupingTest.scala
new file mode 100644
index 000000000..33fff1cf3
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinFragmentGroupingTest.scala
@@ -0,0 +1,282 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Phase 1.5 — fragment-grouped probing.
+ *
+ * The substantive change vs. Phase 1: with `probeParallelism > 1`, the rule splits Lance
+ * fragments into N groups, replicates each left row across the groups, and lets the merge stage
+ * aggregate N contributions per leftId via [[org.lance.spark.knn.internal.TopKHeap]].
+ *
+ * Coverage:
+ * - Oracle equivalence: top-K matches the brute-force oracle when probeParallelism = N. With
+ * no vector index, every per-fragment-group probe is exact (recall = 1.0), so the merged
+ * result must match exact brute force.
+ * - Plan-shape: the lineage contains TWO `ShuffledRDD`s — one from the replicate-and-
+ * partition-by-group step (probe), one from `reduceByKey` (merge). Phase 1 had only one.
+ * - Falls back gracefully when probeParallelism > numFragments (extra groups are empty).
+ */
+class IndexedNearestJoinFragmentGroupingTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 8
+ private val NumRight = 64
+ private val NumLeft = 16
+ private val Seed = 31L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-fragment-grouping-test")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * Top-K matches the brute-force oracle when `probeParallelism = 4` and the right dataset has
+ * at least 4 fragments. Confirms the merge function correctly combines per-fragment-group
+ * contributions.
+ */
+ @Test def testOracleEquivalenceWithFragmentGrouping(): Unit = {
+ val rng = new Random(Seed)
+ val leftDf = buildLeft(rng, NumLeft, Dim)
+ val rightUri = writeRight(rng, NumRight, Dim, fragments = 4)
+
+ val k = 5
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = k,
+ metric = "l2",
+ rightProjection = Some(Seq("rid", "rvec")),
+ probeParallelism = 4)
+
+ val result = joined.collect()
+ assertEquals(NumLeft * k, result.length, "expected k results per left row")
+
+ val byLid = result.groupBy(_.getAs[Int]("lid"))
+ val rightVecs = readRightVectors(rightUri)
+ val rightIds = readRightIds(rightUri)
+
+ byLid.foreach { case (lid, rows) =>
+ val sortedRows = rows.sortBy(_.getAs[Float]("__score"))
+ val leftVec = leftVectorFor(leftDf, lid)
+ val oracle = rightVecs.indices
+ .map(idx => (rightIds(idx), l2(leftVec, rightVecs(idx))))
+ .sortBy(_._2)
+ .take(k)
+
+ assertEquals(
+ oracle.map(_._1).toSet,
+ sortedRows.map(_.getAs[Int]("rid")).toSet,
+ s"top-K right ids for lid=$lid mismatch oracle (probeParallelism = 4)")
+
+ oracle.map(_._2).zip(sortedRows.map(_.getAs[Float]("__score"))).foreach {
+ case (expected, actual) =>
+ assertEquals(expected, actual, 1e-4f, s"score mismatch for lid=$lid")
+ }
+ }
+ }
+
+ /**
+ * Plan-shape: with `probeParallelism > 1` the lineage gains a second `ShuffledRDD` (the
+ * replicate-and-partition-by-group step). Phase 1's degenerate single-task path has only the
+ * merge-side shuffle.
+ */
+ @Test def testFragmentGroupingAddsExtraShuffleToLineage(): Unit = {
+ val rng = new Random(Seed + 1)
+ val leftDf = buildLeft(rng, 4, Dim)
+ val rightUri = writeRight(rng, 16, Dim, fragments = 2)
+
+ val phase1 = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 2,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 1)
+ val phase15 = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 2,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 2)
+
+ val phase1Lineage = phase1.rdd.toDebugString
+ val phase15Lineage = phase15.rdd.toDebugString
+
+ val countShuffles = (s: String) => "ShuffledRDD".r.findAllIn(s).length
+ val phase1Shuffles = countShuffles(phase1Lineage)
+ val phase15Shuffles = countShuffles(phase15Lineage)
+ assertTrue(
+ phase15Shuffles > phase1Shuffles,
+ s"Phase 1.5 should add at least one more ShuffledRDD; phase1=$phase1Shuffles, " +
+ s"phase15=$phase15Shuffles\nlineage:\n$phase15Lineage")
+ }
+
+ /**
+ * Phase 3 — skew handling. With `balanceFragmentsByRowCount = true`, fragment groups are
+ * balanced via LPT bin-packing on per-fragment row counts. The oracle equivalence test still
+ * holds — the ordering of fragments within a group doesn't change top-K results, only the
+ * load distribution.
+ */
+ @Test def testOracleEquivalenceWithRowCountBalancing(): Unit = {
+ val rng = new Random(Seed + 3)
+ val leftDf = buildLeft(rng, NumLeft, Dim)
+ val rightUri = writeRight(rng, NumRight, Dim, fragments = 4)
+
+ val k = 5
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = k,
+ metric = "l2",
+ rightProjection = Some(Seq("rid", "rvec")),
+ probeParallelism = 4,
+ balanceFragmentsByRowCount = true)
+
+ val result = joined.collect()
+ assertEquals(NumLeft * k, result.length)
+
+ val byLid = result.groupBy(_.getAs[Int]("lid"))
+ val rightVecs = readRightVectors(rightUri)
+ val rightIds = readRightIds(rightUri)
+ byLid.foreach { case (lid, rows) =>
+ val sorted = rows.sortBy(_.getAs[Float]("__score"))
+ val leftVec = leftVectorFor(leftDf, lid)
+ val oracle = rightVecs.indices
+ .map(idx => (rightIds(idx), l2(leftVec, rightVecs(idx))))
+ .sortBy(_._2)
+ .take(k)
+ assertEquals(
+ oracle.map(_._1).toSet,
+ sorted.map(_.getAs[Int]("rid")).toSet,
+ s"top-K mismatch with balanceFragmentsByRowCount = true (lid=$lid)")
+ }
+ }
+
+ /**
+ * `probeParallelism` > num fragments → extra groups are empty → result is still correct.
+ * Specifically, the rule degenerates to the Phase 1 path when only one non-empty group exists.
+ */
+ @Test def testProbeParallelismExceedingFragmentsStillCorrect(): Unit = {
+ val rng = new Random(Seed + 2)
+ val leftDf = buildLeft(rng, 4, Dim)
+ // Dataset with a single fragment — probeParallelism = 8 must still produce correct results.
+ val rightUri = writeRight(rng, 16, Dim, fragments = 1)
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 3,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ probeParallelism = 8)
+ assertEquals(4 * 3, joined.collect().length)
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def buildLeft(rng: Random, n: Int, dim: Int) = {
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i), randomVector(rng, dim))
+ }
+ spark.createDataFrame(rows.asJava, schema)
+ }
+
+ /**
+ * Write the right dataset, repartitioning the source DataFrame into `fragments` partitions
+ * before save so that Lance produces approximately one fragment per Spark partition.
+ */
+ private def writeRight(rng: Random, n: Int, dim: Int, fragments: Int): String = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i + 1000), randomVector(rng, dim))
+ }
+ val df = spark.createDataFrame(rows.asJava, schema).repartition(fragments)
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ out
+ }
+
+ private def readRightVectors(uri: String): Array[Array[Float]] =
+ spark.read.format("lance").load(uri).orderBy("rid").collect().map { r =>
+ r.getAs[Seq[Float]]("rvec").toArray
+ }
+
+ private def readRightIds(uri: String): Array[Int] =
+ spark.read.format("lance").load(uri).orderBy("rid").collect().map(_.getAs[Int]("rid"))
+
+ private def leftVectorFor(left: org.apache.spark.sql.DataFrame, lid: Int): Array[Float] = {
+ val r = left.filter(s"lid = $lid").collect().head
+ r.getAs[Seq[Float]]("lvec").toArray
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def l2(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) {
+ val d = a(i) - b(i); s += d * d; i += 1
+ }
+ s
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinIvfPqRecallTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinIvfPqRecallTest.scala
new file mode 100644
index 000000000..8eb5e64c8
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinIvfPqRecallTest.scala
@@ -0,0 +1,324 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+import org.lance.spark.knn.internal.LanceVectorIndexBuilder
+import org.lance.spark.knn.testutil.ClusteredEmbeddings
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Phase 3 — real-recall validation against an IVF-PQ-indexed Lance dataset.
+ *
+ * Builds an IVF-PQ vector index via Lance's `Dataset.createIndex` Java binding, then runs
+ * `IndexedNearestJoin` and measures recall@K vs. the brute-force ground truth. With an index
+ * Lance returns *approximate* top-K, so recall is < 1.0 — the point of this test is to verify:
+ *
+ * 1. The indexed path actually engages (Lance's `useIndex` defaults to true on a Query
+ * against an indexed column; our `LanceProbe.probe` doesn't override it).
+ * 2. Recall at the default settings is in a sane range — our small synthetic dataset is
+ * small enough that recall should be high (most rows survive the IVF cluster cut).
+ * 3. `refineFactor > 1` improves recall by re-ranking more candidates with exact distance.
+ *
+ * Until this test, the 608x / 17.4x benchmark headlines were on a NO-INDEX Lance dataset where
+ * Lance's brute-force per-fragment scan made everything exact (recall = 1.0). The
+ * approximate-vs-exact recall trade-off that an indexed connector exposes was unmeasured. This
+ * test closes that gap.
+ *
+ * Setup specifics: 1024 right rows, dim 32, 4 IVF partitions, 8 PQ sub-vectors. The dataset
+ * is intentionally tiny so the test runs in a few seconds — production-realistic dataset
+ * sizes would need much larger N.
+ */
+class IndexedNearestJoinIvfPqRecallTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 32
+ private val NumRight = 1024
+ private val NumLeft = 32
+ private val K = 10
+ private val Seed = 0xCAFEL
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-ivfpq-recall")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * The headline test: build IVF-PQ, run IndexedNearestJoin, measure recall@10 against the
+ * brute-force oracle. With 1024 rows × 4 IVF partitions, each partition holds ~256 rows;
+ * a default-`nprobes` query hits ~1 partition, so we expect recall to be lower than 1.0
+ * but still substantial.
+ */
+ @Test def testIvfPqRecallReasonableAtDefaults(): Unit = {
+ val (leftDf, leftIds, leftVecs) = buildLeft()
+ val (rightUri, rightIds, rightVecs) = writeRight()
+ LanceVectorIndexBuilder.buildIvfPq(
+ datasetUri = rightUri,
+ vectorColumn = "rvec",
+ numPartitions = 4,
+ numSubVectors = 8,
+ numBits = 8)
+ assertEquals(
+ 1,
+ LanceVectorIndexBuilder.listIndexCount(rightUri),
+ "expected exactly one index after build")
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+
+ val rows = joined.collect()
+ val recall = computeRecallAtK(rows, leftIds, leftVecs, rightIds, rightVecs, K)
+ println(s" IVF-PQ recall@$K (no refine, default nprobes): $recall")
+ // With 1024 rows and 4 IVF partitions, default nprobes = 1, the index returns ~256
+ // candidates per query. Recall should be substantially > 0; our acceptance threshold
+ // is loose because IVF-PQ recall depends on the random data layout — anything < 0.3
+ // would suggest a real bug, not an inherent IVF limitation.
+ assertTrue(recall > 0.3, s"recall@$K=$recall too low; index path probably not engaging")
+ }
+
+ /**
+ * Production-realistic distribution: clustered Gaussian mixture, unit-sphere-normalized —
+ * the geometry of typical sentence-transformer / image-feature embeddings. The benchmark and
+ * the recall test elsewhere use uniform-random vectors over [0, 1]^d, which is the WORST
+ * case for IVF (k-means has no natural cluster structure to latch onto). This test exercises
+ * the indexed path on a more realistic distribution and asserts:
+ *
+ * 1. Recall@K on clustered data >= 0.5 at default IVF-PQ settings. If realistic data
+ * collapsed to coin-flip recall, the indexed path wouldn't be useful in production.
+ * 2. Both uniform and clustered recall numbers are printed, so a reviewer can see whether
+ * the realistic case actually helps in practice (it should — see the file's preamble).
+ *
+ * Why we don't `assert(clustered >= uniform)`: Lance's IVF training (k-means initialization)
+ * is non-deterministic across JVM sessions, so on a tiny 1024-row dataset the run-to-run
+ * noise in either recall number routinely exceeds the structural advantage of realistic
+ * data. A reliable comparison would need either (a) averaging over many seeds, which is
+ * slow and fragile in CI, or (b) much larger N where the structural effect dominates noise.
+ * We chose (c): print both, assert only the realistic-floor invariant.
+ */
+ @Test def testClusteredEmbeddingsRecallSurvives(): Unit = {
+ val (uniformDf, uniformIds, uniformVecs) =
+ buildLeftFromVectors(generateUniform(NumLeft, Dim, Seed))
+ val (uniformUri, uniformRightIds, uniformRightVecs) =
+ writeRightFromVectors(generateUniform(NumRight, Dim, Seed + 1))
+ LanceVectorIndexBuilder.buildIvfPq(uniformUri, "rvec", numPartitions = 4, numSubVectors = 8)
+
+ val (clusteredDf, clusteredIds, clusteredVecs) = buildLeftFromVectors(
+ ClusteredEmbeddings.generate(NumLeft, Dim, numClusters = 4, seed = Seed + 2))
+ val (clusteredUri, clusteredRightIds, clusteredRightVecs) = writeRightFromVectors(
+ ClusteredEmbeddings.generate(NumRight, Dim, numClusters = 16, seed = Seed + 3))
+ LanceVectorIndexBuilder.buildIvfPq(
+ clusteredUri,
+ "rvec",
+ numPartitions = 4,
+ numSubVectors = 8)
+
+ val uniformRecall = recallAgainst(
+ uniformDf,
+ uniformUri,
+ uniformIds,
+ uniformVecs,
+ uniformRightIds,
+ uniformRightVecs)
+ val clusteredRecall = recallAgainst(
+ clusteredDf,
+ clusteredUri,
+ clusteredIds,
+ clusteredVecs,
+ clusteredRightIds,
+ clusteredRightVecs)
+ println(
+ s" IVF-PQ recall@$K: uniform=$uniformRecall, clustered=$clusteredRecall " +
+ "(uniform = IVF worst case; clustered = production-shaped)")
+
+ assertTrue(
+ clusteredRecall >= 0.5,
+ s"clustered-data recall@$K=$clusteredRecall is unexpectedly low; " +
+ "defaults should comfortably exceed 0.5 on production-shaped embeddings — " +
+ "if this fails, suspect a regression in Lance's index path or in our probe wiring")
+ }
+
+ /**
+ * `refineFactor > 1` engages Lance's exact-distance re-rank: fetch `K * refineFactor`
+ * approximate candidates, re-rank, trim back to K. Strictly improves (or matches) recall
+ * vs. no refine. We assert the >= relation rather than a strict > so the test isn't flaky
+ * on tiny datasets where both paths happen to find the same K rows.
+ */
+ @Test def testRefineFactorImprovesRecall(): Unit = {
+ val (leftDf, leftIds, leftVecs) = buildLeft()
+ val (rightUri, rightIds, rightVecs) = writeRight()
+ LanceVectorIndexBuilder.buildIvfPq(rightUri, "rvec", numPartitions = 4)
+
+ val baseline = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ val refined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ refineFactor = Some(8))
+
+ val recallBaseline =
+ computeRecallAtK(baseline.collect(), leftIds, leftVecs, rightIds, rightVecs, K)
+ val recallRefined =
+ computeRecallAtK(refined.collect(), leftIds, leftVecs, rightIds, rightVecs, K)
+ println(s" IVF-PQ recall@$K: no refine = $recallBaseline, refineFactor=8 = $recallRefined")
+ assertTrue(
+ recallRefined >= recallBaseline,
+ s"refineFactor should not hurt recall: baseline=$recallBaseline, refined=$recallRefined")
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def buildLeft(): (org.apache.spark.sql.DataFrame, Array[Int], Array[Array[Float]]) =
+ buildLeftFromVectors(generateUniform(NumLeft, Dim, Seed))
+
+ private def writeRight(): (String, Array[Int], Array[Array[Float]]) =
+ writeRightFromVectors(generateUniform(NumRight, Dim, Seed + 1))
+
+ /** Build a left-side DataFrame from a pre-generated vector array. */
+ private def buildLeftFromVectors(
+ vectors: Array[Array[Float]])
+ : (org.apache.spark.sql.DataFrame, Array[Int], Array[Array[Float]]) = {
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val ids = (0 until vectors.length).toArray
+ val rows = ids.zip(vectors).map { case (id, v) =>
+ RowFactory.create(Integer.valueOf(id), v)
+ }
+ val df = spark.createDataFrame(rows.toSeq.asJava, schema)
+ (df, ids, vectors)
+ }
+
+ /** Write a right-side Lance dataset from a pre-generated vector array. */
+ private def writeRightFromVectors(
+ vectors: Array[Array[Float]]): (String, Array[Int], Array[Array[Float]]) = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val ids = (0 until vectors.length).map(_ + 100000).toArray
+ val rows = ids.zip(vectors).map { case (id, v) =>
+ RowFactory.create(Integer.valueOf(id), v)
+ }
+ val df = spark.createDataFrame(rows.toSeq.asJava, schema)
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ (out, ids, vectors)
+ }
+
+ /** Run an indexed nearest join against the given right dataset and compute recall@K. */
+ private def recallAgainst(
+ leftDf: org.apache.spark.sql.DataFrame,
+ rightUri: String,
+ leftIds: Array[Int],
+ leftVecs: Array[Array[Float]],
+ rightIds: Array[Int],
+ rightVecs: Array[Array[Float]]): Double = {
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ computeRecallAtK(joined.collect(), leftIds, leftVecs, rightIds, rightVecs, K)
+ }
+
+ /** Uniform-random vectors over the unit hypercube — the IVF-worst-case data distribution. */
+ private def generateUniform(n: Int, dim: Int, seed: Long): Array[Array[Float]] = {
+ val rng = new Random(seed)
+ Array.fill(n)(randomVector(rng, dim))
+ }
+
+ /**
+ * Mean recall@K across all left rows: |intersection of indexed top-K with brute-force
+ * top-K| divided by K. A value of 1.0 means the indexed path returned the same K rows as
+ * brute force; lower values mean the IVF cluster cut excluded some true neighbors.
+ */
+ private def computeRecallAtK(
+ joinedRows: Array[org.apache.spark.sql.Row],
+ leftIds: Array[Int],
+ leftVecs: Array[Array[Float]],
+ rightIds: Array[Int],
+ rightVecs: Array[Array[Float]],
+ k: Int): Double = {
+ val byLid = joinedRows.groupBy(_.getAs[Int]("lid"))
+ val perLidRecall = leftIds.zip(leftVecs).map { case (lid, lvec) =>
+ val oracle = rightVecs.indices
+ .map(i => (rightIds(i), l2(lvec, rightVecs(i))))
+ .sortBy(_._2)
+ .take(k)
+ .map(_._1)
+ .toSet
+ val actual = byLid.getOrElse(lid, Array.empty).map(_.getAs[Int]("rid")).toSet
+ val hit = (oracle intersect actual).size.toDouble
+ hit / k
+ }
+ perLidRecall.sum / perLidRecall.length
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def l2(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) { val d = a(i) - b(i); s += d * d; i += 1 }
+ s
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinJitStressTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinJitStressTest.scala
new file mode 100644
index 000000000..ef204097c
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinJitStressTest.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Runs the join at the exact scale that SIGSEGV'd the reverted 3-exec staged code:
+ * 10K right × 100 left × dim=128, K=10, with a crossJoin JIT-warmup preceding the
+ * join iterations to force C2 to compile all the hot UnsafeRow accessors.
+ *
+ * The originally-reported crash signature (`UnsafeRow.getLong(I)J` SIGSEGV in C2-compiled
+ * code, hs_err from early-development reproducer) fires on this exact shape. A clean pass here is the strongest
+ * available evidence that `InterStageShuffle.mergeViaCatalystShuffle` doesn't inherit the
+ * fragility.
+ *
+ * Right side kept at 10K (not the reverted benchmark's 100K) because we also run
+ * correctness + count tests in the same module and don't need the extra scan cost —
+ * the crash was codec-fragility, not a scale-dependent race. The revert commit's repro
+ * fired at 100K too.
+ */
+class IndexedNearestJoinJitStressTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val NumRight = 10000
+ private val NumLeft = 100
+ private val Dim = 128
+ private val K = 10
+ private val Seed = 1337L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-join-jit-stress")
+ .master("local[4]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.driver.memory", "4g")
+ .config("spark.sql.shuffle.partitions", "4")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * JIT warmup via crossJoin + group-by (mirrors the benchmark's baseline config A
+ * which ran first and produced the JIT state the staged config B/C/D/E then crashed
+ * in), followed by 20 iterations of the join at benchmark scale. Each iteration
+ * collects the full result set to force the whole pipeline to run end-to-end.
+ */
+ @Test def testRepeatedJoinAtBenchmarkScale(): Unit = {
+ warmupJit()
+
+ val rng = new Random(Seed)
+ val rightUri = writeRight(rng)
+ val leftDf = buildLeft(rng)
+
+ var iter = 0
+ val iterations = 20
+ while (iter < iterations) {
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "qvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ val rows = joined.collect()
+ assertEquals(NumLeft * K, rows.length, s"iteration $iter wrong row count")
+ iter += 1
+ }
+ }
+
+ /**
+ * Count-based variant at the same scale. This is what exercises `ColumnPruning` and
+ * was the proximate cause of the revert's crash (pruning inserted Project(Nil) that
+ * emitted 0-field UnsafeRows). Running count() 20 times at this scale is the tightest
+ * analog of the reverted repro.
+ */
+ @Test def testRepeatedCountAtBenchmarkScale(): Unit = {
+ warmupJit()
+
+ val rng = new Random(Seed + 1L)
+ val rightUri = writeRight(rng)
+ val leftDf = buildLeft(rng)
+
+ var iter = 0
+ val iterations = 20
+ while (iter < iterations) {
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "qvec",
+ rightVecCol = "rvec",
+ k = K,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ val n = joined.count()
+ assertEquals((NumLeft * K).toLong, n, s"iteration $iter wrong count")
+ iter += 1
+ }
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def warmupJit(): Unit = {
+ // ~250K-row crossJoin-groupBy to build JIT state on UnsafeRow accessors, Exchange,
+ // HashAggregate. Mirrors what `IndexedNearestJoinBenchmark.runSparkCrossJoinBaseline`
+ // runs as config A immediately before the staged configs that crashed.
+ val a = spark.range(0, 500L).toDF("a")
+ val b = spark.range(0, 500L).toDF("b")
+ a.crossJoin(b).groupBy(col("a")).count().count()
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def writeRight(rng: Random): String = {
+ val schema = new StructType(Array(
+ StructField("rid", LongType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val rows = new java.util.ArrayList[Row](NumRight)
+ var i = 0
+ while (i < NumRight) {
+ rows.add(RowFactory.create(
+ java.lang.Long.valueOf(i.toLong),
+ randomVector(rng, Dim).toSeq.asJava))
+ i += 1
+ }
+ val df = spark.createDataFrame(rows, schema)
+ val uri = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(uri)
+ uri
+ }
+
+ private def buildLeft(rng: Random) = {
+ val schema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "qvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val rows = new java.util.ArrayList[Row](NumLeft)
+ var i = 0
+ while (i < NumLeft) {
+ rows.add(RowFactory.create(
+ java.lang.Long.valueOf(i.toLong),
+ randomVector(rng, Dim).toSeq.asJava))
+ i += 1
+ }
+ spark.createDataFrame(rows, schema)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinPlanShapeTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinPlanShapeTest.scala
new file mode 100644
index 000000000..b8c933ca1
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinPlanShapeTest.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Plan-shape assertions for the 3-exec staged pipeline.
+ *
+ * `IndexedNearestJoin.apply` builds a `LanceMaterializeLogicalPlan → LanceMergeLogicalPlan →
+ * LanceProbeLogicalPlan` tree, which lowers to `LanceMaterializeExec → LanceMergeExec →
+ * ShuffleExchangeExec(inserted by EnsureRequirements) → LanceProbeExec → user-plan`.
+ *
+ * This test asserts the shape at the executed-plan level. Deeper AQE / correctness checks
+ * live in [[IndexedNearestJoinAqeVisibilityTest]] and [[IndexedNearestJoinCorrectnessTest]].
+ */
+class IndexedNearestJoinPlanShapeTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 8
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-join-plan-shape")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * The executed plan's tree string must contain all three custom exec names. The
+ * strategy (`LanceKnnStagedStrategy`) must have lowered each logical node to its exec.
+ */
+ @Test def testExecutedPlanContainsAllThreeCustomExecs(): Unit = {
+ val rng = new Random(11L)
+ val leftDf = buildLeft(rng, n = 4, dim = Dim)
+ val rightUri = writeRight(rng, n = 8, dim = Dim)
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 2,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+
+ joined.collect() // stabilise AQE plan
+ val tree = joined.queryExecution.executedPlan.treeString
+ assertTrue(tree.contains("LanceProbe"), s"Expected LanceProbe exec in plan; got:\n$tree")
+ assertTrue(tree.contains("LanceMerge"), s"Expected LanceMerge exec in plan; got:\n$tree")
+ assertTrue(
+ tree.contains("LanceMaterialize"),
+ s"Expected LanceMaterialize exec in plan; got:\n$tree")
+ assertTrue(
+ tree.contains("Exchange"),
+ s"Expected Exchange (from ClusteredDistribution) in plan; got:\n$tree")
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def buildLeft(rng: Random, n: Int, dim: Int) = {
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i), randomVector(rng, dim))
+ }
+ spark.createDataFrame(rows.asJava, schema)
+ }
+
+ private def writeRight(rng: Random, n: Int, dim: Int): String = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i + 1000), randomVector(rng, dim))
+ }
+ val df = spark.createDataFrame(rows.asJava, schema)
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ out
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinTest.scala
new file mode 100644
index 000000000..760ceb391
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/IndexedNearestJoinTest.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * End-to-end correctness test for [[IndexedNearestJoin]].
+ *
+ * The right side is a Lance dataset written without a vector index, which means Lance falls back
+ * to brute-force per-fragment search. That makes the result a recall = 1.0 oracle: the join's
+ * top-K must equal the top-K we compute in plain Scala. Mismatches are real bugs, not recall
+ * issues.
+ *
+ * These tests intentionally don't exercise indexed (approximate) search yet — that's the next
+ * test class to add once we wire vector index DDL through Lance's Java API for the test setup.
+ */
+class IndexedNearestJoinTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 8
+ private val NumRight = 64
+ private val NumLeft = 16
+ private val Seed = 7L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("indexed-nearest-join-test")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * Top-K result for every left row matches the brute-force oracle exactly. With no vector index
+ * Lance does an exact scan, so this is the strictest correctness check we can write.
+ */
+ @Test def testInnerJoinMatchesBruteForceOracle(): Unit = {
+ val rng = new Random(Seed)
+ val leftDf = buildLeft(rng, NumLeft, Dim)
+ val rightUri = writeRight(rng, NumRight, Dim)
+
+ val k = 5
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = k,
+ metric = "l2",
+ // Project a small subset of right columns to keep the test grounded in the expected shape.
+ rightProjection = Some(Seq("rid", "rvec")))
+
+ val result = joined.collect()
+ // Every left row produced exactly k matches.
+ assertEquals(NumLeft * k, result.length, "expected k results per left row")
+
+ // Group by the left id (`lid`), compute oracle, compare.
+ val byLid = result.groupBy(_.getAs[Int]("lid"))
+ val rightVecs: Array[Array[Float]] = readRightVectors(rightUri)
+ val rightIds: Array[Int] = readRightIds(rightUri)
+
+ byLid.foreach { case (lid, rows) =>
+ val sortedRows = rows.sortBy(_.getAs[Float]("__score"))
+ val leftVec = leftVectorFor(leftDf, lid)
+ val oracle = rightVecs.indices
+ .map(idx => (rightIds(idx), l2(leftVec, rightVecs(idx))))
+ .sortBy(_._2)
+ .take(k)
+
+ val actualIds = sortedRows.map(_.getAs[Int]("rid"))
+ val actualScores = sortedRows.map(_.getAs[Float]("__score"))
+
+ // Compare ids (set equality up to ties is enough — score equality below catches ordering).
+ assertEquals(
+ oracle.map(_._1).toSet,
+ actualIds.toSet,
+ s"top-K right ids for lid=$lid mismatch oracle")
+ // Score values match within float tolerance.
+ oracle.map(_._2).zip(actualScores).foreach { case (expected, actualScore) =>
+ assertEquals(
+ expected,
+ actualScore,
+ 1e-4f,
+ s"score mismatch for lid=$lid: oracle=${oracle.map(_._2)} actual=$actualScores")
+ }
+ }
+ }
+
+ /**
+ * Output schema is `left.* ++ right.* ++ __score`. Verifies the column carry-through machinery,
+ * including projection-driven right schema selection.
+ */
+ @Test def testOutputSchemaCarriesLeftThenRightThenScore(): Unit = {
+ val leftDf = buildLeft(new Random(Seed), 4, Dim)
+ val rightUri = writeRight(new Random(Seed + 1), 8, Dim)
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 2,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+
+ val expectedFieldNames = Seq("lid", "lvec", "rid", "__score")
+ assertEquals(expectedFieldNames, joined.schema.fieldNames.toSeq)
+ // Right columns are widened to nullable to support left-outer; left fields keep their
+ // declared nullability.
+ assertTrue(joined.schema("rid").nullable, "right-side `rid` should be widened to nullable")
+ assertTrue(joined.schema("__score").nullable, "score should be nullable")
+ }
+
+ /**
+ * Phase 3 — `refineFactor` parameter passes through the pipeline without affecting correctness
+ * on an unindexed dataset. Lance's brute-force scan is already exact, so any refine factor
+ * yields the same result as without it. The point of this test is wiring: confirm
+ * `IndexedNearestJoin.apply` plumbs the parameter to `LanceProbe.probe` via the stage Conf
+ * without throwing. Real recall improvement only kicks in once an IVF-PQ index is built — that
+ * test is a Phase 3.x follow-up.
+ */
+ @Test def testRefineFactorPassesThroughWithoutBreakingCorrectness(): Unit = {
+ val rng = new Random(Seed)
+ val leftDf = buildLeft(rng, NumLeft, Dim)
+ val rightUri = writeRight(rng, NumRight, Dim)
+ val k = 5
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = k,
+ metric = "l2",
+ rightProjection = Some(Seq("rid", "rvec")),
+ refineFactor = Some(3))
+
+ val result = joined.collect()
+ assertEquals(NumLeft * k, result.length)
+
+ val byLid = result.groupBy(_.getAs[Int]("lid"))
+ val rightVecs = readRightVectors(rightUri)
+ val rightIds = readRightIds(rightUri)
+ byLid.foreach { case (lid, rows) =>
+ val sorted = rows.sortBy(_.getAs[Float]("__score"))
+ val leftVec = leftVectorFor(leftDf, lid)
+ val oracle = rightVecs.indices
+ .map(idx => (rightIds(idx), l2(leftVec, rightVecs(idx))))
+ .sortBy(_._2)
+ .take(k)
+ assertEquals(
+ oracle.map(_._1).toSet,
+ sorted.map(_.getAs[Int]("rid")).toSet,
+ s"top-K mismatch with refineFactor=3 (lid=$lid)")
+ }
+ }
+
+ /**
+ * `outerJoin = true` should preserve a left row when no right rows match — but with an unindexed
+ * right side every probe always returns k results, so the no-match case can't happen
+ * organically. We approximate it by passing a left row with a NULL vector, which `extractVector`
+ * surfaces as zero-score "no result" and the outer path emits with NULL right columns.
+ */
+ @Test def testLeftOuterPreservesUnmatchedLeftRowsWithNullVector(): Unit = {
+ val nullSchema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = true,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val rows = Seq(
+ RowFactory.create(Integer.valueOf(1), null) // null vector; should surface as a no-match left
+ )
+ val leftDf = spark.createDataFrame(rows.asJava, nullSchema)
+ val rightUri = writeRight(new Random(Seed + 2), 8, Dim)
+
+ val joined = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 3,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")),
+ outerJoin = true)
+
+ val rows2 = joined.collect()
+ assertEquals(1, rows2.length, "outer join should preserve the single null-vector left row")
+ val r = rows2.head
+ assertEquals(1, r.getAs[Int]("lid"))
+ assertTrue(
+ r.isNullAt(joined.schema.fieldIndex("rid")),
+ "rid should be NULL on outer-join no-match")
+ assertTrue(
+ r.isNullAt(joined.schema.fieldIndex("__score")),
+ "score should be NULL on outer-join no-match")
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def buildLeft(rng: Random, n: Int, dim: Int) = {
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i), randomVector(rng, dim))
+ }
+ spark.createDataFrame(rows.asJava, schema)
+ }
+
+ private def writeRight(rng: Random, n: Int, dim: Int): String = {
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", dim.toLong).build())))
+ val rows = (0 until n).map { i =>
+ RowFactory.create(Integer.valueOf(i + 1000), randomVector(rng, dim))
+ }
+ val df = spark.createDataFrame(rows.asJava, schema)
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ out
+ }
+
+ private def readRightVectors(uri: String): Array[Array[Float]] = {
+ val df = spark.read.format("lance").load(uri).orderBy("rid")
+ df.collect().map { r =>
+ r.getAs[Seq[Float]]("rvec").toArray
+ }
+ }
+
+ private def readRightIds(uri: String): Array[Int] = {
+ val df = spark.read.format("lance").load(uri).orderBy("rid")
+ df.collect().map(_.getAs[Int]("rid"))
+ }
+
+ private def leftVectorFor(left: org.apache.spark.sql.DataFrame, lid: Int): Array[Float] = {
+ val r = left.filter(s"lid = $lid").collect().head
+ r.getAs[Seq[Float]]("lvec").toArray
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def l2(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) {
+ val d = a(i) - b(i); s += d * d; i += 1
+ }
+ s
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala
new file mode 100644
index 000000000..4f66f0840
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/LanceKnnImplicitsTest.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn
+
+import org.apache.spark.sql.{RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * End-to-end tests for the `df.kNearestJoin(rightDf, ...)` extension. Three things to
+ * verify:
+ *
+ * 1. The extension returns the same rows as `IndexedNearestJoin.apply(uri, ...)` — the
+ * wrapper just changes the call site, not the semantics.
+ * 2. URI extraction handles a plain Lance `spark.read.load` — the common case.
+ * 3. URI extraction throws cleanly when the right DataFrame isn't backed by a Lance scan
+ * (e.g. created from in-memory rows). Bad input must fail fast with a helpful message,
+ * not surface as a confusing runtime error inside the probe.
+ *
+ * The Phase 0 oracle test in `LanceProbeValidationTest` covers correctness of the underlying
+ * probe; we can keep these tests light and not re-validate that.
+ */
+class LanceKnnImplicitsTest {
+
+ import LanceKnnImplicits._
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ private val Dim = 8
+ private val NumRight = 64
+ private val NumLeft = 8
+ private val Seed = 17L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("lance-knn-implicits-test")
+ .master("local[2]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ /**
+ * Happy path: a Lance-backed right DataFrame. Extension returns the same row-count and
+ * rid set per left row as the URI-string `IndexedNearestJoin.apply` form.
+ */
+ @Test def testKNearestJoinAgainstLanceScanMatchesUriForm(): Unit = {
+ val (leftDf, _, _) = buildLeft()
+ val (rightUri, _, _) = writeRight()
+ val rightDf = spark.read.format("lance").load(rightUri)
+
+ val viaExtension = leftDf
+ .kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 5,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ .collect()
+ val viaUri = IndexedNearestJoin(
+ left = leftDf,
+ rightLanceUri = rightUri,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 5,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ .collect()
+
+ assertEquals(viaUri.length, viaExtension.length)
+ val byLid = (rs: Array[org.apache.spark.sql.Row]) =>
+ rs.groupBy(_.getAs[Int]("lid")).map { case (lid, group) =>
+ lid -> group.map(_.getAs[Int]("rid")).toSet
+ }
+ assertEquals(byLid(viaUri), byLid(viaExtension))
+ }
+
+ /**
+ * The right DataFrame is wrapped in a `Filter` (e.g. user wrote `docs.filter("rid > 0")`).
+ * The URI extractor must walk past it to the underlying Lance relation.
+ */
+ @Test def testFilterOnRightStillExtractsUri(): Unit = {
+ val (leftDf, _, _) = buildLeft()
+ val (rightUri, _, _) = writeRight()
+ val rightDf = spark.read.format("lance").load(rightUri).filter("rid > 0")
+
+ val joined = leftDf
+ .kNearestJoin(
+ right = rightDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 3,
+ metric = "l2",
+ rightProjection = Some(Seq("rid")))
+ .collect()
+ assertEquals(NumLeft * 3, joined.length, "expected k results per left row")
+ }
+
+ /**
+ * A DataFrame backed by Parquet (or any non-Lance format) must also fail the Lance-only
+ * guard — the API contract is `format("lance").load(...)` specifically, not "any DataFrame
+ * Spark can read." Catches the case where a user wires in the wrong reader by mistake.
+ */
+ @Test def testParquetRightThrowsClearError(): Unit = {
+ val (leftDf, _, _) = buildLeft()
+ val parquetSchema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false)))
+ val rows = Seq(
+ RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f)),
+ RowFactory.create(Integer.valueOf(2), Array.fill(Dim)(0.5f)))
+ val parquetPath = tempDir.resolve(s"docs_${System.nanoTime()}.parquet").toString
+ spark.createDataFrame(rows.asJava, parquetSchema).write.parquet(parquetPath)
+ val parquetDf = spark.read.parquet(parquetPath)
+
+ val ex = assertThrows(
+ classOf[IllegalArgumentException],
+ () =>
+ leftDf.kNearestJoin(
+ right = parquetDf,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 3,
+ metric = "l2"))
+ assertTrue(
+ ex.getMessage.contains("Lance scan"),
+ s"expected error message to mention Lance scan for parquet input; got: ${ex.getMessage}")
+ }
+
+ /**
+ * Non-Lance DataFrame wrapped in a `SubqueryAlias` (via `as("d")`) must still fail. The
+ * URI extractor walks `SubqueryAlias` to find the underlying relation; if the underlying
+ * is not Lance, alias unwrapping must NOT silently accept it.
+ */
+ @Test def testNonLanceUnderAliasThrowsClearError(): Unit = {
+ val (leftDf, _, _) = buildLeft()
+ val rows = Seq(RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f)))
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false)))
+ val notLance = spark.createDataFrame(rows.asJava, schema).as("d")
+
+ val ex = assertThrows(
+ classOf[IllegalArgumentException],
+ () =>
+ leftDf.kNearestJoin(
+ right = notLance,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 3,
+ metric = "l2"))
+ assertTrue(
+ ex.getMessage.contains("Lance scan"),
+ s"alias-wrapped non-Lance must still fail; got: ${ex.getMessage}")
+ }
+
+ /**
+ * A DataFrame built from in-memory rows is NOT a Lance scan — the extension should throw
+ * an `IllegalArgumentException` with a message naming the constraint, so the user knows
+ * to hand a real Lance DataFrame instead.
+ */
+ @Test def testNonLanceRightThrowsClearError(): Unit = {
+ val (leftDf, _, _) = buildLeft()
+ val ridSchema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField("rvec", ArrayType(FloatType, containsNull = false), nullable = false)))
+ val notLance = spark.createDataFrame(
+ Seq(RowFactory.create(Integer.valueOf(1), Array.fill(Dim)(0.0f))).asJava,
+ ridSchema)
+
+ val ex = assertThrows(
+ classOf[IllegalArgumentException],
+ () =>
+ leftDf.kNearestJoin(
+ right = notLance,
+ leftVecCol = "lvec",
+ rightVecCol = "rvec",
+ k = 3,
+ metric = "l2"))
+ assertTrue(
+ ex.getMessage.contains("Lance scan"),
+ s"expected error message to mention Lance scan; got: ${ex.getMessage}")
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ private def buildLeft(): (org.apache.spark.sql.DataFrame, Array[Int], Array[Array[Float]]) = {
+ val rng = new Random(Seed)
+ val schema = new StructType(Array(
+ StructField("lid", IntegerType, nullable = false),
+ StructField(
+ "lvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val ids = (0 until NumLeft).toArray
+ val vecs = ids.map(_ => randomVector(rng, Dim))
+ val rows = ids.zip(vecs).map { case (id, v) => RowFactory.create(Integer.valueOf(id), v) }
+ val df = spark.createDataFrame(rows.toSeq.asJava, schema)
+ (df, ids, vecs)
+ }
+
+ private def writeRight(): (String, Array[Int], Array[Array[Float]]) = {
+ val rng = new Random(Seed + 1)
+ val schema = new StructType(Array(
+ StructField("rid", IntegerType, nullable = false),
+ StructField(
+ "rvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+ val ids = (0 until NumRight).map(_ + 1000).toArray
+ val vecs = ids.map(_ => randomVector(rng, Dim))
+ val rows = ids.zip(vecs).map { case (id, v) => RowFactory.create(Integer.valueOf(id), v) }
+ val df = spark.createDataFrame(rows.toSeq.asJava, schema)
+ val out = tempDir.resolve(s"right_${System.nanoTime()}").toString
+ df.write.format("lance").save(out)
+ (out, ids, vecs)
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/benchmark/InterStagePayloadOverheadBench.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/benchmark/InterStagePayloadOverheadBench.scala
new file mode 100644
index 000000000..6972e0429
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/benchmark/InterStagePayloadOverheadBench.scala
@@ -0,0 +1,242 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.benchmark
+
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.types._
+import org.lance.spark.knn.internal.{ProbedLeft, ScoredRowRef}
+
+import java.util.{Locale, Random}
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+
+/**
+ * Microbenchmark for inter-stage payload encoding cost. The question this answers:
+ *
+ * If we split the 2.12 module's RDD pipeline into 3 explicit SparkPlan operators
+ * (LanceProbeExec → LanceMergeExec → LanceMaterializeExec), each boundary needs the
+ * `ProbedLeft` payload encoded as `InternalRow`. How much wall-clock does that cost
+ * relative to the actual probe + merge + materialize work?
+ *
+ * Key insight before we even measure: between probe and merge, the payload is per left row,
+ * NOT per `(left × K)` pair. We carry the left vector + K row refs (rowAddr + score), not
+ * K full right-side rows. The right rows only get fetched in the materialize stage. So the
+ * encoding cost scales with |L|, not |L| × K × |right_row_size|.
+ *
+ * == Two encoding schemes compared ==
+ *
+ * A) Catalyst struct encoding via ExpressionEncoder. Schema:
+ * `struct, refs: array>>`. UnsafeRow-encoded; native to Catalyst's row-shuffle path.
+ *
+ * B) Binary blob via Kryo. Schema: `struct` where `blob` is
+ * the Kryo-serialized `ProbedLeft`. Simpler implementation but pays Kryo's per-call
+ * overhead.
+ *
+ * The "winner" is whichever cost is small enough to ignore at the realistic scales we
+ * benchmark (small = 100 left rows, medium = 1000). If both are negligible, the choice is
+ * code-complexity, not performance.
+ *
+ * Invocation:
+ * {{{
+ * MAVEN_OPTS="-Xmx4g " \
+ * ./mvnw -pl lance-spark-knn_2.12 -q exec:java \
+ * -Dexec.classpathScope=test \
+ * -Dexec.mainClass=org.lance.spark.knn.benchmark.InterStagePayloadOverheadBench
+ * }}}
+ */
+object InterStagePayloadOverheadBench {
+
+ private val Dim: Int = 128
+ private val K: Int = 10
+ private val Seed: Long = 42L
+ private val Warmup: Int = 3
+ private val Iterations: Int = 5
+
+ // Scales matching the SQL benchmark's `numLeft` settings.
+ private val Scales: Seq[(String, Int)] = Seq(
+ "small (numLeft=100)" -> 100,
+ "medium (numLeft=1000)" -> 1000,
+ "stress (numLeft=10000)" -> 10000)
+
+ def main(args: Array[String]): Unit = {
+ println("=" * 78)
+ println("Inter-stage ProbedLeft encoding overhead — A: Catalyst struct vs B: Kryo blob")
+ println(s"Dim=$Dim, K=$K, ${Iterations} iters/${Warmup} warmups per cell")
+ println("=" * 78)
+
+ // SparkSession just for the encoder + Kryo registry; no jobs run here.
+ val spark = SparkSession.builder()
+ .appName("inter-stage-payload-overhead")
+ .master("local[1]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.serializer", classOf[KryoSerializer].getName)
+ .getOrCreate()
+ spark.sparkContext.setLogLevel("ERROR")
+
+ try {
+ val rng = new Random(Seed)
+
+ Scales.foreach { case (name, n) =>
+ println()
+ println(s"--- $name ---")
+
+ val payloads = generatePayloads(rng, n)
+
+ // Catalyst-struct encoder.
+ val schema = catalystSchema()
+ val enc = ExpressionEncoder(schema).resolveAndBind()
+ val ser = enc.createSerializer()
+ val deser = enc.createDeserializer()
+
+ val schemeA = bench(
+ "A: Catalyst struct (encode + decode)",
+ () => {
+ var sink = 0L
+ var i = 0
+ while (i < payloads.length) {
+ val ir = ser(toCatalystRow(i.toLong, payloads(i)))
+ val back = deser(ir)
+ sink ^= back.getLong(0)
+ i += 1
+ }
+ sink
+ })
+
+ // Kryo encoder.
+ val kryoSerializerInstance = new KryoSerializer(spark.sparkContext.getConf).newInstance()
+ val schemeB = bench(
+ "B: Kryo binary blob (encode + decode)",
+ () => {
+ var sink = 0L
+ var i = 0
+ while (i < payloads.length) {
+ val bytes = kryoSerializerInstance.serialize(payloads(i))
+ val back = kryoSerializerInstance.deserialize[ProbedLeft](bytes)
+ sink ^= back.refs.length.toLong
+ i += 1
+ }
+ sink
+ })
+
+ val medianA = median(schemeA)
+ val medianB = median(schemeB)
+ printf(
+ " A: Catalyst struct median=%6.2f ms per-row=%6.2f µs (across both inter-stage boundaries: %6.2f ms total)%n",
+ medianA / 1e6,
+ medianA / 1e3 / n,
+ 2.0 * medianA / 1e6)
+ printf(
+ " B: Kryo binary blob median=%6.2f ms per-row=%6.2f µs (across both inter-stage boundaries: %6.2f ms total)%n",
+ medianB / 1e6,
+ medianB / 1e3 / n,
+ 2.0 * medianB / 1e6)
+
+ // Sanity: serialized blob size for B (rough — actual rows may vary slightly).
+ val sampleBlob = new KryoSerializer(spark.sparkContext.getConf).newInstance()
+ .serialize(payloads.head)
+ val avgBlobBytes = sampleBlob.remaining()
+ printf(
+ " Per-row serialized size (Kryo blob): ~%d bytes (×$n rows ≈ %.1f KB total payload)%n",
+ avgBlobBytes,
+ avgBlobBytes * n / 1024.0)
+ }
+
+ println()
+ println("=" * 78)
+ println("Conclusion guide:")
+ println(" - Encoding overhead < 5%% of total wall-clock at the relevant SQL benchmark")
+ println(" cell ⇒ splitting into 3 execs is essentially free; do it for explainability.")
+ println(" - Encoding overhead > 20%% ⇒ the 3-exec split costs more than it informs;")
+ println(" keep the single-exec wrapper and consider RDD.setName() for Spark UI clarity.")
+ println(" - Anywhere between, judgement call.")
+ println("=" * 78)
+ } finally {
+ spark.stop()
+ }
+ }
+
+ // -- payload generation ------------------------------------------------------------------
+
+ private def generatePayloads(rng: Random, n: Int): Array[ProbedLeft] = {
+ val out = new Array[ProbedLeft](n)
+ var i = 0
+ while (i < n) {
+ val vec = new Array[Float](Dim)
+ var d = 0
+ while (d < Dim) { vec(d) = rng.nextFloat(); d += 1 }
+ val leftRow: Row = RowFactory.create(Integer.valueOf(i), vec)
+
+ val refs = new Array[ScoredRowRef](K)
+ var r = 0
+ while (r < K) {
+ refs(r) = ScoredRowRef(rng.nextLong(), rng.nextFloat())
+ r += 1
+ }
+ out(i) = ProbedLeft(leftRow, refs)
+ i += 1
+ }
+ out
+ }
+
+ // Catalyst schema mirroring the candidate Plan-A struct encoding.
+ private def catalystSchema(): StructType = StructType(Seq(
+ StructField("leftId", LongType, nullable = false),
+ StructField(
+ "leftRow",
+ StructType(Seq(
+ StructField("lid", IntegerType, nullable = true),
+ StructField("lvec", ArrayType(FloatType, containsNull = false), nullable = true))),
+ nullable = true),
+ StructField(
+ "refs",
+ ArrayType(
+ StructType(Seq(
+ StructField("rowAddr", LongType, nullable = false),
+ StructField("score", FloatType, nullable = false))),
+ containsNull = false),
+ nullable = false)))
+
+ private def toCatalystRow(leftId: Long, pl: ProbedLeft): Row = {
+ val leftStruct = Row(pl.leftRow.getInt(0), pl.leftRow.get(1))
+ val refStructs = pl.refs.map(r => Row(r.rowAddr, r.score)).toSeq
+ Row(leftId, leftStruct, refStructs)
+ }
+
+ // -- timing helper -----------------------------------------------------------------------
+
+ private def bench(label: String, body: () => Long): Seq[Long] = {
+ var w = 0
+ while (w < Warmup) { val _ = body(); w += 1 }
+ val out = mutable.ArrayBuffer.empty[Long]
+ var i = 0
+ while (i < Iterations) {
+ val t0 = System.nanoTime()
+ val _ = body()
+ out += (System.nanoTime() - t0)
+ i += 1
+ }
+ out.toSeq
+ }
+
+ private def median(xs: Seq[Long]): Long = {
+ val sorted = xs.sorted
+ sorted(sorted.length / 2)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceFragmentsTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceFragmentsTest.scala
new file mode 100644
index 000000000..34abcae78
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceFragmentsTest.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.Test
+
+/**
+ * Unit tests for [[LanceFragments.roundRobin]]. The actual Lance-backed enumeration is exercised
+ * indirectly by the Phase 1.5 oracle test (which writes a dataset and reads its fragment list).
+ * Here we just check the partitioning math.
+ */
+class LanceFragmentsTest {
+
+ @Test def testRoundRobinBalanced(): Unit = {
+ val groups = LanceFragments.roundRobin(Seq(10, 11, 12, 13, 14, 15), 3)
+ assertEquals(3, groups.size)
+ assertEquals(Seq(10, 13), groups(0))
+ assertEquals(Seq(11, 14), groups(1))
+ assertEquals(Seq(12, 15), groups(2))
+ }
+
+ /**
+ * `groupCount > numFragments` produces empty trailing groups. The probe stage must tolerate
+ * empty groups (skip them) — this is the contract the empty result encodes.
+ */
+ @Test def testMoreGroupsThanFragments(): Unit = {
+ val groups = LanceFragments.roundRobin(Seq(7, 8), 5)
+ assertEquals(5, groups.size)
+ assertEquals(Seq(7), groups(0))
+ assertEquals(Seq(8), groups(1))
+ assertTrue(groups(2).isEmpty)
+ assertTrue(groups(3).isEmpty)
+ assertTrue(groups(4).isEmpty)
+ }
+
+ @Test def testSingleGroupReturnsAll(): Unit = {
+ val groups = LanceFragments.roundRobin(Seq(1, 2, 3, 4), 1)
+ assertEquals(Seq(Seq(1, 2, 3, 4)), groups)
+ }
+
+ @Test def testEmptyInputProducesEmptyGroups(): Unit = {
+ val groups = LanceFragments.roundRobin(Seq.empty, 3)
+ assertEquals(3, groups.size)
+ assertTrue(groups.forall(_.isEmpty))
+ }
+
+ // -- greedyBalance / Phase 3 skew handling -------------------------------------------------
+
+ /**
+ * LPT greedy: given imbalanced fragments, the worst group's total should be no more than
+ * 4/3 of the optimal. With 4 frags of weights (10, 10, 10, 1) split into 2 groups, optimal
+ * makespan = 16. LPT places 10 in g0, 10 in g1, 10 in g0 (now 20), 1 in g1 (now 11) — so
+ * g0=20, g1=11. Best balance achievable is g0=11, g1=20 (or symmetric); LPT happens to
+ * arrive at one of those orderings here. Either way, no group exceeds 21 which is well within
+ * the 4/3 bound (~21.3).
+ */
+ @Test def testGreedyBalanceKeepsHeaviestGroupBoundedFor4_3OptOpt(): Unit = {
+ val groups = LanceFragments.greedyBalance(
+ Seq((1, 10L), (2, 10L), (3, 10L), (4, 1L)),
+ groupCount = 2)
+ assertEquals(2, groups.size)
+ val totals = groups.map(g => g.map(id => Map(1 -> 10L, 2 -> 10L, 3 -> 10L, 4 -> 1L)(id)).sum)
+ val maxTotal = totals.max
+ val sumAll = 31L
+ val optimal = math.ceil(sumAll.toDouble / 2).toInt // 16
+ val bound = math.ceil(optimal * 4.0 / 3.0).toInt // 22
+ assertTrue(maxTotal <= bound, s"LPT maxTotal=$maxTotal exceeded 4/3 bound $bound")
+ }
+
+ /**
+ * LPT collapses to round-robin-style behavior when all weights are equal — every group ends
+ * up with the same number of items.
+ */
+ @Test def testGreedyBalanceEqualWeightsBalancesItemCount(): Unit = {
+ val groups = LanceFragments.greedyBalance(
+ Seq((1, 5L), (2, 5L), (3, 5L), (4, 5L), (5, 5L), (6, 5L)),
+ groupCount = 3)
+ assertEquals(3, groups.size)
+ groups.foreach(g => assertEquals(2, g.size, "equal weights should yield equal item counts"))
+ }
+
+ /**
+ * One huge fragment + many small ones: the huge one anchors a group on its own, and the
+ * smalls fill the others. This is the textbook skew case Phase 3 cares about.
+ */
+ @Test def testGreedyBalanceIsolatesSkewedFragment(): Unit = {
+ val groups = LanceFragments.greedyBalance(
+ Seq((10, 100L), (20, 5L), (30, 5L), (40, 5L), (50, 5L)),
+ groupCount = 2)
+ assertEquals(2, groups.size)
+ val groupWith10 = groups.find(_.contains(10)).get
+ assertEquals(Seq(10), groupWith10, "skewed fragment should land in its own group")
+ val otherGroup = groups.find(!_.contains(10)).get
+ assertEquals(Set(20, 30, 40, 50), otherGroup.toSet)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceProbeValidationTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceProbeValidationTest.scala
new file mode 100644
index 000000000..0ae287138
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceProbeValidationTest.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * End-to-end validation of [[LanceProbe]] against a real Lance dataset written by Spark. These are
+ * the day-1 validation tasks the implementation plan calls out:
+ *
+ * - Per-probe call should succeed and return Lance's nearest neighbors.
+ * - Repeated probes against the same `LanceProbe` instance should reuse the open dataset
+ * handle; the second call should not re-pay the dataset open cost.
+ * - `fragmentIds` restriction should narrow the search to specified fragments only.
+ * - Without an explicit vector index the probe falls back to a brute-force per-fragment scan,
+ * which gives recall = 1.0 — making the no-index path the natural correctness oracle.
+ *
+ * These tests do NOT require an actual vector index; that is exercised in the indexed test
+ * suites which build IVF-PQ via Lance's index DDL. Validating the brute-force path first lets us
+ * isolate any LanceProbe bugs from index-quality issues.
+ */
+class LanceProbeValidationTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ // Small synthetic dataset: 64 vectors, dim 8. Enough to exercise the probe loop without making
+ // the test slow.
+ private val NumRows = 64
+ private val VectorDim = 8
+ private val Seed = 42L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("lance-probe-validation")
+ .master("local[2]")
+ // Pin the driver to loopback so test JVMs in restricted networks (CI sandboxes, dev
+ // containers) can bind without scanning the host's interfaces.
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = {
+ if (spark != null) spark.stop()
+ }
+
+ /**
+ * Smoke test: write a dataset, probe it, get K rows back. No correctness assertion beyond
+ * "result has the right shape" — the brute-force-equivalence test below covers semantics.
+ */
+ @Test def testProbeReturnsKResults(): Unit = {
+ val datasetUri = writeSyntheticDataset()
+ val query = randomVector(new Random(7L), VectorDim)
+
+ val probe = new LanceProbe(datasetUri, fragmentIds = None)
+ try {
+ val results = probe.probe(vectorColumn = "vec", query, k = 5, metric = Metric.L2)
+ assertEquals(5, results.size, "probe should return exactly k results")
+ // Distances must be monotonically non-decreasing for L2 (best-first).
+ val scores = results.map(_.score)
+ assertEquals(scores, scores.sorted, "L2 results should be sorted ascending by distance")
+ // Row addresses are stable u64s; we just sanity-check they aren't all zero.
+ assertTrue(results.exists(_.rowAddr != 0L), "row addresses should be populated")
+ } finally probe.close()
+ }
+
+ /**
+ * Without a vector index, Lance does an exact per-fragment scan. That makes it a recall = 1.0
+ * oracle: the probe result should equal the ground-truth top-K computed in plain Scala.
+ */
+ @Test def testProbeMatchesBruteForceOracle(): Unit = {
+ val rng = new Random(Seed)
+ val (rows, vectors) = generateRows(rng, NumRows, VectorDim)
+ val datasetUri = writeRows(rows)
+
+ val query = randomVector(new Random(123L), VectorDim)
+ val k = 10
+
+ val oracle: Seq[(Int, Float)] = vectors.zipWithIndex
+ .map { case (v, idx) => (idx, l2Distance(query, v)) }
+ .sortBy(_._2)
+ .take(k)
+
+ val probe = new LanceProbe(datasetUri, fragmentIds = None)
+ val actual =
+ try probe.probe("vec", query, k, Metric.L2)
+ finally probe.close()
+
+ assertEquals(k, actual.size)
+ // Compare scores within float tolerance.
+ val expectedScores = oracle.map(_._2)
+ val actualScores = actual.map(_.score)
+ expectedScores.zip(actualScores).foreach { case (expected, actualScore) =>
+ assertEquals(
+ expected,
+ actualScore,
+ 1e-4f,
+ s"top-K distance mismatch: oracle=$expectedScores actual=$actualScores")
+ }
+ }
+
+ /**
+ * Validate the dataset handle is reused across calls. The exact perf invariant ("second call
+ * faster than first by some factor") is too brittle for CI, so we only assert that repeated
+ * probes succeed and don't OOM — i.e., no JNI handle / Arrow buffer leak per call.
+ */
+ @Test def testRepeatedProbesShareDatasetHandle(): Unit = {
+ val datasetUri = writeSyntheticDataset()
+ val probe = new LanceProbe(datasetUri, None)
+ try {
+ val rng = new Random(99L)
+ val k = 4
+ var i = 0
+ while (i < 50) {
+ val results = probe.probe("vec", randomVector(rng, VectorDim), k, Metric.L2)
+ assertEquals(k, results.size, s"iteration $i returned wrong size")
+ i += 1
+ }
+ } finally probe.close()
+ }
+
+ /** Empty fragment-id list ⇒ no rows match. Confirms the pushdown actually narrows search. */
+ @Test def testEmptyFragmentRestrictionReturnsNothing(): Unit = {
+ val datasetUri = writeSyntheticDataset()
+ val probe = new LanceProbe(datasetUri, Some(Seq.empty))
+ try {
+ val results = probe.probe("vec", randomVector(new Random(1L), VectorDim), 5, Metric.L2)
+ assertTrue(results.isEmpty, s"empty fragmentIds should yield no results, got ${results.size}")
+ } finally probe.close()
+ }
+
+ // -- helpers ------------------------------------------------------------------------------
+
+ /** Write a fresh dataset and return its file:// URI. */
+ private def writeSyntheticDataset(): String = {
+ val rng = new Random(Seed)
+ val (rows, _) = generateRows(rng, NumRows, VectorDim)
+ writeRows(rows)
+ }
+
+ private def writeRows(rows: Seq[Row]): String = {
+ val schema = new StructType(Array(
+ StructField("id", IntegerType, nullable = false),
+ StructField(
+ "vec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", VectorDim.toLong).build())))
+ val df = spark.createDataFrame(rows.asJava, schema)
+
+ val outDir = tempDir.resolve(s"probe_test_${System.nanoTime()}").toString
+ df.write.format("lance").save(outDir)
+ outDir
+ }
+
+ private def generateRows(rng: Random, n: Int, dim: Int): (Seq[Row], Seq[Array[Float]]) = {
+ val vectors = (0 until n).map(_ => randomVector(rng, dim))
+ val rows = vectors.zipWithIndex.map { case (v, idx) =>
+ RowFactory.create(Integer.valueOf(idx), v)
+ }
+ (rows, vectors)
+ }
+
+ private def randomVector(rng: Random, dim: Int): Array[Float] = {
+ val v = new Array[Float](dim)
+ var i = 0
+ while (i < dim) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ private def l2Distance(a: Array[Float], b: Array[Float]): Float = {
+ var s = 0.0f
+ var i = 0
+ while (i < a.length) {
+ val d = a(i) - b(i)
+ s += d * d
+ i += 1
+ }
+ s
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceVectorIndexBuilder.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceVectorIndexBuilder.scala
new file mode 100644
index 000000000..cd35e13ba
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/LanceVectorIndexBuilder.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.lance.{Dataset, ReadOptions}
+import org.lance.index.{IndexOptions, IndexParams, IndexType}
+import org.lance.index.vector.VectorIndexParams
+import org.lance.spark.LanceRuntime
+
+import scala.collection.JavaConverters._
+
+/**
+ * Test-only helper to build an IVF-PQ vector index on a Lance dataset via
+ * `Dataset.createIndex`. Exists so recall tests can construct the indexed scan path
+ * without writing the Lance Java boilerplate inline.
+ *
+ * Lives in `src/test/scala` because the production code path doesn't need to build
+ * indexes — users build them via Lance's Python / Rust / SQL DDL on their own datasets,
+ * and we just probe whatever's there. The helper exists for closed-loop recall validation.
+ */
+object LanceVectorIndexBuilder {
+
+ /**
+ * Build an IVF-PQ index on `vectorColumn` of the dataset at `datasetUri`. Defaults are
+ * tuned for tiny test datasets — production users would size these much larger.
+ *
+ * @param numPartitions IVF cluster count. Should divide cleanly into the dataset row count.
+ * For a 4K-row dataset, 4-8 partitions is reasonable.
+ * @param numSubVectors PQ sub-vector count. Must divide vector dim evenly.
+ * @param numBits PQ bits per sub-vector. 8 is the standard.
+ * @param metric distance type. Must match the metric used at probe time.
+ * @param maxIters KMeans iteration cap during IVF training. 50 is enough for tests.
+ */
+ def buildIvfPq(
+ datasetUri: String,
+ vectorColumn: String,
+ numPartitions: Int = 4,
+ numSubVectors: Int = 8,
+ numBits: Int = 8,
+ metric: Metric = Metric.L2,
+ maxIters: Int = 50): Unit = {
+ val dataset = openDataset(datasetUri)
+ try {
+ val vectorParams =
+ VectorIndexParams.ivfPq(numPartitions, numSubVectors, numBits, metric.lanceType, maxIters)
+ val indexParams = IndexParams.builder().setVectorIndexParams(vectorParams).build()
+ val opts = IndexOptions
+ .builder(java.util.Collections.singletonList(vectorColumn), IndexType.VECTOR, indexParams)
+ .build()
+ dataset.createIndex(opts)
+ } finally dataset.close()
+ }
+
+ /**
+ * Build an IVF_FLAT index — IVF clustering without PQ compression. Exact distances within
+ * visited clusters (no PQ noise), so recall depends purely on `nprobes` coverage. Higher
+ * memory/disk footprint than IVF-PQ (full vectors stored per cluster) but better recall on
+ * high-dim or random workloads where PQ compression drops too much information.
+ */
+ def buildIvfFlat(
+ datasetUri: String,
+ vectorColumn: String,
+ numPartitions: Int = 4,
+ metric: Metric = Metric.L2): Unit = {
+ val dataset = openDataset(datasetUri)
+ try {
+ val vectorParams = VectorIndexParams.ivfFlat(numPartitions, metric.lanceType)
+ val indexParams = IndexParams.builder().setVectorIndexParams(vectorParams).build()
+ val opts = IndexOptions
+ .builder(java.util.Collections.singletonList(vectorColumn), IndexType.VECTOR, indexParams)
+ .build()
+ dataset.createIndex(opts)
+ } finally dataset.close()
+ }
+
+ private def openDataset(uri: String): Dataset = {
+ Dataset
+ .open()
+ .uri(uri)
+ .allocator(LanceRuntime.allocator())
+ .readOptions(new ReadOptions.Builder().build())
+ .build()
+ }
+
+ /** Number of indexes on the dataset (sanity check after building). */
+ def listIndexCount(datasetUri: String): Int = {
+ val dataset = openDataset(datasetUri)
+ try dataset.listIndexes.asScala.size
+ finally dataset.close()
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/TopKHeapTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/TopKHeapTest.scala
new file mode 100644
index 000000000..41bd1de7d
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/TopKHeapTest.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal
+
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.Test
+
+/**
+ * Unit tests for [[TopKHeap]]. The heap's correctness is the foundation of the merge stage —
+ * any off-by-one or wrong-direction ordering would silently corrupt top-K results. We test both
+ * metric directions explicitly.
+ */
+class TopKHeapTest {
+
+ private def ref(addr: Long, score: Float): ScoredRowRef = ScoredRowRef(addr, score)
+
+ /** Distance metric: smaller score is better. Top-K must hold the K smallest. */
+ @Test def testDistanceKeepsKSmallest(): Unit = {
+ val heap = new TopKHeap(k = 3, smallerIsBetter = true)
+ Seq(5.0f, 1.0f, 4.0f, 2.0f, 8.0f, 0.5f).zipWithIndex.foreach { case (s, i) =>
+ heap.offer(ref(i.toLong, s))
+ }
+ val out = heap.drain()
+ val scores = out.map(_.score).toSeq
+ assertEquals(Seq(0.5f, 1.0f, 2.0f), scores, "distance heap should retain three smallest")
+ }
+
+ /** Similarity metric: larger score is better. Top-K must hold the K largest. */
+ @Test def testSimilarityKeepsKLargest(): Unit = {
+ val heap = new TopKHeap(k = 3, smallerIsBetter = false)
+ Seq(5.0f, 1.0f, 4.0f, 2.0f, 8.0f, 0.5f).zipWithIndex.foreach { case (s, i) =>
+ heap.offer(ref(i.toLong, s))
+ }
+ val out = heap.drain()
+ val scores = out.map(_.score).toSeq
+ assertEquals(Seq(8.0f, 5.0f, 4.0f), scores, "similarity heap should retain three largest")
+ }
+
+ /** Drain order is best-first regardless of insertion order. */
+ @Test def testDrainOrderIsBestFirst(): Unit = {
+ val heap = new TopKHeap(k = 4, smallerIsBetter = true)
+ heap.offerAll(Seq(ref(1, 9f), ref(2, 1f), ref(3, 5f), ref(4, 3f), ref(5, 2f)))
+ val drained = heap.drain()
+ val scores = drained.map(_.score).toSeq
+ assertEquals(Seq(1f, 2f, 3f, 5f), scores)
+ assertTrue(heap.isEmpty, "drain should leave the heap empty")
+ }
+
+ /** Heap with fewer than K elements drains them all in best-first order. */
+ @Test def testFewerThanKReturnsAll(): Unit = {
+ val heap = new TopKHeap(k = 10, smallerIsBetter = true)
+ heap.offerAll(Seq(ref(1, 3f), ref(2, 1f), ref(3, 2f)))
+ assertEquals(Seq(1f, 2f, 3f), heap.drain().map(_.score).toSeq)
+ }
+
+ /** A worse-than-current-worst candidate is rejected. */
+ @Test def testRejectsWorseCandidate(): Unit = {
+ val heap = new TopKHeap(k = 2, smallerIsBetter = true)
+ heap.offer(ref(1, 1f))
+ heap.offer(ref(2, 2f))
+ heap.offer(ref(3, 5f)) // worse than existing 2 → rejected
+ val drained = heap.drain()
+ assertEquals(Seq(1f, 2f), drained.map(_.score).toSeq)
+ assertEquals(Seq(1L, 2L), drained.map(_.rowAddr).toSeq)
+ }
+
+ /** `merge` combines two pre-sorted arrays preserving top-K. */
+ @Test def testMergeCombinesTwoArrays(): Unit = {
+ val a = Array(ref(1, 1f), ref(2, 3f), ref(3, 5f))
+ val b = Array(ref(4, 2f), ref(5, 4f), ref(6, 6f))
+ val merged = TopKHeap.merge(a, b, k = 4, smallerIsBetter = true)
+ assertEquals(Seq(1f, 2f, 3f, 4f), merged.map(_.score).toSeq)
+ }
+
+ /** Merging with one empty input is a noop modulo trim to K. */
+ @Test def testMergeWithEmpty(): Unit = {
+ val a = Array(ref(1, 1f), ref(2, 2f), ref(3, 3f))
+ val merged = TopKHeap.merge(a, Array.empty[ScoredRowRef], k = 2, smallerIsBetter = true)
+ assertEquals(Seq(1f, 2f), merged.map(_.score).toSeq)
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/staged/StagedPlansReferencesTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/staged/StagedPlansReferencesTest.scala
new file mode 100644
index 000000000..0203a0a0b
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/internal/staged/StagedPlansReferencesTest.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.internal.staged
+
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.types.{ArrayType, FloatType, LongType, StructField, StructType}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.Test
+import org.lance.spark.knn.internal.{LanceMaterializeStage, LanceMergeStage, LanceProbeStage, Metric}
+
+/**
+ * Regression test pinning the `references = child.outputSet` override on
+ * `LanceMergeLogicalPlan` and `LanceMaterializeLogicalPlan`.
+ *
+ * If someone ever removes, narrows, or weakens these overrides, Catalyst's
+ * `ColumnPruning` rule will insert `Project(Nil)` wrappers between nodes when downstream
+ * consumers (`count(*)`, `Aggregate`, etc.) reference none of the node's output columns.
+ * Those empty projections codegen to 0-field `UnsafeRow`s, which crash
+ * `ProbedLeftCodec.Decoder.decode` with either `AssertionError: index (0) should < 0`
+ * (interpreter/C1) or a SIGSEGV in `UnsafeRow.getLong` (C2 JIT). That bug was originally
+ * misdiagnosed as a JVM-aarch64 codec interaction; the actual cause was missing
+ * `references` overrides.
+ *
+ * Functional coverage of the same property lives in
+ * [[org.lance.spark.knn.IndexedNearestJoinConsumerShapeTest]] (count / agg / lit-select
+ * all succeed end-to-end). This test is the cheap structural pin: if the override goes
+ * away, this test fails instantly instead of waiting for a slow ColumnPruning-driven
+ * end-to-end crash.
+ */
+class StagedPlansReferencesTest {
+
+ private def dummyLeftSchema: StructType = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField("qvec", ArrayType(FloatType, containsNull = false), nullable = false)))
+
+ private def dummyChild(attrs: Seq[AttributeReference]): LocalRelation =
+ LocalRelation(attrs)
+
+ @Test def testLanceMergeLogicalPlanReferencesIsChildOutputSet(): Unit = {
+ val leftSchema = dummyLeftSchema
+ val interStageAttrs = ProbedLeftCodec.interStageAttributes(leftSchema)
+ val child = dummyChild(interStageAttrs)
+ val merge = LanceMergeLogicalPlan(
+ child = child,
+ stageConf = LanceMergeStage.Conf(finalK = 1, smallerIsBetter = true),
+ leftSchema = leftSchema,
+ interStageOutput = interStageAttrs)
+
+ assertEquals(
+ child.outputSet,
+ merge.references,
+ "LanceMergeLogicalPlan.references must equal child.outputSet so Catalyst " +
+ "ColumnPruning cannot insert Project(Nil) above it")
+ }
+
+ @Test def testLanceMaterializeLogicalPlanReferencesIsChildOutputSet(): Unit = {
+ val leftSchema = dummyLeftSchema
+ val interStageAttrs = ProbedLeftCodec.interStageAttributes(leftSchema)
+ val finalAttrs = Seq(
+ AttributeReference("lid", LongType, nullable = false)(),
+ AttributeReference("rid", LongType, nullable = true)(),
+ AttributeReference("__score", FloatType, nullable = true)())
+
+ val child = dummyChild(interStageAttrs)
+ val materialize = LanceMaterializeLogicalPlan(
+ child = child,
+ stageConf = LanceMaterializeStage.Conf(
+ datasetUri = "/tmp/unused",
+ version = None,
+ rightProjection = Seq("rid"),
+ rightFields = Seq(StructField("rid", LongType, nullable = true)),
+ leftFieldCount = 2,
+ outerJoin = false),
+ leftSchema = leftSchema,
+ finalSchema = new StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField("rid", LongType, nullable = true),
+ StructField("__score", FloatType, nullable = true))),
+ finalOutput = finalAttrs)
+
+ assertEquals(
+ child.outputSet,
+ materialize.references,
+ "LanceMaterializeLogicalPlan.references must equal child.outputSet so Catalyst " +
+ "ColumnPruning cannot insert Project(Nil) above it")
+ }
+
+ /**
+ * Explicit positive check: `child.outputSet` must be a subset of `references`. This is
+ * the literal predicate `ColumnPruning`'s `Aggregate(_, _, child, _) if !child.outputSet
+ * .subsetOf(a.references)` uses to decide whether to insert a pruning Project. Our
+ * override makes the subset relation hold (equality ⇒ subset), which short-circuits
+ * the rule.
+ */
+ @Test def testColumnPruningPredicateShortCircuits(): Unit = {
+ val leftSchema = dummyLeftSchema
+ val interStageAttrs = ProbedLeftCodec.interStageAttributes(leftSchema)
+ val child = dummyChild(interStageAttrs)
+ val merge = LanceMergeLogicalPlan(
+ child = child,
+ stageConf = LanceMergeStage.Conf(finalK = 1, smallerIsBetter = true),
+ leftSchema = leftSchema,
+ interStageOutput = interStageAttrs)
+
+ assertTrue(
+ child.outputSet.subsetOf(merge.references),
+ "ColumnPruning's guard is `!child.outputSet.subsetOf(references)`; " +
+ "override must make the subset relation hold so pruning doesn't fire")
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleReproTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleReproTest.scala
new file mode 100644
index 000000000..4b23a0a55
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleReproTest.scala
@@ -0,0 +1,349 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.staged
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Repro harness for the JVM SIGSEGV observed by the reverted 3-exec staged split
+ * (commits 882fcdb / 4b68ee3 / 6218b1c, reverted in 2e2ba94). Goal: narrow the fault to
+ * either (a) Spark/JVM, or (b) the staged-exec + Lance interaction.
+ *
+ * Each test isolates one variable of the staged-exec pipeline:
+ *
+ * test1 — baseline: no shuffle, schema with ArrayType(FloatType, 128) + array
+ * test2 — encoder round-trip at the same benchmark scale, no shuffle
+ * test3 — shuffle (repartition by leftId) + encoder round-trip, same schema
+ * test4 — same as test3 but payload comes from Row → InternalRow via ExpressionEncoder,
+ * mirroring the staged-exec codec's hot path
+ * test5 — test4 + a downstream mapPartitions that decodes via direct InternalRow
+ * accessors, mirroring ProbedLeftCodec.Decoder
+ *
+ * The benchmark reported the SIGSEGV at "stage 79 task 0" at 100K rows × 128-dim. We run at
+ * 100K here — the reverted codec consistently crashed at that scale on M5 Max + Temurin 17.
+ * If any of these tests SIGSEGV, it's a Spark/JVM bug. If they all pass, the fault is
+ * specific to how the staged execs wire themselves into Lance's output.
+ */
+class InterStageShuffleReproTest {
+
+ private var spark: SparkSession = _
+
+ // Benchmark-scale knobs. The original crash fired at leftN = 100K with a 128-dim qvec on
+ // the left side. K = 10 refs per row matches the join's top-K. Shuffle over 4 partitions
+ // to force cross-partition movement of every row.
+ private val LeftN: Int = 100000
+ private val Dim: Int = 128
+ private val K: Int = 10
+ private val ShuffleParts: Int = 4
+ private val Seed: Long = 1337L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("inter-stage-shuffle-repro")
+ .master("local[4]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.driver.memory", "4g")
+ .config("spark.sql.shuffle.partitions", ShuffleParts.toString)
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ // ---- schemas ----------------------------------------------------------------------------
+
+ /**
+ * Matches `ProbedLeftCodec.interStageSchema` shape:
+ * - leading `_leftId: long`
+ * - user left-schema fields flattened — here: one `ArrayType(FloatType)` vec column
+ * - trailing `_refs: array>`
+ */
+ private val RefStruct: StructType = StructType(Array(
+ StructField("rowAddr", LongType, nullable = false),
+ StructField("score", FloatType, nullable = false)))
+
+ private val InterStageSchema: StructType = StructType(Array(
+ StructField("_leftId", LongType, nullable = false),
+ StructField(
+ "qvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build()),
+ StructField(
+ "_refs",
+ ArrayType(RefStruct, containsNull = false),
+ nullable = false)))
+
+ // ---- data generation --------------------------------------------------------------------
+
+ private def randomFloats(rng: Random, n: Int): Array[Float] = {
+ val v = new Array[Float](n)
+ var i = 0
+ while (i < n) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ /** Build `LeftN` rows driver-side. Values are Java-shaped to match `RowFactory.create`. */
+ private def buildInterStageRows(): java.util.List[Row] = {
+ val rng = new Random(Seed)
+ val rows = new java.util.ArrayList[Row](LeftN)
+ var i = 0
+ while (i < LeftN) {
+ val qvec = randomFloats(rng, Dim).toSeq.asJava
+ val refs = new java.util.ArrayList[Row](K)
+ var r = 0
+ while (r < K) {
+ refs.add(RowFactory.create(
+ java.lang.Long.valueOf(rng.nextLong() & 0x7FFFFFFFFFFFFFFFL),
+ java.lang.Float.valueOf(rng.nextFloat())))
+ r += 1
+ }
+ rows.add(RowFactory.create(
+ java.lang.Long.valueOf(i.toLong),
+ qvec,
+ refs))
+ i += 1
+ }
+ rows
+ }
+
+ // ---- tests ------------------------------------------------------------------------------
+
+ /** Baseline: no shuffle. Just build a DF and `count()`. If this crashes, schema alone is broken. */
+ @Test def test1_buildAndCountNoShuffle(): Unit = {
+ val df = spark.createDataFrame(buildInterStageRows(), InterStageSchema)
+ val n = df.count()
+ assertEquals(LeftN.toLong, n)
+ }
+
+ /**
+ * Force an `ExpressionEncoder(InterStageSchema)` round-trip. Row → InternalRow → Row via
+ * `as[Row]` on a Dataset forces encoder codegen. No shuffle. If this crashes, encoder
+ * codegen + this specific schema (array, array) is what trips the JIT.
+ */
+ @Test def test2_encoderRoundTripNoShuffle(): Unit = {
+ val df = spark.createDataFrame(buildInterStageRows(), InterStageSchema)
+ // Force encoder deserialization by collecting via rdd → row. This is the simplest
+ // mirror of what ProbedLeftCodec.Decoder does on each task.
+ val collected = df.rdd.count()
+ assertEquals(LeftN.toLong, collected)
+ }
+
+ /**
+ * Shuffle by `_leftId`, then materialize. Mimics `EnsureRequirements` inserting a
+ * `ShuffleExchangeExec` when `LanceMergeExec` declares `ClusteredDistribution(_leftId)`.
+ * Encoder round-trips on both sides of the shuffle. If this crashes, Spark's UnsafeRow
+ * shuffle of `(long, array, array)` is the bug.
+ */
+ @Test def test3_shufflePlusCount(): Unit = {
+ val df = spark.createDataFrame(buildInterStageRows(), InterStageSchema)
+ val shuffled = df.repartition(ShuffleParts, col("_leftId"))
+ val n = shuffled.count()
+ assertEquals(LeftN.toLong, n)
+ }
+
+ /**
+ * Same shuffle shape as test3 but the payload starts life as `RDD[InternalRow]` produced
+ * by an `ExpressionEncoder(InterStageSchema).createSerializer()` in a `mapPartitions` —
+ * exactly what `LanceProbeExec.doExecute` does. This is the closest non-Lance mirror of
+ * the reverted staged probe stage's output.
+ *
+ * We then read back through `df.rdd` which triggers the decoder. If this crashes, the
+ * encoder-driven inter-stage row path is the bug independent of Lance.
+ */
+ @Test def test4_encodeToInternalRowThenShuffle(): Unit = {
+ val schemaCaptured = InterStageSchema
+ val leftN = LeftN
+ val dim = Dim
+ val k = K
+ val seed = Seed
+
+ // Build an RDD[Row] driver-side (small synthetic gen, no Lance). Each partition gets
+ // ~leftN/parallelism rows. Parallelism is local[4] ⇒ 4 partitions upstream of the shuffle.
+ val rows: Seq[Row] = {
+ val rng = new Random(seed)
+ (0 until leftN).map { i =>
+ val qvec = randomFloats(rng, dim).toSeq
+ val refs = (0 until k).map { _ =>
+ Row(rng.nextLong() & 0x7FFFFFFFFFFFFFFFL, rng.nextFloat())
+ }
+ Row(i.toLong, qvec, refs)
+ }
+ }
+ val rowRdd: RDD[Row] = spark.sparkContext.parallelize(rows, 4)
+
+ // Encode Row → InternalRow in a mapPartitions, mirroring LanceProbeExec.doExecute.
+ val internalRdd: RDD[InternalRow] = rowRdd.mapPartitions { iter =>
+ val enc = ExpressionEncoder(schemaCaptured).resolveAndBind()
+ val ser = enc.createSerializer()
+ iter.map(r => ser(r).copy())
+ }
+
+ // Wrap back into a DataFrame via internalCreateDataFrame-style path. This is the
+ // public equivalent: go RDD[InternalRow] → RDD[Row] → createDataFrame.
+ val backToRow: RDD[Row] = internalRdd.mapPartitions { iter =>
+ val enc = ExpressionEncoder(schemaCaptured).resolveAndBind()
+ val deser = enc.createDeserializer()
+ iter.map(ir => deser(ir.copy()))
+ }
+
+ val df = spark.createDataFrame(backToRow, schemaCaptured)
+ val shuffled = df.repartition(ShuffleParts, col("_leftId"))
+ val n = shuffled.count()
+ assertEquals(leftN.toLong, n)
+ }
+
+ /**
+ * Adds a post-shuffle consumer that decodes each `InternalRow` via direct accessors —
+ * `ir.getLong(0)`, `ir.getArray(1)`, `ir.getArray(2).getStruct(i, 2)`. This is what
+ * `ProbedLeftCodec.Decoder` does, and what the revert commit said was SIGSEGV-ing in
+ * `UnsafeRow.getArray` under C2.
+ *
+ * We iterate through the shuffle's output InternalRows directly (via `queryExecution.toRdd`,
+ * which is the Catalyst-internal `RDD[InternalRow]` — same shape LanceMergeExec sees from
+ * its upstream ShuffleExchangeExec). Then sum the leftIds + refs length to force JIT to
+ * compile the hot loop over the shuffled UnsafeRows.
+ *
+ * Runs the inner loop multiple times to give C2 a chance to compile and mis-speculate.
+ */
+ @Test def test5_directInternalRowAccessorsPostShuffle(): Unit = {
+ val schemaCaptured = InterStageSchema
+ val leftN = LeftN
+
+ val df = spark.createDataFrame(buildInterStageRows(), schemaCaptured)
+ val shuffled = df.repartition(ShuffleParts, col("_leftId"))
+
+ // Catalyst-internal RDD[InternalRow] — what a physical child's execute() returns.
+ val shuffledInternal: RDD[InternalRow] = shuffled.queryExecution.toRdd
+
+ // Direct-accessor consumer, matching ProbedLeftCodec.Decoder's hot loop. Run it a few
+ // times so JIT C2 has a chance to compile + speculate on UnsafeRow.getArray.
+ var totalRows = 0L
+ var trial = 0
+ while (trial < 3) {
+ val count = shuffledInternal.mapPartitions { iter =>
+ var sum = 0L
+ while (iter.hasNext) {
+ val ir = iter.next().copy()
+ val leftId = ir.getLong(0)
+ // qvec: ArrayType(FloatType). getArray returns UnsafeArrayData.
+ val qvec = ir.getArray(1)
+ var qsum = 0.0f
+ var j = 0
+ while (j < qvec.numElements()) {
+ qsum += qvec.getFloat(j)
+ j += 1
+ }
+ // refs: ArrayType(StructType(...)). Iterate via getArray + getStruct.
+ val refs = ir.getArray(2)
+ var refSum = 0L
+ var r = 0
+ while (r < refs.numElements()) {
+ val s = refs.getStruct(r, 2)
+ refSum += s.getLong(0)
+ r += 1
+ }
+ sum += leftId + qsum.toLong + refSum
+ }
+ Iterator.single(sum)
+ }.count()
+ totalRows += count
+ trial += 1
+ }
+
+ // Each of 3 trials visits ShuffleParts partitions ⇒ 3 * ShuffleParts rows of count output.
+ assertEquals((3 * ShuffleParts).toLong, totalRows)
+ }
+
+ /**
+ * JIT-warmup stress. The reverted PR reported the SIGSEGV at Spark's "stage 79 task 0.0"
+ * right after a crossJoin baseline finished — i.e. after the JVM had been running hot for
+ * minutes and C2 had compiled essentially everything in sight. Tests 1-5 each only execute
+ * the hot loop a handful of times before asserting; that's not enough for C2 to compile +
+ * mis-speculate. This test runs:
+ *
+ * 1. An upstream crossJoin-esque warmup (generates JIT pressure on encoder / shuffle /
+ * UnsafeRow paths, similar to what the benchmark's config A does first).
+ * 2. 50 iterations of encode → shuffle → direct-accessor-decode at benchmark scale.
+ *
+ * If the reverted codec's fault is reproducible on this machine via Spark's codegen path
+ * alone, this is where it will surface.
+ */
+ @Test def test6_jitWarmupStress(): Unit = {
+ val schemaCaptured = InterStageSchema
+ val leftN = LeftN
+
+ // --- Warmup: produce JIT pressure on the shuffle/UnsafeRow paths. -------------------
+ // A small crossJoin-ish workload whose shape touches UnsafeRow getters repeatedly.
+ val warmupA = spark.range(0, 500L).toDF("a")
+ val warmupB = spark.range(0, 500L).toDF("b")
+ // 500 × 500 = 250K rows; filter + group, crossJoin-shaped, touches codegen paths.
+ warmupA.crossJoin(warmupB)
+ .groupBy(col("a"))
+ .count()
+ .count()
+
+ // --- Main loop: 50 iterations of the staged-codec hot path. -------------------------
+ val df = spark.createDataFrame(buildInterStageRows(), schemaCaptured)
+ val shuffled = df.repartition(ShuffleParts, col("_leftId"))
+ val shuffledInternal: RDD[InternalRow] = shuffled.queryExecution.toRdd
+
+ var iter = 0
+ val iterations = 50
+ var observedSum = 0L
+ while (iter < iterations) {
+ val perIterSum = shuffledInternal.mapPartitions { it =>
+ var sum = 0L
+ while (it.hasNext) {
+ val ir = it.next().copy()
+ val leftId = ir.getLong(0)
+ val qvec = ir.getArray(1)
+ var j = 0
+ var qsum = 0.0f
+ while (j < qvec.numElements()) {
+ qsum += qvec.getFloat(j)
+ j += 1
+ }
+ val refs = ir.getArray(2)
+ var r = 0
+ var refSum = 0L
+ while (r < refs.numElements()) {
+ val s = refs.getStruct(r, 2)
+ refSum += s.getLong(0)
+ r += 1
+ }
+ sum += leftId + qsum.toLong + refSum
+ }
+ Iterator.single(sum)
+ }.collect().sum
+ observedSum += perIterSum
+ iter += 1
+ }
+
+ // Weak check: every iteration sees the same `leftN` rows, so observedSum should be
+ // nonzero and equal across iterations. A meaningful crash would prevent reaching here.
+ assertTrue(observedSum != 0L, s"Expected non-zero observed sum after $iterations iters")
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleWithLanceReproTest.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleWithLanceReproTest.scala
new file mode 100644
index 000000000..288e74b11
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/staged/InterStageShuffleWithLanceReproTest.scala
@@ -0,0 +1,365 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.staged
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
+import org.junit.jupiter.api.Assertions._
+import org.junit.jupiter.api.io.TempDir
+
+import java.nio.file.Path
+import java.util.Random
+
+import scala.collection.JavaConverters._
+
+/**
+ * Follow-up to [[InterStageShuffleReproTest]] — all six synthetic (driver-built) tests
+ * there passed at benchmark scale on the JVM/arch the original crash fired on. This suite
+ * adds the missing input provenance: left rows read from an actual Lance dataset via
+ * `spark.read.format("lance")`, then run through the same encode → shuffle → decode path
+ * that [[org.lance.spark.knn.internal.staged.ProbedLeftCodec]] exercises.
+ *
+ * If the synthetic tests pass but Lance-backed tests at the same scale crash, the fault is
+ * at the Lance → Spark boundary, not in Spark/JVM generally. Candidate causes:
+ *
+ * 1. Arrow-backed `ColumnarBatch` reads leaving JVM refs into off-heap buffers that are
+ * freed when the scanner closes. Subsequent `UnsafeArrayData.getFloat` reads would
+ * hit unmapped memory → SIGSEGV in native.
+ * 2. Double/triple encoder round-trip (Arrow columnar → InternalRow → Row → InternalRow
+ * → UnsafeRow → shuffle) corrupting the nested array length header.
+ * 3. Thread-safety: `local[4]` runs four tasks concurrently in the same JVM; any shared
+ * state in the per-task encoder would corrupt UnsafeRow writes.
+ *
+ * Each test isolates one step of the staged-exec pipeline against a Lance-source left.
+ */
+class InterStageShuffleWithLanceReproTest {
+
+ @TempDir var tempDir: Path = _
+ private var spark: SparkSession = _
+
+ // Match the synthetic test for direct comparison.
+ private val LeftN: Int = 100000
+ private val Dim: Int = 128
+ private val K: Int = 10
+ private val ShuffleParts: Int = 4
+ private val Seed: Long = 1337L
+
+ @BeforeEach def setup(): Unit = {
+ spark = SparkSession.builder()
+ .appName("inter-stage-shuffle-with-lance-repro")
+ .master("local[4]")
+ .config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.driver.host", "127.0.0.1")
+ .config("spark.driver.memory", "4g")
+ .config("spark.sql.shuffle.partitions", ShuffleParts.toString)
+ .getOrCreate()
+ }
+
+ @AfterEach def teardown(): Unit = if (spark != null) spark.stop()
+
+ // ---- schemas ----------------------------------------------------------------------------
+
+ /** Left-side Lance schema: (lid, qvec). Matches the benchmark's left shape. */
+ private val LeftSchema: StructType = StructType(Array(
+ StructField("lid", LongType, nullable = false),
+ StructField(
+ "qvec",
+ ArrayType(FloatType, containsNull = false),
+ nullable = false,
+ new MetadataBuilder().putLong("arrow.fixed-size-list.size", Dim.toLong).build())))
+
+ /** Inter-stage shape: same as `ProbedLeftCodec.interStageSchema(LeftSchema)`. */
+ private val RefStruct: StructType = StructType(Array(
+ StructField("rowAddr", LongType, nullable = false),
+ StructField("score", FloatType, nullable = false)))
+
+ private val InterStageSchema: StructType = StructType(
+ StructField("_leftId", LongType, nullable = false) +:
+ LeftSchema.fields :+
+ StructField("_refs", ArrayType(RefStruct, containsNull = false), nullable = false))
+
+ // ---- data generation --------------------------------------------------------------------
+
+ private def randomFloats(rng: Random, n: Int): Array[Float] = {
+ val v = new Array[Float](n)
+ var i = 0
+ while (i < n) { v(i) = rng.nextFloat(); i += 1 }
+ v
+ }
+
+ /** Write a Lance dataset with `LeftN` rows at the benchmark's left-schema shape. */
+ private def writeLanceLeft(): String = {
+ val rng = new Random(Seed)
+ val rows = new java.util.ArrayList[Row](LeftN)
+ var i = 0
+ while (i < LeftN) {
+ val qvec = randomFloats(rng, Dim).toSeq.asJava
+ rows.add(RowFactory.create(java.lang.Long.valueOf(i.toLong), qvec))
+ i += 1
+ }
+ val df = spark.createDataFrame(rows, LeftSchema)
+ val uri = tempDir.resolve(s"left_${System.nanoTime()}").toString
+ df.write.format("lance").save(uri)
+ uri
+ }
+
+ /** Build fake refs (no Lance probe — the crash candidates sit on the LEFT side's read). */
+ private def fakeRefs(rng: Random): Seq[Row] = {
+ val buf = new scala.collection.mutable.ArrayBuffer[Row](K)
+ var r = 0
+ while (r < K) {
+ buf += Row(rng.nextLong() & 0x7FFFFFFFFFFFFFFFL, rng.nextFloat())
+ r += 1
+ }
+ buf.toSeq
+ }
+
+ // ---- tests ------------------------------------------------------------------------------
+
+ /**
+ * Read left side from Lance, count. If the Lance columnar read path itself can't sustain
+ * 100K × 128-dim, this is where that shows up. Doesn't exercise any of the codec.
+ */
+ @Test def test1_lanceReadThenCount(): Unit = {
+ val uri = writeLanceLeft()
+ val left = spark.read.format("lance").load(uri)
+ val n = left.count()
+ assertEquals(LeftN.toLong, n)
+ }
+
+ /**
+ * Read left from Lance → `ExpressionEncoder(LeftSchema).createDeserializer()` → `Row`
+ * in each partition. Mirrors what `LanceProbeExec.doExecute` does BEFORE adding refs:
+ * it takes `child.execute()` (which for a Lance table is `RDD[InternalRow]` backed by
+ * ColumnarBatch) and deserializes via encoder into `Row`.
+ *
+ * If off-heap lifetime is the bug, this is close to the root — the Arrow buffer may
+ * still be live here, but the deserialized `Row` should no longer reference it.
+ */
+ @Test def test2_lanceReadThenEncoderRoundTrip(): Unit = {
+ val uri = writeLanceLeft()
+ val left = spark.read.format("lance").load(uri)
+ // `.rdd` on a DataFrame triggers DeserializeToObject via the row encoder — same code
+ // path the staged codec uses via `createDeserializer()`.
+ val n = left.rdd.count()
+ assertEquals(LeftN.toLong, n)
+ }
+
+ /**
+ * The full staged-codec hot path against a Lance-source left:
+ * Lance scan → InternalRow child.execute()
+ * → mapPartitions { deserialize to Row via ExpressionEncoder(leftSchema) }
+ * → mapPartitions { zipWithUniqueId + attach fake refs → encode via ExpressionEncoder(InterStageSchema) }
+ * → repartition by _leftId (ShuffleExchange)
+ * → mapPartitions { direct InternalRow decode: getLong, getArray, getStruct }
+ * → count
+ *
+ * This is the closest non-trivial mirror of `LanceProbeExec.doExecute` feeding
+ * `LanceMergeExec.doExecute` across a `ShuffleExchangeExec`. If the reverted codec's
+ * crash is Lance-boundary-induced, this should SIGSEGV.
+ */
+ @Test def test3_lanceSourceFullCodecRoundTripThenShuffle(): Unit = {
+ val leftSchema = LeftSchema
+ val interStageSchema = InterStageSchema
+ val shuffleParts = ShuffleParts
+ val seed = Seed
+ val k = K
+
+ val uri = writeLanceLeft()
+ val left = spark.read.format("lance").load(uri)
+
+ // Step 1: Lance → InternalRow (Catalyst toRdd) → Row via encoder deserialize.
+ // This is LanceProbeExec.doExecute lines "Decode user's left-side InternalRows into Rows".
+ val leftInternal: RDD[InternalRow] = left.queryExecution.toRdd
+ val rowRdd: RDD[Row] = leftInternal.mapPartitions { iter =>
+ val enc = ExpressionEncoder(leftSchema).resolveAndBind()
+ val deser = enc.createDeserializer()
+ iter.map(ir => deser(ir.copy()))
+ }
+
+ // Step 2: attach a synthetic leftId + fake refs; encode to InterStageSchema via the
+ // same single ExpressionEncoder path the codec uses (the fix from commit 6218b1c).
+ val encodedRdd: RDD[InternalRow] = rowRdd
+ .zipWithUniqueId()
+ .map { case (row, id) => (id, row) }
+ .mapPartitionsWithIndex { case (partIdx, iter) =>
+ val interEnc = ExpressionEncoder(interStageSchema).resolveAndBind()
+ val ser = interEnc.createSerializer()
+ val rng = new Random(seed + partIdx.toLong)
+ val leftFieldCount = leftSchema.length
+ iter.map { case (leftId, leftRow) =>
+ // Flatten: [_leftId, leftField0, ..., leftFieldN, _refs]
+ val cols = new Array[Any](2 + leftFieldCount)
+ cols(0) = java.lang.Long.valueOf(leftId)
+ var i = 0
+ while (i < leftFieldCount) {
+ cols(1 + i) = leftRow.get(i)
+ i += 1
+ }
+ val refs = new scala.collection.mutable.ArrayBuffer[Row](k)
+ var r = 0
+ while (r < k) {
+ refs += Row(rng.nextLong() & 0x7FFFFFFFFFFFFFFFL, rng.nextFloat())
+ r += 1
+ }
+ cols(1 + leftFieldCount) = refs.toSeq
+ ser(Row.fromSeq(cols.toSeq)).copy()
+ }
+ }
+
+ // Step 3: put into a DataFrame so repartition-by-column (which requires a Catalyst
+ // shuffle) can consume it. Going RDD[InternalRow] → RDD[Row] → createDataFrame
+ // mirrors the original production code path's `createDataFrame(rdd, schema)` shape.
+ val backToRow: RDD[Row] = encodedRdd.mapPartitions { iter =>
+ val enc = ExpressionEncoder(interStageSchema).resolveAndBind()
+ val deser = enc.createDeserializer()
+ iter.map(ir => deser(ir.copy()))
+ }
+ val df = spark.createDataFrame(backToRow, interStageSchema)
+
+ // Step 4: shuffle by _leftId — this is what ClusteredDistribution(leftId) produces.
+ val shuffled = df.repartition(shuffleParts, col("_leftId"))
+
+ // Step 5: consume via direct InternalRow accessors (ProbedLeftCodec.Decoder shape).
+ // Inter-stage schema has 4 cols: [_leftId:long, lid:long, qvec:array, _refs:array]
+ val shuffledInternal: RDD[InternalRow] = shuffled.queryExecution.toRdd
+ val n = shuffledInternal.mapPartitions { iter =>
+ var sum = 0L
+ while (iter.hasNext) {
+ val ir = iter.next().copy()
+ val leftId = ir.getLong(0)
+ val lid = ir.getLong(1)
+ val qvec = ir.getArray(2)
+ var j = 0
+ var qsum = 0.0f
+ while (j < qvec.numElements()) {
+ qsum += qvec.getFloat(j)
+ j += 1
+ }
+ val refs = ir.getArray(3)
+ var r = 0
+ var refSum = 0L
+ while (r < refs.numElements()) {
+ val s = refs.getStruct(r, 2)
+ refSum += s.getLong(0)
+ r += 1
+ }
+ sum += leftId + lid + qsum.toLong + refSum
+ }
+ Iterator.single(sum)
+ }.count()
+
+ // Count is per-partition: ShuffleParts partitions emit one Long each.
+ assertEquals(shuffleParts.toLong, n)
+ }
+
+ /**
+ * Same pipeline as test3 but with JIT warmup + 20 iterations, to give C2 time to compile
+ * the hot loop over Lance-sourced UnsafeRows. The reverted crash fired at Spark stage
+ * 79 — well past any first-pass interpreter/C1 execution.
+ */
+ @Test def test4_lanceSourceFullCodecJitStress(): Unit = {
+ val leftSchema = LeftSchema
+ val interStageSchema = InterStageSchema
+ val shuffleParts = ShuffleParts
+ val seed = Seed
+ val k = K
+
+ val uri = writeLanceLeft()
+ val left = spark.read.format("lance").load(uri)
+
+ // JIT warmup: small crossJoin shape (mirrors config A in the benchmark).
+ val wA = spark.range(0, 500L).toDF("a")
+ val wB = spark.range(0, 500L).toDF("b")
+ wA.crossJoin(wB).groupBy(col("a")).count().count()
+
+ // Build the pipeline once — we re-execute it per iteration.
+ val leftInternal: RDD[InternalRow] = left.queryExecution.toRdd
+ val rowRdd: RDD[Row] = leftInternal.mapPartitions { iter =>
+ val enc = ExpressionEncoder(leftSchema).resolveAndBind()
+ val deser = enc.createDeserializer()
+ iter.map(ir => deser(ir.copy()))
+ }
+ val encodedRdd: RDD[InternalRow] = rowRdd
+ .zipWithUniqueId()
+ .map { case (row, id) => (id, row) }
+ .mapPartitionsWithIndex { case (partIdx, iter) =>
+ val interEnc = ExpressionEncoder(interStageSchema).resolveAndBind()
+ val ser = interEnc.createSerializer()
+ val rng = new Random(seed + partIdx.toLong)
+ val leftFieldCount = leftSchema.length
+ iter.map { case (leftId, leftRow) =>
+ val cols = new Array[Any](2 + leftFieldCount)
+ cols(0) = java.lang.Long.valueOf(leftId)
+ var i = 0
+ while (i < leftFieldCount) { cols(1 + i) = leftRow.get(i); i += 1 }
+ val refs = new scala.collection.mutable.ArrayBuffer[Row](k)
+ var r = 0
+ while (r < k) {
+ refs += Row(rng.nextLong() & 0x7FFFFFFFFFFFFFFFL, rng.nextFloat())
+ r += 1
+ }
+ cols(1 + leftFieldCount) = refs.toSeq
+ ser(Row.fromSeq(cols.toSeq)).copy()
+ }
+ }
+ val backToRow: RDD[Row] = encodedRdd.mapPartitions { iter =>
+ val enc = ExpressionEncoder(interStageSchema).resolveAndBind()
+ val deser = enc.createDeserializer()
+ iter.map(ir => deser(ir.copy()))
+ }
+ val df = spark.createDataFrame(backToRow, interStageSchema)
+ val shuffled = df.repartition(shuffleParts, col("_leftId"))
+ val shuffledInternal: RDD[InternalRow] = shuffled.queryExecution.toRdd
+
+ var iter = 0
+ val iterations = 100
+ var observedSum = 0L
+ while (iter < iterations) {
+ val perIterSum = shuffledInternal.mapPartitions { it =>
+ var sum = 0L
+ while (it.hasNext) {
+ val ir = it.next().copy()
+ val leftId = ir.getLong(0)
+ val lid = ir.getLong(1)
+ val qvec = ir.getArray(2)
+ var j = 0
+ var qsum = 0.0f
+ while (j < qvec.numElements()) {
+ qsum += qvec.getFloat(j)
+ j += 1
+ }
+ val refs = ir.getArray(3)
+ var r = 0
+ var refSum = 0L
+ while (r < refs.numElements()) {
+ val s = refs.getStruct(r, 2)
+ refSum += s.getLong(0)
+ r += 1
+ }
+ sum += leftId + lid + qsum.toLong + refSum
+ }
+ Iterator.single(sum)
+ }.collect().sum
+ observedSum += perIterSum
+ iter += 1
+ }
+ assertTrue(observedSum != 0L, s"Expected non-zero observed sum after $iterations iterations")
+ }
+}
diff --git a/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/testutil/ClusteredEmbeddings.scala b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/testutil/ClusteredEmbeddings.scala
new file mode 100644
index 000000000..47ccc3c6d
--- /dev/null
+++ b/lance-spark-knn_2.12/src/test/scala/org/lance/spark/knn/testutil/ClusteredEmbeddings.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.spark.knn.testutil
+
+import java.util.Random
+
+/**
+ * Generate a clustered Gaussian-mixture embedding sample as a stand-in for real production
+ * embeddings (SIFT / sentence-transformer / image features). Real embeddings are not uniform
+ * over the unit hypercube — they cluster around a small number of topic centroids with each
+ * cluster occupying a relatively narrow region of the space. Uniform-random vectors are the
+ * worst case for IVF: there's no natural cluster structure for k-means to latch onto, so the
+ * IVF partitions cover the space arbitrarily and the per-cluster recall is essentially random.
+ *
+ * Method:
+ * 1. Pick `numClusters` cluster centers, each drawn uniformly from the unit hypercube.
+ * 2. For each row, pick a cluster (round-robin so each cluster gets equal mass) and sample
+ * a Gaussian centered on it with standard deviation `sigma * cluster_separation`.
+ * 3. L2-normalize so vectors live on the unit sphere — the natural geometry for cosine /
+ * inner-product retrieval, and what most production embedding models produce.
+ *
+ * The cluster-separation factor is the median pairwise distance between centers; scaling sigma
+ * by it keeps the cluster radius proportional to inter-cluster spacing regardless of `dim` or
+ * `numClusters`. With sigma ≈ 0.15 the clusters overlap a little but stay distinguishable —
+ * a reasonable proxy for production embedding distributions.
+ *
+ * The generator is deterministic given the seed so test runs are reproducible.
+ */
+object ClusteredEmbeddings {
+
+ /**
+ * Build a clustered-Gaussian-mixture sample.
+ *
+ * @param n number of vectors to generate
+ * @param dim vector dimension
+ * @param numClusters number of cluster centers (small relative to `n` — typical 16-64)
+ * @param sigma per-cluster standard deviation, in units of inter-cluster distance.
+ * 0.05 = tight clusters (high recall floor); 0.5 = loose, near-uniform
+ * @param seed RNG seed for reproducibility
+ * @return an array of `n` float vectors of dimension `dim`, L2-normalized
+ */
+ def generate(
+ n: Int,
+ dim: Int,
+ numClusters: Int,
+ sigma: Double = 0.15,
+ seed: Long = 0L): Array[Array[Float]] = {
+ require(n > 0 && dim > 0 && numClusters > 0, "n, dim, numClusters must all be positive")
+ require(numClusters <= n, "numClusters cannot exceed n")
+ val rng = new Random(seed)
+
+ // Step 1: cluster centers, uniform on [0, 1]^dim. Stored as Doubles so the noise pass keeps
+ // numerical headroom — L2 normalization at the end folds back to Float precision.
+ val centers = Array.fill(numClusters)(Array.fill(dim)(rng.nextDouble()))
+
+ // Step 2: median pairwise distance between centers, used to scale sigma. We don't want sigma
+ // expressed in absolute distance units — the right notion is "fraction of cluster spacing,"
+ // which keeps clustering tightness behavior stable across (dim, numClusters) settings.
+ val sep = medianPairwiseDistance(centers)
+ val scaledSigma = sigma * sep
+
+ // Step 3: sample each row from a Gaussian centered on a round-robin cluster. Round-robin
+ // (rather than uniformly random cluster choice) gives every cluster the same mass — a more
+ // controlled benchmark setup than letting some clusters get sparsely populated.
+ val out = new Array[Array[Float]](n)
+ var i = 0
+ while (i < n) {
+ val center = centers(i % numClusters)
+ val v = new Array[Float](dim)
+ var d = 0
+ while (d < dim) {
+ v(d) = (center(d) + rng.nextGaussian() * scaledSigma).toFloat
+ d += 1
+ }
+ l2Normalize(v)
+ out(i) = v
+ i += 1
+ }
+ out
+ }
+
+ /**
+ * Median pairwise L2 distance between centers. We sample up to 1024 random center pairs
+ * rather than computing all `O(K^2)` of them — for `numClusters = 64` that's 2016 pairs,
+ * trivial; for larger K we'd otherwise pay cost the rest of the test doesn't need.
+ */
+ private def medianPairwiseDistance(centers: Array[Array[Double]]): Double = {
+ val k = centers.length
+ if (k < 2) return 1.0
+ val rng = new Random(0L)
+ val numPairs = math.min(1024, k * (k - 1) / 2)
+ val dists = new Array[Double](numPairs)
+ var p = 0
+ while (p < numPairs) {
+ var i = rng.nextInt(k)
+ var j = rng.nextInt(k)
+ while (j == i) j = rng.nextInt(k)
+ dists(p) = euclidean(centers(i), centers(j))
+ p += 1
+ }
+ java.util.Arrays.sort(dists)
+ dists(dists.length / 2)
+ }
+
+ private def euclidean(a: Array[Double], b: Array[Double]): Double = {
+ var s = 0.0
+ var i = 0
+ while (i < a.length) {
+ val d = a(i) - b(i)
+ s += d * d
+ i += 1
+ }
+ math.sqrt(s)
+ }
+
+ private def l2Normalize(v: Array[Float]): Unit = {
+ var s = 0.0
+ var i = 0
+ while (i < v.length) { s += v(i) * v(i); i += 1 }
+ val norm = math.sqrt(s).toFloat
+ if (norm > 0f) {
+ i = 0
+ while (i < v.length) { v(i) = v(i) / norm; i += 1 }
+ }
+ }
+}
diff --git a/lance-spark-knn_2.13/pom.xml b/lance-spark-knn_2.13/pom.xml
new file mode 100644
index 000000000..2b6d5b831
--- /dev/null
+++ b/lance-spark-knn_2.13/pom.xml
@@ -0,0 +1,140 @@
+
+
+ 4.0.0
+
+
+ org.lance
+ lance-spark-root
+ 0.4.0-beta.4
+ ../pom.xml
+
+
+ lance-spark-knn_2.13
+ ${project.artifactId}
+ Indexed nearest-neighbor join for Lance datasets in Spark
+ jar
+
+
+ ${scala213.version}
+ ${scala213.compat.version}
+
+
+
+
+ org.lance
+ lance-spark-base_2.13
+ ${project.version}
+
+
+ org.apache.spark
+ spark-sql_${scala.compat.version}
+ provided
+
+
+ org.lance
+ lance-spark-base_2.13
+ ${project.version}
+ test-jar
+ test
+
+
+
+ org.lance
+ lance-spark-3.5_2.13
+ ${project.version}
+ test
+
+
+ org.junit.jupiter
+ junit-jupiter
+ ${junit.version}
+ test
+
+
+
+
+ ../lance-spark-knn_2.12/src/main/scala
+ ../lance-spark-knn_2.12/src/test/scala
+
+
+ ../lance-spark-knn_2.12/src/test/resources
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+ 3.2.0
+
+
+ add-java-source
+ generate-sources
+
+ add-source
+
+
+
+ ../lance-spark-knn_2.12/src/main/java
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ ${scala-maven-plugin.version}
+
+
+ scala-compile-first
+ process-resources
+
+ add-source
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ -feature
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+ compile
+
+ compile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+
+ test-jar
+
+
+
+
+
+
+
diff --git a/pom.xml b/pom.xml
index bc569bfee..13e768a5a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -130,11 +130,13 @@
lance-spark-base_2.12
+ lance-spark-knn_2.12lance-spark-3.5_2.12lance-spark-bundle-3.5_2.12lance-spark-3.4_2.12lance-spark-bundle-3.4_2.12lance-spark-base_2.13
+ lance-spark-knn_2.13lance-spark-3.5_2.13lance-spark-bundle-3.5_2.13lance-spark-3.4_2.13
@@ -143,6 +145,7 @@
lance-spark-bundle-4.0_2.13lance-spark-4.1_2.13lance-spark-bundle-4.1_2.13
+ lance-spark-knn-4.2_2.13